Spaces:
Runtime error
Runtime error
import torch | |
from huggingface_guess import model_list | |
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects | |
from backend.patcher.clip import CLIP | |
from backend.patcher.vae import VAE | |
from backend.patcher.unet import UnetPatcher | |
from backend.text_processing.classic_engine import ClassicTextProcessingEngine | |
from backend.text_processing.t5_engine import T5TextProcessingEngine | |
from backend.args import dynamic_args | |
from backend import memory_management | |
from backend.modules.k_prediction import PredictionDiscreteFlow | |
from modules.shared import opts | |
## patch SD3 Class in huggingface_guess.model_list | |
def SD3_clip_target(self, state_dict={}): | |
return {'clip_l': 'text_encoder', 'clip_g': 'text_encoder_2', 't5xxl': 'text_encoder_3'} | |
model_list.SD3.unet_target = 'transformer' | |
model_list.SD3.clip_target = SD3_clip_target | |
## end patch | |
class StableDiffusion3(ForgeDiffusionEngine): | |
matched_guesses = [model_list.SD3] | |
def __init__(self, estimated_config, huggingface_components): | |
super().__init__(estimated_config, huggingface_components) | |
self.is_inpaint = False | |
clip = CLIP( | |
model_dict={ | |
'clip_l': huggingface_components['text_encoder'], | |
'clip_g': huggingface_components['text_encoder_2'], | |
't5xxl' : huggingface_components['text_encoder_3'] | |
}, | |
tokenizer_dict={ | |
'clip_l': huggingface_components['tokenizer'], | |
'clip_g': huggingface_components['tokenizer_2'], | |
't5xxl' : huggingface_components['tokenizer_3'] | |
} | |
) | |
k_predictor = PredictionDiscreteFlow(shift=3.0) | |
vae = VAE(model=huggingface_components['vae']) | |
unet = UnetPatcher.from_model( | |
model=huggingface_components['transformer'], | |
diffusers_scheduler= None, | |
k_predictor=k_predictor, | |
config=estimated_config | |
) | |
self.text_processing_engine_l = ClassicTextProcessingEngine( | |
text_encoder=clip.cond_stage_model.clip_l, | |
tokenizer=clip.tokenizer.clip_l, | |
embedding_dir=dynamic_args['embedding_dir'], | |
embedding_key='clip_l', | |
embedding_expected_shape=768, | |
emphasis_name=dynamic_args['emphasis_name'], | |
text_projection=True, | |
minimal_clip_skip=1, | |
clip_skip=1, | |
return_pooled=True, | |
final_layer_norm=False, | |
) | |
self.text_processing_engine_g = ClassicTextProcessingEngine( | |
text_encoder=clip.cond_stage_model.clip_g, | |
tokenizer=clip.tokenizer.clip_g, | |
embedding_dir=dynamic_args['embedding_dir'], | |
embedding_key='clip_g', | |
embedding_expected_shape=1280, | |
emphasis_name=dynamic_args['emphasis_name'], | |
text_projection=True, | |
minimal_clip_skip=1, | |
clip_skip=1, | |
return_pooled=True, | |
final_layer_norm=False, | |
) | |
self.text_processing_engine_t5 = T5TextProcessingEngine( | |
text_encoder=clip.cond_stage_model.t5xxl, | |
tokenizer=clip.tokenizer.t5xxl, | |
emphasis_name=dynamic_args['emphasis_name'], | |
) | |
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None) | |
self.forge_objects_original = self.forge_objects.shallow_copy() | |
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy() | |
# WebUI Legacy | |
self.is_sd3 = True | |
def set_clip_skip(self, clip_skip): | |
self.text_processing_engine_l.clip_skip = clip_skip | |
self.text_processing_engine_g.clip_skip = clip_skip | |
def get_learned_conditioning(self, prompt: list[str]): | |
memory_management.load_model_gpu(self.forge_objects.clip.patcher) | |
cond_g, g_pooled = self.text_processing_engine_g(prompt) | |
cond_l, l_pooled = self.text_processing_engine_l(prompt) | |
if opts.sd3_enable_t5: | |
cond_t5 = self.text_processing_engine_t5(prompt) | |
else: | |
cond_t5 = torch.zeros([len(prompt), 256, 4096]).to(cond_l.device) | |
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False) | |
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt) | |
if force_zero_negative_prompt: | |
l_pooled = torch.zeros_like(l_pooled) | |
g_pooled = torch.zeros_like(g_pooled) | |
cond_l = torch.zeros_like(cond_l) | |
cond_g = torch.zeros_like(cond_g) | |
cond_t5 = torch.zeros_like(cond_t5) | |
cond_lg = torch.cat([cond_l, cond_g], dim=-1) | |
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1])) | |
cond = dict( | |
crossattn=torch.cat([cond_lg, cond_t5], dim=-2), | |
vector=torch.cat([l_pooled, g_pooled], dim=-1), | |
) | |
return cond | |
def get_prompt_lengths_on_ui(self, prompt): | |
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0]) | |
return token_count, max(255, token_count) | |
def encode_first_stage(self, x): | |
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) | |
sample = self.forge_objects.vae.first_stage_model.process_in(sample) | |
return sample.to(x) | |
def decode_first_stage(self, x): | |
sample = self.forge_objects.vae.first_stage_model.process_out(x) | |
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 | |
return sample.to(x) | |