import torch import comfy import comfy.sd1_clip from torch.nn.functional import silu from types import MethodType from comfy.sd import CLIP from comfy import ldm import ldm.modules.diffusionmodules import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.openaimodel import ldm.modules.attention from . import devices, shared, sd_hijack_unet, sd_hijack_optimizations, script_callbacks, errors from .textual_inversion import textual_inversion from ..smZNodes import FrozenCLIPEmbedderWithCustomWordsCustom, FrozenOpenCLIPEmbedder2WithCustomWordsCustom, get_learned_conditioning from functools import partial if not hasattr(ldm.modules.diffusionmodules.model, "nonlinearity_orig"): ldm.modules.diffusionmodules.model.nonlinearity_orig = ldm.modules.diffusionmodules.model.nonlinearity if not hasattr(ldm.modules.diffusionmodules.openaimodel, "th_orig"): ldm.modules.diffusionmodules.openaimodel.th_orig = ldm.modules.diffusionmodules.openaimodel.th ldm.modules.attention.CrossAttention.forward_orig = ldm.modules.attention.CrossAttention.forward ldm.modules.diffusionmodules.model.AttnBlock.forward_orig = ldm.modules.diffusionmodules.model.AttnBlock.forward optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None already_optimized = False # temp fix for displaying info since two cliptextencode's will run def list_optimizers(): script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers) new_optimizers = script_callbacks.list_optimizers_callback() new_optimizers = [x for x in new_optimizers if x.is_available()] new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) optimizers.clear() optimizers.extend(new_optimizers) def apply_optimizations(option=None): global already_optimized if already_optimized: display = False list_optimizers() global current_optimizer undo_optimizations() if len(optimizers) == 0: # a script can access the model very early, and optimizations would not be filled by then current_optimizer = None return '' ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th # sgm.modules.diffusionmodules.model.nonlinearity = silu # sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th if current_optimizer is not None: current_optimizer.undo() current_optimizer = None selection = option or shared.opts.cross_attention_optimization if selection == "Automatic" and len(optimizers) > 0: matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0]) else: matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt == selection]), None) if selection == "None": matching_optimizer = None elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention: matching_optimizer = None elif matching_optimizer is None: matching_optimizer = optimizers[0] if matching_optimizer is not None: if shared.opts.debug: print(f"Applying attention optimization: {matching_optimizer.name}... ", end='') matching_optimizer.apply() already_optimized = True if shared.opts.debug: print("done.") current_optimizer = matching_optimizer return current_optimizer else: # if shared.opts.debug: # print("Disabling attention optimization") return '' def undo_optimizations(): sd_hijack_optimizations.undo() ldm.modules.diffusionmodules.model.nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity_orig ldm.modules.diffusionmodules.openaimodel.th = ldm.modules.diffusionmodules.openaimodel.th_orig class StableDiffusionModelHijack: fixes = None comments = [] layers = None circular_enabled = False clip = None tokenizer = None optimization_method = None embedding_db = textual_inversion.EmbeddingDatabase() def apply_optimizations(self, option=None): try: self.optimization_method = apply_optimizations(option) except Exception as e: errors.display(e, "applying optimizations") undo_optimizations() def hijack(self, m: comfy.sd1_clip.SD1ClipModel): tokenizer_parent = m.tokenizer # SD1Tokenizer # SDTokenizer tokenizer_parent2 = getattr(tokenizer_parent, tokenizer_parent.clip) if hasattr(tokenizer_parent, 'clip') else tokenizer_parent tokenizer = getattr(tokenizer_parent, tokenizer_parent.clip).tokenizer if hasattr(tokenizer_parent, 'clip') else tokenizer_parent.tokenizer if hasattr(m, 'clip'): m = getattr(m, m.clip) model_embeddings = m.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding.weight = model_embeddings.token_embedding.wrapped._parameters.get('weight').to(device=devices.device) m.tokenizer_parent0 = tokenizer_parent m.tokenizer_parent = tokenizer_parent2 m.tokenizer = tokenizer m = FrozenOpenCLIPEmbedder2WithCustomWordsCustom(m, self) if "SDXLClipG" in type(m).__name__ else FrozenCLIPEmbedderWithCustomWordsCustom(m, self) m.clip_layer = getattr(m.wrapped, "clip_layer", None) m.reset_clip_layer = getattr(m.wrapped, "reset_clip_layer", None) m.transformer = getattr(m.wrapped, "transformer", None) self.cond_stage_model = m self.clip = m apply_weighted_forward(self.clip) self.apply_optimizations() def undo_hijack(self, m): try: m = m.wrapped model_embeddings = m.transformer.text_model.embeddings if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped undo_optimizations() undo_weighted_forward(m) self.apply_circular(False) # self.layers = None self.clip = None self.cond_stage_model = None except Exception as err: print(err) def apply_circular(self, enable): if self.circular_enabled == enable: return self.circular_enabled = enable for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: layer.padding_mode = 'circular' if enable else 'zeros' def clear_comments(self): self.comments = [] def get_prompt_lengths(self, text): if self.clip is None: return 0, 0 _, token_count = self.clip.process_texts([text]) return token_count, self.clip.get_target_prompt_token_count(token_count) model_hijack = StableDiffusionModelHijack() def weighted_loss(sd_model, pred, target, mean=True): #Calculate the weight normally, but ignore the mean loss = sd_model._old_get_loss(pred, target, mean=False) # pylint: disable=protected-access #Check if we have weights available weight = getattr(sd_model, '_custom_loss_weight', None) if weight is not None: loss *= weight #Return the loss, as mean if specified return loss.mean() if mean else loss def weighted_forward(sd_model, x, c, w, *args, **kwargs): try: #Temporarily append weights to a place accessible during loss calc sd_model._custom_loss_weight = w # pylint: disable=protected-access #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set if not hasattr(sd_model, '_old_get_loss'): sd_model._old_get_loss = sd_model.get_loss # pylint: disable=protected-access sd_model.get_loss = MethodType(weighted_loss, sd_model) #Run the standard forward function, but with the patched 'get_loss' return sd_model.forward(x, c, *args, **kwargs) finally: try: #Delete temporary weights if appended del sd_model._custom_loss_weight except AttributeError: pass #If we have an old loss function, reset the loss function to the original one if hasattr(sd_model, '_old_get_loss'): sd_model.get_loss = sd_model._old_get_loss # pylint: disable=protected-access del sd_model._old_get_loss def apply_weighted_forward(sd_model): #Add new function 'weighted_forward' that can be called to calc weighted loss sd_model.weighted_forward = MethodType(weighted_forward, sd_model) def undo_weighted_forward(sd_model): try: del sd_model.weighted_forward except AttributeError: pass class EmbeddingsWithFixes(torch.nn.Module): def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): super().__init__() self.wrapped = wrapped self.embeddings = embeddings def forward(self, input_ids): batch_fixes = self.embeddings.fixes self.embeddings.fixes = None try: inputs_embeds = self.wrapped(input_ids) except: inputs_embeds = self.wrapped(input_ids.cpu()) if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: return inputs_embeds vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec emb = devices.cond_cast_unet(vec) if emb.device != tensor.device: emb = emb.to(device=tensor.device) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) try: tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) except Exception as err: print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", tensor.shape[0], emb.shape[1]) # raise err vecs.append(tensor) return torch.stack(vecs)