|
from __future__ import annotations |
|
import comfy |
|
import torch |
|
from typing import List, Tuple |
|
from functools import partial |
|
from .modules import prompt_parser, shared, devices |
|
from .modules.shared import opts |
|
from .modules.sd_samplers_cfg_denoiser import CFGDenoiser |
|
from .modules.sd_hijack_clip import FrozenCLIPEmbedderForSDXLWithCustomWords |
|
from .modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedder2WithCustomWords |
|
from .modules.textual_inversion.textual_inversion import Embedding |
|
import comfy.sdxl_clip |
|
import comfy.sd1_clip |
|
import comfy.sample |
|
from comfy.sd1_clip import SD1Tokenizer, unescape_important, escape_important, token_weights, expand_directory_list |
|
from nodes import CLIPTextEncode |
|
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution |
|
from comfy import model_management |
|
import inspect |
|
from textwrap import dedent, indent |
|
import functools |
|
import tempfile |
|
import importlib |
|
import sys |
|
import os |
|
import re |
|
import contextlib |
|
import itertools |
|
import binascii |
|
|
|
try: |
|
from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL, CLIPTextEncodeSDXLRefiner |
|
except Exception as err: |
|
print(f"[smZNodes]: Your ComfyUI version is outdated. Please update to the latest version. ({err})") |
|
class CLIPTextEncodeSDXL(CLIPTextEncode): ... |
|
class CLIPTextEncodeSDXLRefiner(CLIPTextEncode): ... |
|
|
|
def get_learned_conditioning(self, c): |
|
if self.cond_stage_forward is None: |
|
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): |
|
c = self.cond_stage_model.encode(c) |
|
if isinstance(c, DiagonalGaussianDistribution): |
|
c = c.mode() |
|
else: |
|
c = self.cond_stage_model(c) |
|
else: |
|
assert hasattr(self.cond_stage_model, self.cond_stage_forward) |
|
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) |
|
return c |
|
|
|
class PopulateVars: |
|
def populate_self_variables(self, from_): |
|
super_attrs = vars(from_) |
|
self_attrs = vars(self) |
|
self_attrs.update(super_attrs) |
|
|
|
should_use_fp16_signature = inspect.signature(comfy.model_management.should_use_fp16) |
|
class ClipTextEncoderCustom: |
|
|
|
def _forward(self: comfy.sd1_clip.SD1ClipModel, tokens): |
|
def set_dtype_compat(dtype, newv = False): |
|
dtype_num = lambda d : int(re.sub(r'.*?(\d+)', r'\1', repr(d))) |
|
_p = should_use_fp16_signature.parameters |
|
|
|
|
|
|
|
if dtype_num(dtype) >= 32: |
|
newv = False |
|
if not newv: return |
|
dtype = devices.dtype if dtype != devices.dtype else dtype |
|
|
|
|
|
inner_model = getattr(self.transformer, self.inner_name, None) |
|
if inner_model is not None and hasattr(inner_model, "embeddings"): |
|
inner_model.embeddings.to(dtype) |
|
else: |
|
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(dtype)) |
|
def reset_dtype_compat(): |
|
|
|
|
|
|
|
inner_model = getattr(self.transformer, self.inner_name, None) |
|
if inner_model is not None and hasattr(inner_model, "embeddings"): |
|
inner_model.embeddings.to(torch.float32) |
|
else: |
|
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) |
|
enable_compat = False |
|
if enable_compat: set_dtype_compat(torch.float16, enable_compat) |
|
|
|
backup_embeds = self.transformer.get_input_embeddings() |
|
device = backup_embeds.weight.device |
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) |
|
tokens = torch.LongTensor(tokens).to(device) |
|
|
|
|
|
if hasattr(self.transformer, 'dtype'): |
|
dtype = self.transformer.dtype |
|
else: |
|
dtype = getattr(self.transformer, self.inner_name, self.transformer.text_model).final_layer_norm.weight.dtype |
|
|
|
if dtype != torch.float32: |
|
precision_scope = torch.autocast |
|
else: |
|
precision_scope = lambda a, dtype=None: contextlib.nullcontext(a) |
|
|
|
with precision_scope(model_management.get_autocast_device(device), dtype=dtype if enable_compat else torch.float32): |
|
attention_mask = None |
|
if self.enable_attention_masks: |
|
attention_mask = torch.zeros_like(tokens) |
|
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 |
|
for x in range(attention_mask.shape[0]): |
|
for y in range(attention_mask.shape[1]): |
|
attention_mask[x, y] = 1 |
|
if tokens[x, y] == max_token: |
|
break |
|
|
|
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) |
|
self.transformer.set_input_embeddings(backup_embeds) |
|
|
|
if self.layer == "last": |
|
z = outputs[0] |
|
else: |
|
z = outputs[1] |
|
|
|
if outputs[2] is not None: |
|
pooled_output = outputs[2].float() |
|
else: |
|
pooled_output = None |
|
|
|
if enable_compat: reset_dtype_compat() |
|
|
|
if self.text_projection is not None and pooled_output is not None: |
|
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() |
|
return z.float(), pooled_output |
|
|
|
def encode_with_transformers_comfy_(self, tokens: List[List[int]], return_pooled=False): |
|
tokens_orig = tokens |
|
try: |
|
if isinstance(tokens, torch.Tensor): |
|
tokens = tokens.tolist() |
|
z, pooled = ClipTextEncoderCustom._forward(self.wrapped, tokens) |
|
except Exception as e: |
|
z, pooled = ClipTextEncoderCustom._forward(self.wrapped, tokens_orig) |
|
|
|
|
|
if z.device != devices.device: |
|
z = z.to(device=devices.device) |
|
|
|
|
|
|
|
|
|
z.pooled = pooled |
|
return (z, pooled) if return_pooled else z |
|
|
|
def encode_with_transformers_comfy(self, tokens: List[List[int]], return_pooled=False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
''' |
|
This function is different from `clip.cond_stage_model.encode_token_weights()` |
|
in that the tokens are `List[List[int]]`, not including the weights. |
|
|
|
Originally from `sd1_clip.py`: `encode()` -> `forward()` |
|
''' |
|
tokens_orig = tokens |
|
try: |
|
if isinstance(tokens, torch.Tensor): |
|
tokens = tokens.tolist() |
|
z, pooled = self.wrapped(tokens) |
|
except Exception as e: |
|
z, pooled = self.wrapped(tokens_orig) |
|
|
|
|
|
if z.device != devices.device: |
|
z = z.to(device=devices.device) |
|
|
|
|
|
|
|
|
|
z.pooled = pooled |
|
return (z, pooled) if return_pooled else z |
|
|
|
class FrozenOpenCLIPEmbedder2WithCustomWordsCustom(FrozenOpenCLIPEmbedder2WithCustomWords, ClipTextEncoderCustom, PopulateVars): |
|
def __init__(self, wrapped: comfy.sdxl_clip.SDXLClipG, hijack): |
|
self.populate_self_variables(wrapped.tokenizer_parent) |
|
super().__init__(wrapped, hijack) |
|
self.id_start = self.wrapped.tokenizer.bos_token_id |
|
self.id_end = self.wrapped.tokenizer.eos_token_id |
|
self.id_pad = 0 |
|
|
|
|
|
self.token_mults = {} |
|
vocab = self.tokenizer.get_vocab() |
|
self.comma_token = vocab.get(',</w>', None) |
|
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] |
|
for text, ident in tokens_with_parens: |
|
mult = 1.0 |
|
for c in text: |
|
if c == '[': |
|
mult /= 1.1 |
|
if c == ']': |
|
mult *= 1.1 |
|
if c == '(': |
|
mult *= 1.1 |
|
if c == ')': |
|
mult /= 1.1 |
|
if mult != 1.0: |
|
self.token_mults[ident] = mult |
|
|
|
def tokenize_line(self, line): |
|
line = parse_and_register_embeddings(self, line) |
|
return super().tokenize_line(line) |
|
|
|
def encode(self, tokens): |
|
return self.encode_with_transformers(tokens, True) |
|
|
|
def encode_with_transformers(self, tokens, return_pooled=False): |
|
return self.encode_with_transformers_comfy_(tokens, return_pooled) |
|
|
|
def encode_token_weights(self, tokens): |
|
pass |
|
|
|
def tokenize(self, texts): |
|
|
|
tokenized = [self.tokenizer(text)["input_ids"][1:-1] for text in texts] |
|
return tokenized |
|
|
|
|
|
class FrozenCLIPEmbedderWithCustomWordsCustom(FrozenCLIPEmbedderForSDXLWithCustomWords, ClipTextEncoderCustom, PopulateVars): |
|
''' |
|
Custom class that also inherits a tokenizer to have the `_try_get_embedding()` method. |
|
''' |
|
def __init__(self, wrapped: comfy.sd1_clip.SD1ClipModel, hijack): |
|
self.populate_self_variables(wrapped.tokenizer_parent) |
|
|
|
super().__init__(wrapped, hijack) |
|
|
|
def encode_token_weights(self, tokens): |
|
pass |
|
|
|
def encode(self, tokens): |
|
return self.encode_with_transformers(tokens, True) |
|
|
|
def encode_with_transformers(self, tokens, return_pooled=False): |
|
return self.encode_with_transformers_comfy_(tokens, return_pooled) |
|
|
|
def tokenize_line(self, line): |
|
line = parse_and_register_embeddings(self, line) |
|
return super().tokenize_line(line) |
|
|
|
def tokenize(self, texts): |
|
tokenized = [self.tokenizer(text)["input_ids"][1:-1] for text in texts] |
|
return tokenized |
|
|
|
emb_re_ = r"(embedding:)?(?:({}[\w\.\-\!\$\/\\]+(\.safetensors|\.pt|\.bin)|(?(1)[\w\.\-\!\$\/\\]+|(?!)))(\.safetensors|\.pt|\.bin)?)(?::(\d+\.?\d*|\d*\.\d+))?" |
|
|
|
def tokenize_with_weights_custom(self, text:str, return_word_ids=False): |
|
''' |
|
Takes a prompt and converts it to a list of (token, weight, word id) elements. |
|
Tokens can both be integer tokens and pre computed CLIP tensors. |
|
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. |
|
Returned list has the dimensions NxM where M is the input size of CLIP |
|
''' |
|
if self.pad_with_end: |
|
pad_token = self.end_token |
|
else: |
|
pad_token = 0 |
|
|
|
text = escape_important(text) |
|
parsed_weights = token_weights(text, 1.0) |
|
|
|
embs = get_valid_embeddings(self.embedding_directory) if self.embedding_directory is not None else [] |
|
embs_str = embs_str + '|' if (embs_str:='|'.join(embs)) else '' |
|
emb_re = emb_re_.format(embs_str) |
|
emb_re = re.compile(emb_re, flags=re.MULTILINE | re.UNICODE | re.IGNORECASE) |
|
|
|
|
|
tokens = [] |
|
for weighted_segment, weight in parsed_weights: |
|
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') |
|
to_tokenize = [x for x in to_tokenize if x != ""] |
|
for word in to_tokenize: |
|
matches = emb_re.finditer(word) |
|
last_end = 0 |
|
leftovers=[] |
|
for _, match in enumerate(matches, start=1): |
|
start=match.start() |
|
end=match.end() |
|
if (fragment:=word[last_end:start]): |
|
leftovers.append(fragment) |
|
ext = ext if (ext:=match.group(4)) else '' |
|
embedding_sname = embedding_sname if (embedding_sname:=match.group(2)) else '' |
|
embedding_name = embedding_sname + ext |
|
if embedding_name: |
|
embed, leftover = self._try_get_embedding(embedding_name) |
|
if embed is None: |
|
print(f"warning, embedding:{embedding_name} does not exist, ignoring") |
|
else: |
|
if opts.debug: |
|
print(f'[smZNodes] using embedding:{embedding_name}') |
|
if len(embed.shape) == 1: |
|
tokens.append([(embed, weight)]) |
|
else: |
|
tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) |
|
last_end = end |
|
if (fragment:=word[last_end:]): |
|
leftovers.append(fragment) |
|
word_new = ''.join(leftovers) |
|
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) |
|
|
|
|
|
batched_tokens = [] |
|
batch = [] |
|
if self.start_token is not None: |
|
batch.append((self.start_token, 1.0, 0)) |
|
batched_tokens.append(batch) |
|
for i, t_group in enumerate(tokens): |
|
|
|
is_large = len(t_group) >= self.max_word_length |
|
|
|
while len(t_group) > 0: |
|
if len(t_group) + len(batch) > self.max_length - 1: |
|
remaining_length = self.max_length - len(batch) - 1 |
|
|
|
if is_large: |
|
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) |
|
batch.append((self.end_token, 1.0, 0)) |
|
t_group = t_group[remaining_length:] |
|
|
|
else: |
|
batch.append((self.end_token, 1.0, 0)) |
|
if self.pad_to_max_length: |
|
batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) |
|
|
|
batch = [] |
|
if self.start_token is not None: |
|
batch.append((self.start_token, 1.0, 0)) |
|
batched_tokens.append(batch) |
|
else: |
|
batch.extend([(t,w,i+1) for t,w in t_group]) |
|
t_group = [] |
|
|
|
|
|
batch.append((self.end_token, 1.0, 0)) |
|
if self.pad_to_max_length: |
|
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) |
|
|
|
if not return_word_ids: |
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] |
|
|
|
return batched_tokens |
|
|
|
def get_valid_embeddings(embedding_directory): |
|
from builtins import any as b_any |
|
exts = ['.safetensors', '.pt', '.bin'] |
|
if isinstance(embedding_directory, str): |
|
embedding_directory = [embedding_directory] |
|
embedding_directory = expand_directory_list(embedding_directory) |
|
embs = [] |
|
for embd in embedding_directory: |
|
for root, dirs, files in os.walk(embd, topdown=False): |
|
for name in files: |
|
if not b_any(x in os.path.splitext(name)[1] for x in exts): continue |
|
n = os.path.basename(name) |
|
for ext in exts: n=n.removesuffix(ext) |
|
embs.append(re.escape(n)) |
|
embs.sort(key=len, reverse=True) |
|
return embs |
|
|
|
def parse_and_register_embeddings(self: FrozenCLIPEmbedderWithCustomWordsCustom|FrozenOpenCLIPEmbedder2WithCustomWordsCustom, text: str, return_word_ids=False): |
|
from builtins import any as b_any |
|
embedding_directory = self.wrapped.tokenizer_parent.embedding_directory |
|
embs = get_valid_embeddings(embedding_directory) |
|
embs_str = '|'.join(embs) |
|
emb_re = emb_re_.format(embs_str + '|' if embs_str else '') |
|
emb_re = re.compile(emb_re, flags=re.MULTILINE | re.UNICODE | re.IGNORECASE) |
|
matches = emb_re.finditer(text) |
|
for matchNum, match in enumerate(matches, start=1): |
|
found=False |
|
ext = ext if (ext:=match.group(4)) else '' |
|
embedding_sname = embedding_sname if (embedding_sname:=match.group(2)) else '' |
|
embedding_name = embedding_sname + ext |
|
if embedding_name: |
|
embed, _ = self.wrapped.tokenizer_parent._try_get_embedding(embedding_name) |
|
if embed is not None: |
|
found=True |
|
if opts.debug: |
|
print(f'[smZNodes] using embedding:{embedding_name}') |
|
if embed.device != devices.device: |
|
embed = embed.to(device=devices.device) |
|
self.hijack.embedding_db.register_embedding(Embedding(embed, embedding_sname), self) |
|
if not found: |
|
print(f"warning, embedding:{embedding_name} does not exist, ignoring") |
|
out = emb_re.sub(r"\2", text) |
|
return out |
|
|
|
def expand(tensor1, tensor2): |
|
def adjust_tensor_shape(tensor_small, tensor_big): |
|
|
|
|
|
replication_factor = -(-tensor_big.size(1) // tensor_small.size(1)) |
|
|
|
|
|
tensor_small_extended = tensor_small.repeat(1, replication_factor, 1) |
|
|
|
|
|
tensor_small_matched = tensor_small_extended[:, :tensor_big.size(1), :] |
|
|
|
return tensor_small_matched |
|
|
|
|
|
if tensor1.size(1) != tensor2.size(1): |
|
|
|
if tensor1.size(1) < tensor2.size(1): |
|
tensor1 = adjust_tensor_shape(tensor1, tensor2) |
|
else: |
|
tensor2 = adjust_tensor_shape(tensor2, tensor1) |
|
return (tensor1, tensor2) |
|
|
|
def reconstruct_schedules(schedules, step): |
|
create_reconstruct_fn = lambda _cc: prompt_parser.reconstruct_multicond_batch if type(_cc).__name__ == "MulticondLearnedConditioning" else prompt_parser.reconstruct_cond_batch |
|
reconstruct_fn = create_reconstruct_fn(schedules) |
|
return reconstruct_fn(schedules, step) |
|
|
|
|
|
class ClipTokenWeightEncoder: |
|
def encode_token_weights(self, token_weight_pairs, steps=0, current_step=0, multi=False): |
|
schedules = token_weight_pairs |
|
texts = token_weight_pairs |
|
conds_list = [[(0, 1.0)]] |
|
from .modules.sd_hijack import model_hijack |
|
try: |
|
model_hijack.hijack(self) |
|
if isinstance(token_weight_pairs, list) and isinstance(token_weight_pairs[0], str): |
|
if multi: schedules = prompt_parser.get_multicond_learned_conditioning(model_hijack.cond_stage_model, texts, steps, None, opts.use_old_scheduling) |
|
else: schedules = prompt_parser.get_learned_conditioning(model_hijack.cond_stage_model, texts, steps, None, opts.use_old_scheduling) |
|
cond = reconstruct_schedules(schedules, current_step) |
|
if type(cond) is tuple: |
|
conds_list, cond = cond |
|
pooled = cond.pooled.cpu() |
|
cond = cond.cpu() |
|
cond.pooled = pooled |
|
cond.pooled.conds_list = conds_list |
|
cond.pooled.schedules = schedules |
|
else: |
|
|
|
def encode_toks(_token_weight_pairs): |
|
zs = [] |
|
first_pooled = None |
|
for batch_chunk in _token_weight_pairs: |
|
tokens = [x[0] for x in batch_chunk] |
|
multipliers = [x[1] for x in batch_chunk] |
|
z = model_hijack.cond_stage_model.process_tokens([tokens], [multipliers]) |
|
if first_pooled == None: |
|
first_pooled = z.pooled |
|
zs.append(z) |
|
zcond = torch.hstack(zs) |
|
zcond.pooled = first_pooled |
|
return zcond |
|
|
|
if isinstance(token_weight_pairs, dict): |
|
token_weight_pairs = next(iter(token_weight_pairs.values())) |
|
cond = encode_toks(token_weight_pairs) |
|
pooled = cond.pooled.cpu() |
|
cond = cond.cpu() |
|
cond.pooled = pooled |
|
cond.pooled.conds_list = conds_list |
|
finally: |
|
model_hijack.undo_hijack(model_hijack.cond_stage_model) |
|
return (cond, cond.pooled) |
|
|
|
class SD1ClipModel(ClipTokenWeightEncoder): ... |
|
|
|
class SDXLClipG(ClipTokenWeightEncoder): ... |
|
|
|
class SDXLClipModel(ClipTokenWeightEncoder): |
|
|
|
def encode_token_weights(self: comfy.sdxl_clip.SDXLClipModel, token_weight_pairs, steps=0, current_step=0, multi=False): |
|
token_weight_pairs_g = token_weight_pairs["g"] |
|
token_weight_pairs_l = token_weight_pairs["l"] |
|
|
|
self.clip_g.encode_token_weights_orig = self.clip_g.encode_token_weights |
|
self.clip_l.encode_token_weights_orig = self.clip_l.encode_token_weights |
|
self.clip_g.cond_stage_model = self.clip_g |
|
self.clip_l.cond_stage_model = self.clip_l |
|
self.clip_g.encode_token_weights = partial(SDXLClipG.encode_token_weights, self.clip_g) |
|
self.clip_l.encode_token_weights = partial(SD1ClipModel.encode_token_weights, self.clip_l) |
|
try: |
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g, steps, current_step, multi) |
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l, steps, current_step, multi) |
|
|
|
|
|
finally: |
|
self.clip_g.encode_token_weights = self.clip_g.encode_token_weights_orig |
|
self.clip_l.encode_token_weights = self.clip_l.encode_token_weights_orig |
|
self.clip_g.cond_stage_model = None |
|
self.clip_l.cond_stage_model = None |
|
|
|
if hasattr(g_pooled, 'schedules') and hasattr(l_pooled, 'schedules'): |
|
g_pooled.schedules = {"g": g_pooled.schedules, "l": l_pooled.schedules} |
|
|
|
g_out, l_out = expand(g_out, l_out) |
|
l_out, g_out = expand(l_out, g_out) |
|
|
|
return torch.cat([l_out, g_out], dim=-1), g_pooled |
|
|
|
class SDXLRefinerClipModel(ClipTokenWeightEncoder): |
|
|
|
def encode_token_weights(self: comfy.sdxl_clip.SDXLClipModel, token_weight_pairs, steps=0, current_step=0, multi=False): |
|
self.clip_g.encode_token_weights_orig = self.clip_g.encode_token_weights |
|
self.clip_g.encode_token_weights = partial(SDXLClipG.encode_token_weights, self.clip_g) |
|
token_weight_pairs_g = token_weight_pairs["g"] |
|
try: g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g, steps, current_step, multi) |
|
finally: self.clip_g.encode_token_weights = self.clip_g.encode_token_weights_orig |
|
if hasattr(g_pooled, 'schedules'): |
|
g_pooled.schedules = {"g": g_pooled.schedules} |
|
return (g_out, g_pooled) |
|
|
|
def is_prompt_editing(schedules): |
|
if schedules == None: return False |
|
if not isinstance(schedules, dict): |
|
schedules = {'g': schedules} |
|
for k,v in schedules.items(): |
|
if type(v) == list: |
|
if len(v[0]) != 1: return True |
|
else: |
|
if len(v.batch[0][0].schedules) != 1: return True |
|
return False |
|
|
|
|
|
|
|
from .modules import rng_philox |
|
def randn_without_seed(x, generator=None, randn_source="cpu"): |
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. |
|
|
|
Use either randn() or manual_seed() to initialize the generator.""" |
|
if randn_source == "nv": |
|
return torch.asarray(generator.randn(x.size()), device=x.device) |
|
else: |
|
if generator is not None and generator.device.type == "cpu": |
|
return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=devices.cpu, generator=generator).to(device=x.device) |
|
else: |
|
return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) |
|
|
|
class TorchHijack: |
|
"""This is here to replace torch.randn_like of k-diffusion. |
|
|
|
k-diffusion has random_sampler argument for most samplers, but not for all, so |
|
this is needed to properly replace every use of torch.randn_like. |
|
|
|
We need to replace to make images generated in batches to be same as images generated individually.""" |
|
|
|
def __init__(self, generator, randn_source): |
|
|
|
self.generator = generator |
|
self.randn_source = randn_source |
|
|
|
def __getattr__(self, item): |
|
if item == 'randn_like': |
|
return self.randn_like |
|
|
|
if hasattr(torch, item): |
|
return getattr(torch, item) |
|
|
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") |
|
|
|
def randn_like(self, x): |
|
return randn_without_seed(x, generator=self.generator, randn_source=self.randn_source) |
|
|
|
def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'): |
|
""" |
|
creates random noise given a latent image and a seed. |
|
optional arg skip can be used to skip and discard x number of noise generations for a given seed |
|
""" |
|
from .modules.shared import opts |
|
from comfy.sample import np |
|
def get_generator(seed): |
|
nonlocal device |
|
nonlocal opts |
|
_generator = torch.Generator(device=device) |
|
generator = _generator.manual_seed(seed) |
|
if opts.randn_source == 'nv': |
|
generator = rng_philox.Generator(seed) |
|
return generator |
|
generator = generator_eta = get_generator(seed) |
|
|
|
if opts.eta_noise_seed_delta > 0: |
|
seed = min(int(seed + opts.eta_noise_seed_delta), int(0xffffffffffffffff)) |
|
generator_eta = get_generator(seed) |
|
|
|
|
|
|
|
import comfy.k_diffusion.sampling |
|
comfy.k_diffusion.sampling.torch = TorchHijack(generator_eta, opts.randn_source) |
|
|
|
if noise_inds is None: |
|
shape = latent_image.size() |
|
if opts.randn_source == 'nv': |
|
return torch.asarray(generator.randn(shape), device=devices.cpu) |
|
else: |
|
return torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator) |
|
|
|
unique_inds, inverse = np.unique(noise_inds, return_inverse=True) |
|
noises = [] |
|
for i in range(unique_inds[-1]+1): |
|
shape = [1] + list(latent_image.size())[1:] |
|
if opts.randn_source == 'nv': |
|
noise = torch.asarray(generator.randn(shape), device=devices.cpu) |
|
else: |
|
noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator) |
|
if i in unique_inds: |
|
noises.append(noise) |
|
noises = [noises[i] for i in inverse] |
|
noises = torch.cat(noises, axis=0) |
|
return noises |
|
|
|
|
|
|
|
def run(clip: comfy.sd.CLIP, text, parser, mean_normalization, |
|
multi_conditioning, use_old_emphasis_implementation, with_SDXL, |
|
ascore, width, height, crop_w, crop_h, target_width, target_height, |
|
text_g, text_l, steps=1, step=0): |
|
opts.prompt_mean_norm = mean_normalization |
|
opts.use_old_emphasis_implementation = use_old_emphasis_implementation |
|
opts.CLIP_stop_at_last_layers = abs(clip.layer_idx or 1) |
|
is_sdxl = "SDXL" in type(clip.cond_stage_model).__name__ |
|
if is_sdxl: |
|
|
|
|
|
opts.batch_cond_uncond = True |
|
|
|
parser_d = {"full": "Full parser", |
|
"compel": "Compel parser", |
|
"A1111": "A1111 parser", |
|
"fixed attention": "Fixed attention", |
|
"comfy++": "Comfy++ parser", |
|
} |
|
opts.prompt_attention = parser_d.get(parser, "Comfy parser") |
|
|
|
sdxl_params = {} |
|
if with_SDXL and is_sdxl: |
|
sdxl_params = { |
|
"aesthetic_score": ascore, "width": width, "height": height, |
|
"crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, |
|
"target_height": target_height, "text_g": text_g, "text_l": text_l |
|
} |
|
pooled={} |
|
if hasattr(comfy.sd1_clip, 'SDTokenizer'): |
|
SDTokenizer = comfy.sd1_clip.SDTokenizer |
|
else: |
|
SDTokenizer = comfy.sd1_clip.SD1Tokenizer |
|
tokenize_with_weights_orig = SDTokenizer.tokenize_with_weights |
|
if parser == "comfy": |
|
SDTokenizer.tokenize_with_weights = tokenize_with_weights_custom |
|
clip_model_type_name = type(clip.cond_stage_model).__name__ |
|
if with_SDXL and is_sdxl: |
|
if clip_model_type_name== "SDXLClipModel": |
|
out = CLIPTextEncodeSDXL().encode(clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) |
|
out[0][0][1]['aesthetic_score'] = sdxl_params['aesthetic_score'] |
|
elif clip_model_type_name == "SDXLRefinerClipModel": |
|
out = CLIPTextEncodeSDXLRefiner().encode(clip, ascore, width, height, text) |
|
for item in ['aesthetic_score', 'width', 'height', 'text_g', 'text_l']: |
|
sdxl_params.pop(item) |
|
out[0][0][1].update(sdxl_params) |
|
else: |
|
raise NotImplementedError() |
|
else: |
|
out = CLIPTextEncode().encode(clip, text) |
|
SDTokenizer.tokenize_with_weights = tokenize_with_weights_orig |
|
return out |
|
else: |
|
texts = [text] |
|
create_prompts = lambda txts: prompt_parser.SdConditioning(txts) |
|
texts = create_prompts(texts) |
|
if is_sdxl: |
|
if with_SDXL: |
|
texts = {"g": create_prompts([text_g]), "l": create_prompts([text_l])} |
|
else: |
|
texts = {"g": texts, "l": texts} |
|
|
|
|
|
clip_clone = clip |
|
clip_clone.cond_stage_model_orig = clip_clone.cond_stage_model |
|
clip_clone.cond_stage_model.encode_token_weights_orig = clip_clone.cond_stage_model.encode_token_weights |
|
|
|
def patch_cond_stage_model(): |
|
nonlocal clip_clone |
|
from .smZNodes import SD1ClipModel, SDXLClipModel, SDXLRefinerClipModel |
|
ctp = type(clip_clone.cond_stage_model) |
|
clip_clone.cond_stage_model.tokenizer = clip_clone.tokenizer |
|
if ctp is comfy.sdxl_clip.SDXLClipModel: |
|
clip_clone.cond_stage_model.encode_token_weights = SDXLClipModel.encode_token_weights |
|
clip_clone.cond_stage_model.clip_g.tokenizer = clip_clone.tokenizer.clip_g |
|
clip_clone.cond_stage_model.clip_l.tokenizer = clip_clone.tokenizer.clip_l |
|
elif ctp is comfy.sdxl_clip.SDXLRefinerClipModel: |
|
clip_clone.cond_stage_model.encode_token_weights = SDXLRefinerClipModel.encode_token_weights |
|
clip_clone.cond_stage_model.clip_g.tokenizer = clip_clone.tokenizer.clip_g |
|
else: |
|
clip_clone.cond_stage_model.encode_token_weights = SD1ClipModel.encode_token_weights |
|
|
|
tokens = texts |
|
if parser == "comfy++": |
|
SDTokenizer.tokenize_with_weights = tokenize_with_weights_custom |
|
tokens = clip_clone.tokenize(text) |
|
SDTokenizer.tokenize_with_weights = tokenize_with_weights_orig |
|
cond = pooled = None |
|
patch_cond_stage_model() |
|
try: |
|
clip_clone.cond_stage_model.encode_token_weights = partial(clip_clone.cond_stage_model.encode_token_weights, clip_clone.cond_stage_model, steps=steps, current_step=step, multi=multi_conditioning) |
|
cond, pooled = clip_clone.encode_from_tokens(tokens, True) |
|
finally: |
|
clip_clone.cond_stage_model = clip_clone.cond_stage_model_orig |
|
clip_clone.cond_stage_model.encode_token_weights = clip_clone.cond_stage_model.encode_token_weights_orig |
|
|
|
if opts.debug: |
|
print('[smZNodes] using steps', steps) |
|
gen_id = lambda : binascii.hexlify(os.urandom(1024))[64:72] |
|
id=gen_id() |
|
schedules = getattr(pooled, 'schedules', [[(0, 1.0)]]) |
|
pooled = {"pooled_output": pooled, "from_smZ": True, "smZid": id, "conds_list": pooled.conds_list, **sdxl_params} |
|
out = [[cond, pooled]] |
|
if is_prompt_editing(schedules): |
|
for x in range(1,steps): |
|
if type(schedules) is not dict: |
|
cond=reconstruct_schedules(schedules, x) |
|
if type(cond) is tuple: |
|
conds_list, cond = cond |
|
pooled['conds_list'] = conds_list |
|
cond=cond.cpu() |
|
elif type(schedules) is dict and len(schedules) == 1: |
|
cond = reconstruct_schedules(next(iter(schedules.values())), x) |
|
if type(cond) is tuple: |
|
conds_list, cond = cond |
|
pooled['conds_list'] = conds_list |
|
cond=cond.cpu() |
|
elif type(schedules) is dict: |
|
g_out = reconstruct_schedules(schedules['g'], x) |
|
if type(g_out) is tuple: _, g_out = g_out |
|
l_out = reconstruct_schedules(schedules['l'], x) |
|
if type(l_out) is tuple: _, l_out = l_out |
|
g_out, l_out = expand(g_out, l_out) |
|
l_out, g_out = expand(l_out, g_out) |
|
cond = torch.cat([l_out, g_out], dim=-1).cpu() |
|
else: |
|
raise NotImplementedError |
|
out = out + [[cond, pooled]] |
|
out[0][1]['orig_len'] = len(out) |
|
return (out,) |
|
|
|
|
|
|
|
from server import PromptServer |
|
def prompt_handler(json_data): |
|
data=json_data['prompt'] |
|
def tmp(): |
|
nonlocal data |
|
current_clip_id = None |
|
def find_nearest_ksampler(clip_id): |
|
"""Find the nearest KSampler node that references the given CLIPTextEncode id.""" |
|
for ksampler_id, node in data.items(): |
|
if "Sampler" in node["class_type"] or "sampler" in node["class_type"]: |
|
|
|
if check_link_to_clip(ksampler_id, clip_id): |
|
return get_steps(data, ksampler_id) |
|
return None |
|
|
|
def get_steps(graph, node_id): |
|
node = graph.get(str(node_id), {}) |
|
steps_input_value = node.get("inputs", {}).get("steps", None) |
|
if steps_input_value is None: |
|
steps_input_value = node.get("inputs", {}).get("sigmas", None) |
|
|
|
while(True): |
|
|
|
if isinstance(steps_input_value, (int, float, str)): |
|
return min(max(1, int(steps_input_value)), 10000) |
|
|
|
|
|
elif isinstance(steps_input_value, list): |
|
ref_node_id, ref_input_index = steps_input_value |
|
ref_node = graph.get(str(ref_node_id), {}) |
|
steps_input_value = ref_node.get("inputs", {}).get("steps", None) |
|
if steps_input_value is None: |
|
keys = list(ref_node.get("inputs", {}).keys()) |
|
ref_input_key = keys[ref_input_index % len(keys)] |
|
steps_input_value = ref_node.get("inputs", {}).get(ref_input_key) |
|
else: |
|
return None |
|
|
|
def check_link_to_clip(node_id, clip_id, visited=None): |
|
"""Check if a given node links directly or indirectly to a CLIPTextEncode node.""" |
|
if visited is None: |
|
visited = set() |
|
|
|
node = data[node_id] |
|
|
|
if node_id in visited: |
|
return False |
|
visited.add(node_id) |
|
|
|
for input_value in node["inputs"].values(): |
|
if isinstance(input_value, list) and input_value[0] == clip_id: |
|
return True |
|
if isinstance(input_value, list) and check_link_to_clip(input_value[0], clip_id, visited): |
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
for clip_id, node in data.items(): |
|
if node["class_type"] == "smZ CLIPTextEncode": |
|
current_clip_id = clip_id |
|
steps = find_nearest_ksampler(clip_id) |
|
if steps is not None: |
|
node["inputs"]["smZ_steps"] = steps |
|
if opts.debug: |
|
print(f'[smZNodes] id: {current_clip_id} | steps: {steps}') |
|
tmp() |
|
return json_data |
|
|
|
if hasattr(PromptServer.instance, 'add_on_prompt_handler'): |
|
PromptServer.instance.add_on_prompt_handler(prompt_handler) |
|
|
|
|
|
def bounded_modulo(number, modulo_value): |
|
return number if number < modulo_value else modulo_value |
|
|
|
def get_adm(c): |
|
for y in ["adm_encoded", "c_adm", "y"]: |
|
if y in c: |
|
c_c_adm = c[y] |
|
if y == "adm_encoded": y="c_adm" |
|
if type(c_c_adm) is not torch.Tensor: c_c_adm = c_c_adm.cond |
|
return {y: c_c_adm, 'key': y} |
|
return None |
|
|
|
getp=lambda x: x[1] if type(x) is list else x |
|
def calc_cond(c, current_step): |
|
"""Group by smZ conds that may do prompt-editing / regular conds / comfy conds.""" |
|
_cond = [] |
|
|
|
fn=lambda x : getp(x).get("from_smZ", None) is not None |
|
an_iterator = itertools.groupby(c, fn ) |
|
for key, group in an_iterator: |
|
ls=list(group) |
|
|
|
fn2=lambda x : getp(x).get("smZid", None) |
|
an_iterator2 = itertools.groupby(ls, fn2) |
|
for key2, group2 in an_iterator2: |
|
ls2=list(group2) |
|
if key2 is not None: |
|
orig_len = getp(ls2[0]).get('orig_len', 1) |
|
i = bounded_modulo(current_step, orig_len - 1) |
|
_cond = _cond + [ls2[i]] |
|
else: |
|
_cond = _cond + ls2 |
|
return _cond |
|
|
|
CFGNoisePredictorOrig = comfy.samplers.CFGNoisePredictor |
|
class CFGNoisePredictor(CFGNoisePredictorOrig): |
|
def __init__(self, model): |
|
super().__init__(model) |
|
self.step = 0 |
|
self.inner_model2 = CFGDenoiser(model.apply_model) |
|
self.s_min_uncond = opts.s_min_uncond |
|
self.c_adm = None |
|
self.init_cond = None |
|
self.init_uncond = None |
|
self.is_prompt_editing_u = False |
|
self.is_prompt_editing_c = False |
|
|
|
def apply_model(self, *args, **kwargs): |
|
x=kwargs['x'] if 'x' in kwargs else args[0] |
|
timestep=kwargs['timestep'] if 'timestep' in kwargs else args[1] |
|
cond=kwargs['cond'] if 'cond' in kwargs else args[2] |
|
uncond=kwargs['uncond'] if 'uncond' in kwargs else args[3] |
|
cond_scale=kwargs['cond_scale'] if 'cond_scale' in kwargs else args[4] |
|
model_options=kwargs['model_options'] if 'model_options' in kwargs else {} |
|
|
|
cc=calc_cond(cond, self.step) |
|
uu=calc_cond(uncond, self.step) |
|
self.step += 1 |
|
|
|
if (any([getp(p).get('from_smZ', False) for p in cc]) or |
|
any([getp(p).get('from_smZ', False) for p in uu])): |
|
if model_options.get('transformer_options',None) is None: |
|
model_options['transformer_options'] = {} |
|
model_options['transformer_options']['from_smZ'] = True |
|
|
|
if not opts.use_CFGDenoiser or not model_options['transformer_options'].get('from_smZ', False): |
|
if 'cond' in kwargs: kwargs['cond'] = cc |
|
else: args[2]=cc |
|
if 'uncond' in kwargs: kwargs['uncond'] = uu |
|
else: args[3]=uu |
|
out = super().apply_model(*args, **kwargs) |
|
else: |
|
|
|
for ix in range(len(cc)): |
|
if getp(cc[ix]).get('from_smZ', False): |
|
cc = [cc[ix]] |
|
break |
|
for ix in range(len(uu)): |
|
if getp(uu[ix]).get('from_smZ', False): |
|
uu = [uu[ix]] |
|
break |
|
c=getp(cc[0]) |
|
u=getp(uu[0]) |
|
_cc = cc[0][0] if type(cc[0]) is list else cc[0]['model_conds']['c_crossattn'].cond |
|
_uu = uu[0][0] if type(uu[0]) is list else uu[0]['model_conds']['c_crossattn'].cond |
|
conds_list = c.get('conds_list', [[(0, 1.0)]]) |
|
if 'model_conds' in c: c = c['model_conds'] |
|
if 'model_conds' in u: u = u['model_conds'] |
|
c_c_adm = get_adm(c) |
|
if c_c_adm is not None: |
|
u_c_adm = get_adm(u) |
|
k = c_c_adm['key'] |
|
self.c_adm = {k: torch.cat([c_c_adm[k], u_c_adm[u_c_adm['key']]]).to(device=x.device), 'key': k} |
|
|
|
_cc, _uu = expand(_cc, _uu) |
|
_uu, _cc = expand(_uu, _cc) |
|
x.c_adm = self.c_adm |
|
image_cond = txt2img_image_conditioning(None, x) |
|
out = self.inner_model2(x, timestep, cond=(conds_list, _cc), uncond=_uu, cond_scale=cond_scale, s_min_uncond=self.s_min_uncond, image_cond=image_cond) |
|
return out |
|
|
|
def txt2img_image_conditioning(sd_model, x, width=None, height=None): |
|
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inject_code(original_func, data): |
|
|
|
original_source = inspect.getsource(original_func) |
|
|
|
|
|
lines = original_source.split("\n") |
|
|
|
for item in data: |
|
|
|
target_line_number = None |
|
for i, line in enumerate(lines): |
|
if item['target_line'] in line: |
|
target_line_number = i + 1 |
|
|
|
|
|
indentation = '' |
|
for char in line: |
|
if char == ' ': |
|
indentation += char |
|
else: |
|
break |
|
|
|
|
|
code_to_insert = dedent(item['code_to_insert']) |
|
code_to_insert = indent(code_to_insert, indentation) |
|
break |
|
|
|
if target_line_number is None: |
|
raise FileNotFoundError |
|
|
|
|
|
|
|
|
|
lines.insert(target_line_number, code_to_insert) |
|
|
|
|
|
modified_source = "\n".join(lines) |
|
modified_source = dedent(modified_source.strip("\n")) |
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as temp_file: |
|
temp_file.write(modified_source) |
|
temp_file.flush() |
|
|
|
MODULE_PATH = temp_file.name |
|
MODULE_NAME = __name__.split('.')[0] + "_patch_modules" |
|
spec = importlib.util.spec_from_file_location(MODULE_NAME, MODULE_PATH) |
|
module = importlib.util.module_from_spec(spec) |
|
sys.modules[spec.name] = module |
|
spec.loader.exec_module(module) |
|
|
|
|
|
globals_dict = original_func.__globals__ |
|
for key, value in globals_dict.items(): |
|
setattr(module, key, value) |
|
modified_module = module |
|
|
|
|
|
modified_function = getattr(modified_module, original_func.__name__) |
|
|
|
|
|
if inspect.ismethod(original_func): |
|
modified_function = modified_function.__get__(original_func.__self__, original_func.__class__) |
|
|
|
|
|
functools.update_wrapper(modified_function, original_func) |
|
|
|
|
|
return modified_function |
|
|
|
|
|
|
|
|
|
|
|
from tqdm.auto import trange |
|
@torch.no_grad() |
|
def sample_dpmpp_2m_alt(model, x, sigmas, extra_args=None, callback=None, disable=None): |
|
"""DPM-Solver++(2M).""" |
|
extra_args = {} if extra_args is None else extra_args |
|
s_in = x.new_ones([x.shape[0]]) |
|
sigma_fn = lambda t: t.neg().exp() |
|
t_fn = lambda sigma: sigma.log().neg() |
|
old_denoised = None |
|
|
|
for i in trange(len(sigmas) - 1, disable=disable): |
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
if callback is not None: |
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) |
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) |
|
h = t_next - t |
|
if old_denoised is None or sigmas[i + 1] == 0: |
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised |
|
else: |
|
h_last = t - t_fn(sigmas[i - 1]) |
|
r = h_last / h |
|
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised |
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d |
|
sigma_progress = i / len(sigmas) |
|
adjustment_factor = 1 + (0.15 * (sigma_progress * sigma_progress)) |
|
old_denoised = denoised * adjustment_factor |
|
return x |
|
|
|
|
|
def add_sample_dpmpp_2m_alt(): |
|
from comfy.samplers import KSampler, k_diffusion_sampling |
|
if "dpmpp_2m_alt" not in KSampler.SAMPLERS: |
|
try: |
|
idx = KSampler.SAMPLERS.index("dpmpp_2m") |
|
KSampler.SAMPLERS.insert(idx+1, "dpmpp_2m_alt") |
|
setattr(k_diffusion_sampling, 'sample_dpmpp_2m_alt', sample_dpmpp_2m_alt) |
|
import importlib |
|
importlib.reload(k_diffusion_sampling) |
|
except ValueError as err: |
|
pass |
|
|