StupidGame's picture
Upload 1941 files
baa8e90
from __future__ import annotations
import comfy
import torch
from typing import List, Tuple
from functools import partial
from .modules import prompt_parser, shared, devices
from .modules.shared import opts
from .modules.sd_samplers_cfg_denoiser import CFGDenoiser
from .modules.sd_hijack_clip import FrozenCLIPEmbedderForSDXLWithCustomWords
from .modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedder2WithCustomWords
from .modules.textual_inversion.textual_inversion import Embedding
import comfy.sdxl_clip
import comfy.sd1_clip
import comfy.sample
from comfy.sd1_clip import SD1Tokenizer, unescape_important, escape_important, token_weights, expand_directory_list
from nodes import CLIPTextEncode
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy import model_management
import inspect
from textwrap import dedent, indent
import functools
import tempfile
import importlib
import sys
import os
import re
import contextlib
import itertools
import binascii
try:
from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL, CLIPTextEncodeSDXLRefiner
except Exception as err:
print(f"[smZNodes]: Your ComfyUI version is outdated. Please update to the latest version. ({err})")
class CLIPTextEncodeSDXL(CLIPTextEncode): ...
class CLIPTextEncodeSDXLRefiner(CLIPTextEncode): ...
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
class PopulateVars:
def populate_self_variables(self, from_):
super_attrs = vars(from_)
self_attrs = vars(self)
self_attrs.update(super_attrs)
should_use_fp16_signature = inspect.signature(comfy.model_management.should_use_fp16)
class ClipTextEncoderCustom:
def _forward(self: comfy.sd1_clip.SD1ClipModel, tokens):
def set_dtype_compat(dtype, newv = False):
dtype_num = lambda d : int(re.sub(r'.*?(\d+)', r'\1', repr(d)))
_p = should_use_fp16_signature.parameters
# newer versions of ComfyUI upcasts the transformer embeddings, which is technically correct
# when it's a newer version, we want to downcast it to torch.float16, so set newv=True
# newv = 'device' in _p and 'prioritize_performance' in _p # comment this to have default comfy behaviour
if dtype_num(dtype) >= 32:
newv = False
if not newv: return
dtype = devices.dtype if dtype != devices.dtype else dtype
# self.transformer.text_model.embeddings.position_embedding.to(dtype)
# self.transformer.text_model.embeddings.token_embedding.to(dtype)
inner_model = getattr(self.transformer, self.inner_name, None)
if inner_model is not None and hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(dtype)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(dtype))
def reset_dtype_compat():
# token_embedding_dtype = position_embedding_dtype = torch.float32
# self.transformer.text_model.embeddings.token_embedding.to(token_embedding_dtype)
# self.transformer.text_model.embeddings.position_embedding.to(position_embedding_dtype)
inner_model = getattr(self.transformer, self.inner_name, None)
if inner_model is not None and hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
enable_compat = False
if enable_compat: set_dtype_compat(torch.float16, enable_compat)
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
# dtype=backup_embeds.weight.dtype
if hasattr(self.transformer, 'dtype'):
dtype = self.transformer.dtype
else:
dtype = getattr(self.transformer, self.inner_name, self.transformer.text_model).final_layer_norm.weight.dtype
if dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, dtype=None: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), dtype=dtype if enable_compat else torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None
if enable_compat: reset_dtype_compat()
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
def encode_with_transformers_comfy_(self, tokens: List[List[int]], return_pooled=False):
tokens_orig = tokens
try:
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
z, pooled = ClipTextEncoderCustom._forward(self.wrapped, tokens) # self.wrapped.encode(tokens)
except Exception as e:
z, pooled = ClipTextEncoderCustom._forward(self.wrapped, tokens_orig)
# z = self.encode_with_transformers__(tokens_bak)
if z.device != devices.device:
z = z.to(device=devices.device)
# if z.dtype != devices.dtype:
# z = z.to(dtype=devices.dtype)
# if pooled.dtype != devices.dtype:
# pooled = pooled.to(dtype=devices.dtype)
z.pooled = pooled
return (z, pooled) if return_pooled else z
def encode_with_transformers_comfy(self, tokens: List[List[int]], return_pooled=False) -> Tuple[torch.Tensor, torch.Tensor]:
'''
This function is different from `clip.cond_stage_model.encode_token_weights()`
in that the tokens are `List[List[int]]`, not including the weights.
Originally from `sd1_clip.py`: `encode()` -> `forward()`
'''
tokens_orig = tokens
try:
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
z, pooled = self.wrapped(tokens) # self.wrapped.encode(tokens)
except Exception as e:
z, pooled = self.wrapped(tokens_orig)
# z = self.encode_with_transformers__(tokens_bak)
if z.device != devices.device:
z = z.to(device=devices.device)
# if z.dtype != devices.dtype:
# z = z.to(dtype=devices.dtype)
# if pooled.dtype != devices.dtype:
# pooled = pooled.to(dtype=devices.dtype)
z.pooled = pooled
return (z, pooled) if return_pooled else z
class FrozenOpenCLIPEmbedder2WithCustomWordsCustom(FrozenOpenCLIPEmbedder2WithCustomWords, ClipTextEncoderCustom, PopulateVars):
def __init__(self, wrapped: comfy.sdxl_clip.SDXLClipG, hijack):
self.populate_self_variables(wrapped.tokenizer_parent)
super().__init__(wrapped, hijack)
self.id_start = self.wrapped.tokenizer.bos_token_id
self.id_end = self.wrapped.tokenizer.eos_token_id
self.id_pad = 0
# Below is safe to do since ComfyUI uses the same CLIP model
# for Open Clip instead of an actual Open Clip model?
self.token_mults = {}
vocab = self.tokenizer.get_vocab()
self.comma_token = vocab.get(',</w>', None)
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def tokenize_line(self, line):
line = parse_and_register_embeddings(self, line)
return super().tokenize_line(line)
def encode(self, tokens):
return self.encode_with_transformers(tokens, True)
def encode_with_transformers(self, tokens, return_pooled=False):
return self.encode_with_transformers_comfy_(tokens, return_pooled)
def encode_token_weights(self, tokens):
pass
def tokenize(self, texts):
# assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
tokenized = [self.tokenizer(text)["input_ids"][1:-1] for text in texts]
return tokenized
class FrozenCLIPEmbedderWithCustomWordsCustom(FrozenCLIPEmbedderForSDXLWithCustomWords, ClipTextEncoderCustom, PopulateVars):
'''
Custom class that also inherits a tokenizer to have the `_try_get_embedding()` method.
'''
def __init__(self, wrapped: comfy.sd1_clip.SD1ClipModel, hijack):
self.populate_self_variables(wrapped.tokenizer_parent) # SD1Tokenizer
# self.embedding_identifier_tokenized = wrapped.tokenizer([self.embedding_identifier])["input_ids"][0][1:-1]
super().__init__(wrapped, hijack)
def encode_token_weights(self, tokens):
pass
def encode(self, tokens):
return self.encode_with_transformers(tokens, True)
def encode_with_transformers(self, tokens, return_pooled=False):
return self.encode_with_transformers_comfy_(tokens, return_pooled)
def tokenize_line(self, line):
line = parse_and_register_embeddings(self, line)
return super().tokenize_line(line)
def tokenize(self, texts):
tokenized = [self.tokenizer(text)["input_ids"][1:-1] for text in texts]
return tokenized
emb_re_ = r"(embedding:)?(?:({}[\w\.\-\!\$\/\\]+(\.safetensors|\.pt|\.bin)|(?(1)[\w\.\-\!\$\/\\]+|(?!)))(\.safetensors|\.pt|\.bin)?)(?::(\d+\.?\d*|\d*\.\d+))?"
def tokenize_with_weights_custom(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
embs = get_valid_embeddings(self.embedding_directory) if self.embedding_directory is not None else []
embs_str = embs_str + '|' if (embs_str:='|'.join(embs)) else ''
emb_re = emb_re_.format(embs_str)
emb_re = re.compile(emb_re, flags=re.MULTILINE | re.UNICODE | re.IGNORECASE)
#tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
matches = emb_re.finditer(word)
last_end = 0
leftovers=[]
for _, match in enumerate(matches, start=1):
start=match.start()
end=match.end()
if (fragment:=word[last_end:start]):
leftovers.append(fragment)
ext = ext if (ext:=match.group(4)) else ''
embedding_sname = embedding_sname if (embedding_sname:=match.group(2)) else ''
embedding_name = embedding_sname + ext
if embedding_name:
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:
if opts.debug:
print(f'[smZNodes] using embedding:{embedding_name}')
if len(embed.shape) == 1:
tokens.append([(embed, weight)])
else:
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
last_end = end
if (fragment:=word[last_end:]):
leftovers.append(fragment)
word_new = ''.join(leftovers)
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
#reshape token array to CLIP input size
batched_tokens = []
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
#fill last batch
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
return batched_tokens
def get_valid_embeddings(embedding_directory):
from builtins import any as b_any
exts = ['.safetensors', '.pt', '.bin']
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
embedding_directory = expand_directory_list(embedding_directory)
embs = []
for embd in embedding_directory:
for root, dirs, files in os.walk(embd, topdown=False):
for name in files:
if not b_any(x in os.path.splitext(name)[1] for x in exts): continue
n = os.path.basename(name)
for ext in exts: n=n.removesuffix(ext)
embs.append(re.escape(n))
embs.sort(key=len, reverse=True)
return embs
def parse_and_register_embeddings(self: FrozenCLIPEmbedderWithCustomWordsCustom|FrozenOpenCLIPEmbedder2WithCustomWordsCustom, text: str, return_word_ids=False):
from builtins import any as b_any
embedding_directory = self.wrapped.tokenizer_parent.embedding_directory
embs = get_valid_embeddings(embedding_directory)
embs_str = '|'.join(embs)
emb_re = emb_re_.format(embs_str + '|' if embs_str else '')
emb_re = re.compile(emb_re, flags=re.MULTILINE | re.UNICODE | re.IGNORECASE)
matches = emb_re.finditer(text)
for matchNum, match in enumerate(matches, start=1):
found=False
ext = ext if (ext:=match.group(4)) else ''
embedding_sname = embedding_sname if (embedding_sname:=match.group(2)) else ''
embedding_name = embedding_sname + ext
if embedding_name:
embed, _ = self.wrapped.tokenizer_parent._try_get_embedding(embedding_name)
if embed is not None:
found=True
if opts.debug:
print(f'[smZNodes] using embedding:{embedding_name}')
if embed.device != devices.device:
embed = embed.to(device=devices.device)
self.hijack.embedding_db.register_embedding(Embedding(embed, embedding_sname), self)
if not found:
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
out = emb_re.sub(r"\2", text)
return out
def expand(tensor1, tensor2):
def adjust_tensor_shape(tensor_small, tensor_big):
# Calculate replication factor
# -(-a // b) is ceiling of division without importing math.ceil
replication_factor = -(-tensor_big.size(1) // tensor_small.size(1))
# Use repeat to extend tensor_small
tensor_small_extended = tensor_small.repeat(1, replication_factor, 1)
# Take the rows of the extended tensor_small to match tensor_big
tensor_small_matched = tensor_small_extended[:, :tensor_big.size(1), :]
return tensor_small_matched
# Check if their second dimensions are different
if tensor1.size(1) != tensor2.size(1):
# Check which tensor has the smaller second dimension and adjust its shape
if tensor1.size(1) < tensor2.size(1):
tensor1 = adjust_tensor_shape(tensor1, tensor2)
else:
tensor2 = adjust_tensor_shape(tensor2, tensor1)
return (tensor1, tensor2)
def reconstruct_schedules(schedules, step):
create_reconstruct_fn = lambda _cc: prompt_parser.reconstruct_multicond_batch if type(_cc).__name__ == "MulticondLearnedConditioning" else prompt_parser.reconstruct_cond_batch
reconstruct_fn = create_reconstruct_fn(schedules)
return reconstruct_fn(schedules, step)
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs, steps=0, current_step=0, multi=False):
schedules = token_weight_pairs
texts = token_weight_pairs
conds_list = [[(0, 1.0)]]
from .modules.sd_hijack import model_hijack
try:
model_hijack.hijack(self)
if isinstance(token_weight_pairs, list) and isinstance(token_weight_pairs[0], str):
if multi: schedules = prompt_parser.get_multicond_learned_conditioning(model_hijack.cond_stage_model, texts, steps, None, opts.use_old_scheduling)
else: schedules = prompt_parser.get_learned_conditioning(model_hijack.cond_stage_model, texts, steps, None, opts.use_old_scheduling)
cond = reconstruct_schedules(schedules, current_step)
if type(cond) is tuple:
conds_list, cond = cond
pooled = cond.pooled.cpu()
cond = cond.cpu()
cond.pooled = pooled
cond.pooled.conds_list = conds_list
cond.pooled.schedules = schedules
else:
# comfy++
def encode_toks(_token_weight_pairs):
zs = []
first_pooled = None
for batch_chunk in _token_weight_pairs:
tokens = [x[0] for x in batch_chunk]
multipliers = [x[1] for x in batch_chunk]
z = model_hijack.cond_stage_model.process_tokens([tokens], [multipliers])
if first_pooled == None:
first_pooled = z.pooled
zs.append(z)
zcond = torch.hstack(zs)
zcond.pooled = first_pooled
return zcond
# non-sdxl will be something like: {"l": [[]]}
if isinstance(token_weight_pairs, dict):
token_weight_pairs = next(iter(token_weight_pairs.values()))
cond = encode_toks(token_weight_pairs)
pooled = cond.pooled.cpu()
cond = cond.cpu()
cond.pooled = pooled
cond.pooled.conds_list = conds_list
finally:
model_hijack.undo_hijack(model_hijack.cond_stage_model)
return (cond, cond.pooled)
class SD1ClipModel(ClipTokenWeightEncoder): ...
class SDXLClipG(ClipTokenWeightEncoder): ...
class SDXLClipModel(ClipTokenWeightEncoder):
def encode_token_weights(self: comfy.sdxl_clip.SDXLClipModel, token_weight_pairs, steps=0, current_step=0, multi=False):
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_l = token_weight_pairs["l"]
self.clip_g.encode_token_weights_orig = self.clip_g.encode_token_weights
self.clip_l.encode_token_weights_orig = self.clip_l.encode_token_weights
self.clip_g.cond_stage_model = self.clip_g
self.clip_l.cond_stage_model = self.clip_l
self.clip_g.encode_token_weights = partial(SDXLClipG.encode_token_weights, self.clip_g)
self.clip_l.encode_token_weights = partial(SD1ClipModel.encode_token_weights, self.clip_l)
try:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g, steps, current_step, multi)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l, steps, current_step, multi)
# g_out, g_pooled = SDXLClipG.encode_token_weights(self.clip_g, token_weight_pairs_g, steps, current_step, multi)
# l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l, steps, current_step, multi)
finally:
self.clip_g.encode_token_weights = self.clip_g.encode_token_weights_orig
self.clip_l.encode_token_weights = self.clip_l.encode_token_weights_orig
self.clip_g.cond_stage_model = None
self.clip_l.cond_stage_model = None
if hasattr(g_pooled, 'schedules') and hasattr(l_pooled, 'schedules'):
g_pooled.schedules = {"g": g_pooled.schedules, "l": l_pooled.schedules}
g_out, l_out = expand(g_out, l_out)
l_out, g_out = expand(l_out, g_out)
return torch.cat([l_out, g_out], dim=-1), g_pooled
class SDXLRefinerClipModel(ClipTokenWeightEncoder):
def encode_token_weights(self: comfy.sdxl_clip.SDXLClipModel, token_weight_pairs, steps=0, current_step=0, multi=False):
self.clip_g.encode_token_weights_orig = self.clip_g.encode_token_weights
self.clip_g.encode_token_weights = partial(SDXLClipG.encode_token_weights, self.clip_g)
token_weight_pairs_g = token_weight_pairs["g"]
try: g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g, steps, current_step, multi)
finally: self.clip_g.encode_token_weights = self.clip_g.encode_token_weights_orig
if hasattr(g_pooled, 'schedules'):
g_pooled.schedules = {"g": g_pooled.schedules}
return (g_out, g_pooled)
def is_prompt_editing(schedules):
if schedules == None: return False
if not isinstance(schedules, dict):
schedules = {'g': schedules}
for k,v in schedules.items():
if type(v) == list:
if len(v[0]) != 1: return True
else:
if len(v.batch[0][0].schedules) != 1: return True
return False
# ===================================================================
# RNG
from .modules import rng_philox
def randn_without_seed(x, generator=None, randn_source="cpu"):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if randn_source == "nv":
return torch.asarray(generator.randn(x.size()), device=x.device)
else:
if generator is not None and generator.device.type == "cpu":
return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=devices.cpu, generator=generator).to(device=x.device)
else:
return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator)
class TorchHijack:
"""This is here to replace torch.randn_like of k-diffusion.
k-diffusion has random_sampler argument for most samplers, but not for all, so
this is needed to properly replace every use of torch.randn_like.
We need to replace to make images generated in batches to be same as images generated individually."""
def __init__(self, generator, randn_source):
# self.rng = p.rng
self.generator = generator
self.randn_source = randn_source
def __getattr__(self, item):
if item == 'randn_like':
return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
return randn_without_seed(x, generator=self.generator, randn_source=self.randn_source)
def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
from .modules.shared import opts
from comfy.sample import np
def get_generator(seed):
nonlocal device
nonlocal opts
_generator = torch.Generator(device=device)
generator = _generator.manual_seed(seed)
if opts.randn_source == 'nv':
generator = rng_philox.Generator(seed)
return generator
generator = generator_eta = get_generator(seed)
if opts.eta_noise_seed_delta > 0:
seed = min(int(seed + opts.eta_noise_seed_delta), int(0xffffffffffffffff))
generator_eta = get_generator(seed)
# hijack randn_like
import comfy.k_diffusion.sampling
comfy.k_diffusion.sampling.torch = TorchHijack(generator_eta, opts.randn_source)
if noise_inds is None:
shape = latent_image.size()
if opts.randn_source == 'nv':
return torch.asarray(generator.randn(shape), device=devices.cpu)
else:
return torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
shape = [1] + list(latent_image.size())[1:]
if opts.randn_source == 'nv':
noise = torch.asarray(generator.randn(shape), device=devices.cpu)
else:
noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
# ===========================================================
def run(clip: comfy.sd.CLIP, text, parser, mean_normalization,
multi_conditioning, use_old_emphasis_implementation, with_SDXL,
ascore, width, height, crop_w, crop_h, target_width, target_height,
text_g, text_l, steps=1, step=0):
opts.prompt_mean_norm = mean_normalization
opts.use_old_emphasis_implementation = use_old_emphasis_implementation
opts.CLIP_stop_at_last_layers = abs(clip.layer_idx or 1)
is_sdxl = "SDXL" in type(clip.cond_stage_model).__name__
if is_sdxl:
# Prevents tensor shape mismatch
# This is what comfy does by default
opts.batch_cond_uncond = True
parser_d = {"full": "Full parser",
"compel": "Compel parser",
"A1111": "A1111 parser",
"fixed attention": "Fixed attention",
"comfy++": "Comfy++ parser",
}
opts.prompt_attention = parser_d.get(parser, "Comfy parser")
sdxl_params = {}
if with_SDXL and is_sdxl:
sdxl_params = {
"aesthetic_score": ascore, "width": width, "height": height,
"crop_w": crop_w, "crop_h": crop_h, "target_width": target_width,
"target_height": target_height, "text_g": text_g, "text_l": text_l
}
pooled={}
if hasattr(comfy.sd1_clip, 'SDTokenizer'):
SDTokenizer = comfy.sd1_clip.SDTokenizer
else:
SDTokenizer = comfy.sd1_clip.SD1Tokenizer
tokenize_with_weights_orig = SDTokenizer.tokenize_with_weights
if parser == "comfy":
SDTokenizer.tokenize_with_weights = tokenize_with_weights_custom
clip_model_type_name = type(clip.cond_stage_model).__name__
if with_SDXL and is_sdxl:
if clip_model_type_name== "SDXLClipModel":
out = CLIPTextEncodeSDXL().encode(clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l)
out[0][0][1]['aesthetic_score'] = sdxl_params['aesthetic_score']
elif clip_model_type_name == "SDXLRefinerClipModel":
out = CLIPTextEncodeSDXLRefiner().encode(clip, ascore, width, height, text)
for item in ['aesthetic_score', 'width', 'height', 'text_g', 'text_l']:
sdxl_params.pop(item)
out[0][0][1].update(sdxl_params)
else:
raise NotImplementedError()
else:
out = CLIPTextEncode().encode(clip, text)
SDTokenizer.tokenize_with_weights = tokenize_with_weights_orig
return out
else:
texts = [text]
create_prompts = lambda txts: prompt_parser.SdConditioning(txts)
texts = create_prompts(texts)
if is_sdxl:
if with_SDXL:
texts = {"g": create_prompts([text_g]), "l": create_prompts([text_l])}
else:
texts = {"g": texts, "l": texts}
# clip_clone = clip.clone()
clip_clone = clip
clip_clone.cond_stage_model_orig = clip_clone.cond_stage_model
clip_clone.cond_stage_model.encode_token_weights_orig = clip_clone.cond_stage_model.encode_token_weights
def patch_cond_stage_model():
nonlocal clip_clone
from .smZNodes import SD1ClipModel, SDXLClipModel, SDXLRefinerClipModel
ctp = type(clip_clone.cond_stage_model)
clip_clone.cond_stage_model.tokenizer = clip_clone.tokenizer
if ctp is comfy.sdxl_clip.SDXLClipModel:
clip_clone.cond_stage_model.encode_token_weights = SDXLClipModel.encode_token_weights
clip_clone.cond_stage_model.clip_g.tokenizer = clip_clone.tokenizer.clip_g
clip_clone.cond_stage_model.clip_l.tokenizer = clip_clone.tokenizer.clip_l
elif ctp is comfy.sdxl_clip.SDXLRefinerClipModel:
clip_clone.cond_stage_model.encode_token_weights = SDXLRefinerClipModel.encode_token_weights
clip_clone.cond_stage_model.clip_g.tokenizer = clip_clone.tokenizer.clip_g
else:
clip_clone.cond_stage_model.encode_token_weights = SD1ClipModel.encode_token_weights
tokens = texts
if parser == "comfy++":
SDTokenizer.tokenize_with_weights = tokenize_with_weights_custom
tokens = clip_clone.tokenize(text)
SDTokenizer.tokenize_with_weights = tokenize_with_weights_orig
cond = pooled = None
patch_cond_stage_model()
try:
clip_clone.cond_stage_model.encode_token_weights = partial(clip_clone.cond_stage_model.encode_token_weights, clip_clone.cond_stage_model, steps=steps, current_step=step, multi=multi_conditioning)
cond, pooled = clip_clone.encode_from_tokens(tokens, True)
finally:
clip_clone.cond_stage_model = clip_clone.cond_stage_model_orig
clip_clone.cond_stage_model.encode_token_weights = clip_clone.cond_stage_model.encode_token_weights_orig
if opts.debug:
print('[smZNodes] using steps', steps)
gen_id = lambda : binascii.hexlify(os.urandom(1024))[64:72]
id=gen_id()
schedules = getattr(pooled, 'schedules', [[(0, 1.0)]])
pooled = {"pooled_output": pooled, "from_smZ": True, "smZid": id, "conds_list": pooled.conds_list, **sdxl_params}
out = [[cond, pooled]]
if is_prompt_editing(schedules):
for x in range(1,steps):
if type(schedules) is not dict:
cond=reconstruct_schedules(schedules, x)
if type(cond) is tuple:
conds_list, cond = cond
pooled['conds_list'] = conds_list
cond=cond.cpu()
elif type(schedules) is dict and len(schedules) == 1: # SDXLRefiner
cond = reconstruct_schedules(next(iter(schedules.values())), x)
if type(cond) is tuple:
conds_list, cond = cond
pooled['conds_list'] = conds_list
cond=cond.cpu()
elif type(schedules) is dict:
g_out = reconstruct_schedules(schedules['g'], x)
if type(g_out) is tuple: _, g_out = g_out
l_out = reconstruct_schedules(schedules['l'], x)
if type(l_out) is tuple: _, l_out = l_out
g_out, l_out = expand(g_out, l_out)
l_out, g_out = expand(l_out, g_out)
cond = torch.cat([l_out, g_out], dim=-1).cpu()
else:
raise NotImplementedError
out = out + [[cond, pooled]]
out[0][1]['orig_len'] = len(out)
return (out,)
# ========================================================================
from server import PromptServer
def prompt_handler(json_data):
data=json_data['prompt']
def tmp():
nonlocal data
current_clip_id = None
def find_nearest_ksampler(clip_id):
"""Find the nearest KSampler node that references the given CLIPTextEncode id."""
for ksampler_id, node in data.items():
if "Sampler" in node["class_type"] or "sampler" in node["class_type"]:
# Check if this KSampler node directly or indirectly references the given CLIPTextEncode node
if check_link_to_clip(ksampler_id, clip_id):
return get_steps(data, ksampler_id)
return None
def get_steps(graph, node_id):
node = graph.get(str(node_id), {})
steps_input_value = node.get("inputs", {}).get("steps", None)
if steps_input_value is None:
steps_input_value = node.get("inputs", {}).get("sigmas", None)
while(True):
# Base case: it's a direct value
if isinstance(steps_input_value, (int, float, str)):
return min(max(1, int(steps_input_value)), 10000)
# Loop case: it's a reference to another node
elif isinstance(steps_input_value, list):
ref_node_id, ref_input_index = steps_input_value
ref_node = graph.get(str(ref_node_id), {})
steps_input_value = ref_node.get("inputs", {}).get("steps", None)
if steps_input_value is None:
keys = list(ref_node.get("inputs", {}).keys())
ref_input_key = keys[ref_input_index % len(keys)]
steps_input_value = ref_node.get("inputs", {}).get(ref_input_key)
else:
return None
def check_link_to_clip(node_id, clip_id, visited=None):
"""Check if a given node links directly or indirectly to a CLIPTextEncode node."""
if visited is None:
visited = set()
node = data[node_id]
if node_id in visited:
return False
visited.add(node_id)
for input_value in node["inputs"].values():
if isinstance(input_value, list) and input_value[0] == clip_id:
return True
if isinstance(input_value, list) and check_link_to_clip(input_value[0], clip_id, visited):
return True
return False
# Update each CLIPTextEncode node's steps with the steps from its nearest referencing KSampler node
for clip_id, node in data.items():
if node["class_type"] == "smZ CLIPTextEncode":
current_clip_id = clip_id
steps = find_nearest_ksampler(clip_id)
if steps is not None:
node["inputs"]["smZ_steps"] = steps
if opts.debug:
print(f'[smZNodes] id: {current_clip_id} | steps: {steps}')
tmp()
return json_data
if hasattr(PromptServer.instance, 'add_on_prompt_handler'):
PromptServer.instance.add_on_prompt_handler(prompt_handler)
# ========================================================================
def bounded_modulo(number, modulo_value):
return number if number < modulo_value else modulo_value
def get_adm(c):
for y in ["adm_encoded", "c_adm", "y"]:
if y in c:
c_c_adm = c[y]
if y == "adm_encoded": y="c_adm"
if type(c_c_adm) is not torch.Tensor: c_c_adm = c_c_adm.cond
return {y: c_c_adm, 'key': y}
return None
getp=lambda x: x[1] if type(x) is list else x
def calc_cond(c, current_step):
"""Group by smZ conds that may do prompt-editing / regular conds / comfy conds."""
_cond = []
# Group by conds from smZ
fn=lambda x : getp(x).get("from_smZ", None) is not None
an_iterator = itertools.groupby(c, fn )
for key, group in an_iterator:
ls=list(group)
# Group by prompt-editing conds
fn2=lambda x : getp(x).get("smZid", None)
an_iterator2 = itertools.groupby(ls, fn2)
for key2, group2 in an_iterator2:
ls2=list(group2)
if key2 is not None:
orig_len = getp(ls2[0]).get('orig_len', 1)
i = bounded_modulo(current_step, orig_len - 1)
_cond = _cond + [ls2[i]]
else:
_cond = _cond + ls2
return _cond
CFGNoisePredictorOrig = comfy.samplers.CFGNoisePredictor
class CFGNoisePredictor(CFGNoisePredictorOrig):
def __init__(self, model):
super().__init__(model)
self.step = 0
self.inner_model2 = CFGDenoiser(model.apply_model)
self.s_min_uncond = opts.s_min_uncond
self.c_adm = None
self.init_cond = None
self.init_uncond = None
self.is_prompt_editing_u = False
self.is_prompt_editing_c = False
def apply_model(self, *args, **kwargs):
x=kwargs['x'] if 'x' in kwargs else args[0]
timestep=kwargs['timestep'] if 'timestep' in kwargs else args[1]
cond=kwargs['cond'] if 'cond' in kwargs else args[2]
uncond=kwargs['uncond'] if 'uncond' in kwargs else args[3]
cond_scale=kwargs['cond_scale'] if 'cond_scale' in kwargs else args[4]
model_options=kwargs['model_options'] if 'model_options' in kwargs else {}
cc=calc_cond(cond, self.step)
uu=calc_cond(uncond, self.step)
self.step += 1
if (any([getp(p).get('from_smZ', False) for p in cc]) or
any([getp(p).get('from_smZ', False) for p in uu])):
if model_options.get('transformer_options',None) is None:
model_options['transformer_options'] = {}
model_options['transformer_options']['from_smZ'] = True
if not opts.use_CFGDenoiser or not model_options['transformer_options'].get('from_smZ', False):
if 'cond' in kwargs: kwargs['cond'] = cc
else: args[2]=cc
if 'uncond' in kwargs: kwargs['uncond'] = uu
else: args[3]=uu
out = super().apply_model(*args, **kwargs)
else:
# Only supports one cond
for ix in range(len(cc)):
if getp(cc[ix]).get('from_smZ', False):
cc = [cc[ix]]
break
for ix in range(len(uu)):
if getp(uu[ix]).get('from_smZ', False):
uu = [uu[ix]]
break
c=getp(cc[0])
u=getp(uu[0])
_cc = cc[0][0] if type(cc[0]) is list else cc[0]['model_conds']['c_crossattn'].cond
_uu = uu[0][0] if type(uu[0]) is list else uu[0]['model_conds']['c_crossattn'].cond
conds_list = c.get('conds_list', [[(0, 1.0)]])
if 'model_conds' in c: c = c['model_conds']
if 'model_conds' in u: u = u['model_conds']
c_c_adm = get_adm(c)
if c_c_adm is not None:
u_c_adm = get_adm(u)
k = c_c_adm['key']
self.c_adm = {k: torch.cat([c_c_adm[k], u_c_adm[u_c_adm['key']]]).to(device=x.device), 'key': k}
# SDXL. Need to pad with repeats
_cc, _uu = expand(_cc, _uu)
_uu, _cc = expand(_uu, _cc)
x.c_adm = self.c_adm
image_cond = txt2img_image_conditioning(None, x)
out = self.inner_model2(x, timestep, cond=(conds_list, _cc), uncond=_uu, cond_scale=cond_scale, s_min_uncond=self.s_min_uncond, image_cond=image_cond)
return out
def txt2img_image_conditioning(sd_model, x, width=None, height=None):
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
# if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
# # The "masked-image" in this case will just be all zeros since the entire image is masked.
# image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
# image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
# # Add the fake full 1s mask to the first dimension.
# image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
# image_conditioning = image_conditioning.to(x.dtype)
# return image_conditioning
# elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
# return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
# else:
# # Dummy zero conditioning if we're not using inpainting or unclip models.
# # Still takes up a bit of memory, but no encoder call.
# # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
# return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
# =======================================================================================
def inject_code(original_func, data):
# Get the source code of the original function
original_source = inspect.getsource(original_func)
# Split the source code into lines
lines = original_source.split("\n")
for item in data:
# Find the line number of the target line
target_line_number = None
for i, line in enumerate(lines):
if item['target_line'] in line:
target_line_number = i + 1
# Find the indentation of the line where the new code will be inserted
indentation = ''
for char in line:
if char == ' ':
indentation += char
else:
break
# Indent the new code to match the original
code_to_insert = dedent(item['code_to_insert'])
code_to_insert = indent(code_to_insert, indentation)
break
if target_line_number is None:
raise FileNotFoundError
# Target line not found, return the original function
# return original_func
# Insert the code to be injected after the target line
lines.insert(target_line_number, code_to_insert)
# Recreate the modified source code
modified_source = "\n".join(lines)
modified_source = dedent(modified_source.strip("\n"))
# Create a temporary file to write the modified source code so I can still view the
# source code when debugging.
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as temp_file:
temp_file.write(modified_source)
temp_file.flush()
MODULE_PATH = temp_file.name
MODULE_NAME = __name__.split('.')[0] + "_patch_modules"
spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
# Pass global variables to the modified module
globals_dict = original_func.__globals__
for key, value in globals_dict.items():
setattr(module, key, value)
modified_module = module
# Retrieve the modified function from the module
modified_function = getattr(modified_module, original_func.__name__)
# If the original function was a method, bind it to the first argument (self)
if inspect.ismethod(original_func):
modified_function = modified_function.__get__(original_func.__self__, original_func.__class__)
# Update the metadata of the modified function to associate it with the original function
functools.update_wrapper(modified_function, original_func)
# Return the modified function
return modified_function
# ========================================================================
# DPM++ 2M alt
from tqdm.auto import trange
@torch.no_grad()
def sample_dpmpp_2m_alt(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
sigma_progress = i / len(sigmas)
adjustment_factor = 1 + (0.15 * (sigma_progress * sigma_progress))
old_denoised = denoised * adjustment_factor
return x
def add_sample_dpmpp_2m_alt():
from comfy.samplers import KSampler, k_diffusion_sampling
if "dpmpp_2m_alt" not in KSampler.SAMPLERS:
try:
idx = KSampler.SAMPLERS.index("dpmpp_2m")
KSampler.SAMPLERS.insert(idx+1, "dpmpp_2m_alt")
setattr(k_diffusion_sampling, 'sample_dpmpp_2m_alt', sample_dpmpp_2m_alt)
import importlib
importlib.reload(k_diffusion_sampling)
except ValueError as err:
pass