Spaces:
Configuration error
Configuration error
import re, time, os, psutil | |
import folder_paths | |
import comfy.utils | |
import comfy.sd | |
import comfy.controlnet | |
from comfy.model_patcher import ModelPatcher | |
from nodes import NODE_CLASS_MAPPINGS | |
from collections import defaultdict | |
from .log import log_node_info, log_node_error | |
from ..dit.pixArt.loader import load_pixart | |
stable_diffusion_loaders = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader", "easy hunyuanDiTLoader","easy zero123Loader", "easy svdLoader"] | |
stable_cascade_loaders = ["easy cascadeLoader"] | |
dit_loaders = ['easy pixArtLoader'] | |
controlnet_loaders = ["easy controlnetLoader", "easy controlnetLoaderADV"] | |
instant_loaders = ["easy instantIDApply", "easy instantIDApplyADV"] | |
cascade_vae_node = ["easy preSamplingCascade", "easy fullCascadeKSampler"] | |
model_merge_node = ["easy XYInputs: ModelMergeBlocks"] | |
lora_widget = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader"] | |
class easyLoader: | |
def __init__(self): | |
self.loaded_objects = { | |
"ckpt": defaultdict(tuple), # {ckpt_name: (model, ...)} | |
"unet": defaultdict(tuple), | |
"clip": defaultdict(tuple), | |
"clip_vision": defaultdict(tuple), | |
"bvae": defaultdict(tuple), | |
"vae": defaultdict(object), | |
"lora": defaultdict(dict), # {lora_name: {UID: (model_lora, clip_lora)}} | |
"controlnet": defaultdict(dict), | |
"t5": defaultdict(tuple), | |
"chatglm3": defaultdict(tuple), | |
} | |
self.memory_threshold = self.determine_memory_threshold(0.7) | |
self.lora_name_cache = [] | |
def clean_values(self, values: str): | |
original_values = values.split("; ") | |
cleaned_values = [] | |
for value in original_values: | |
cleaned_value = value.strip(';').strip() | |
if cleaned_value == "": | |
continue | |
try: | |
cleaned_value = int(cleaned_value) | |
except ValueError: | |
try: | |
cleaned_value = float(cleaned_value) | |
except ValueError: | |
pass | |
cleaned_values.append(cleaned_value) | |
return cleaned_values | |
def clear_unused_objects(self, desired_names: set, object_type: str): | |
keys = set(self.loaded_objects[object_type].keys()) | |
for key in keys - desired_names: | |
del self.loaded_objects[object_type][key] | |
def get_input_value(self, entry, key, prompt=None): | |
val = entry["inputs"][key] | |
if isinstance(val, str): | |
return val | |
elif isinstance(val, list): | |
if prompt is not None and val[0]: | |
return prompt[val[0]]['inputs'][key] | |
else: | |
return val[0] | |
else: | |
return str(val) | |
def process_pipe_loader(self, entry, desired_ckpt_names, desired_vae_names, desired_lora_names, desired_lora_settings, num_loras=3, suffix=""): | |
for idx in range(1, num_loras + 1): | |
lora_name_key = f"{suffix}lora{idx}_name" | |
desired_lora_names.add(self.get_input_value(entry, lora_name_key)) | |
setting = f'{self.get_input_value(entry, lora_name_key)};{entry["inputs"][f"{suffix}lora{idx}_model_strength"]};{entry["inputs"][f"{suffix}lora{idx}_clip_strength"]}' | |
desired_lora_settings.add(setting) | |
desired_ckpt_names.add(self.get_input_value(entry, f"{suffix}ckpt_name")) | |
desired_vae_names.add(self.get_input_value(entry, f"{suffix}vae_name")) | |
def update_loaded_objects(self, prompt): | |
desired_ckpt_names = set() | |
desired_unet_names = set() | |
desired_clip_names = set() | |
desired_vae_names = set() | |
desired_lora_names = set() | |
desired_lora_settings = set() | |
desired_controlnet_names = set() | |
desired_t5_names = set() | |
desired_glm3_names = set() | |
for entry in prompt.values(): | |
class_type = entry["class_type"] | |
if class_type in lora_widget: | |
lora_name = self.get_input_value(entry, "lora_name") | |
desired_lora_names.add(lora_name) | |
setting = f'{lora_name};{entry["inputs"]["lora_model_strength"]};{entry["inputs"]["lora_clip_strength"]}' | |
desired_lora_settings.add(setting) | |
if class_type in stable_diffusion_loaders: | |
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name", prompt)) | |
desired_vae_names.add(self.get_input_value(entry, "vae_name")) | |
elif class_type in ['easy kolorsLoader']: | |
desired_unet_names.add(self.get_input_value(entry, "unet_name")) | |
desired_vae_names.add(self.get_input_value(entry, "vae_name")) | |
desired_glm3_names.add(self.get_input_value(entry, "chatglm3_name")) | |
elif class_type in dit_loaders: | |
t5_name = self.get_input_value(entry, "mt5_name") if "mt5_name" in entry["inputs"] else None | |
clip_name = self.get_input_value(entry, "clip_name") if "clip_name" in entry["inputs"] else None | |
model_name = self.get_input_value(entry, "model_name") | |
ckpt_name = self.get_input_value(entry, "ckpt_name", prompt) | |
if t5_name: | |
desired_t5_names.add(t5_name) | |
if clip_name: | |
desired_clip_names.add(clip_name) | |
desired_ckpt_names.add(ckpt_name+'_'+model_name) | |
elif class_type in stable_cascade_loaders: | |
desired_unet_names.add(self.get_input_value(entry, "stage_c")) | |
desired_unet_names.add(self.get_input_value(entry, "stage_b")) | |
desired_clip_names.add(self.get_input_value(entry, "clip_name")) | |
desired_vae_names.add(self.get_input_value(entry, "stage_a")) | |
elif class_type in cascade_vae_node: | |
encode_vae_name = self.get_input_value(entry, "encode_vae_name") | |
decode_vae_name = self.get_input_value(entry, "decode_vae_name") | |
if encode_vae_name and encode_vae_name != 'None': | |
desired_vae_names.add(encode_vae_name) | |
if decode_vae_name and decode_vae_name != 'None': | |
desired_vae_names.add(decode_vae_name) | |
elif class_type in controlnet_loaders: | |
control_net_name = self.get_input_value(entry, "control_net_name", prompt) | |
scale_soft_weights = self.get_input_value(entry, "scale_soft_weights") | |
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}') | |
elif class_type in instant_loaders: | |
control_net_name = self.get_input_value(entry, "control_net_name", prompt) | |
scale_soft_weights = self.get_input_value(entry, "cn_soft_weights") | |
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}') | |
elif class_type in model_merge_node: | |
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_1")) | |
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_2")) | |
vae_use = self.get_input_value(entry, "vae_use") | |
if vae_use != 'Use Model 1' and vae_use != 'Use Model 2': | |
desired_vae_names.add(vae_use) | |
object_types = ["ckpt", "unet", "clip", "bvae", "vae", "lora", "controlnet", "t5"] | |
for object_type in object_types: | |
if object_type == 'unet': | |
desired_names = desired_unet_names | |
elif object_type in ["ckpt", "clip", "bvae"]: | |
if object_type == 'clip': | |
desired_names = desired_ckpt_names.union(desired_clip_names) | |
else: | |
desired_names = desired_ckpt_names | |
elif object_type == "vae": | |
desired_names = desired_vae_names | |
elif object_type == "controlnet": | |
desired_names = desired_controlnet_names | |
elif object_type == "t5": | |
desired_names = desired_t5_names | |
elif object_type == "chatglm3": | |
desired_names = desired_glm3_names | |
else: | |
desired_names = desired_lora_names | |
self.clear_unused_objects(desired_names, object_type) | |
def add_to_cache(self, obj_type, key, value): | |
""" | |
Add an item to the cache with the current timestamp. | |
""" | |
timestamped_value = (value, time.time()) | |
self.loaded_objects[obj_type][key] = timestamped_value | |
def determine_memory_threshold(self, percentage=0.8): | |
""" | |
Determines the memory threshold as a percentage of the total available memory. | |
Args: | |
- percentage (float): The fraction of total memory to use as the threshold. | |
Should be a value between 0 and 1. Default is 0.8 (80%). | |
Returns: | |
- memory_threshold (int): Memory threshold in bytes. | |
""" | |
total_memory = psutil.virtual_memory().total | |
memory_threshold = total_memory * percentage | |
return memory_threshold | |
def get_memory_usage(self): | |
""" | |
Returns the memory usage of the current process in bytes. | |
""" | |
process = psutil.Process(os.getpid()) | |
return process.memory_info().rss | |
def eviction_based_on_memory(self): | |
""" | |
Evicts objects from cache based on memory usage and priority. | |
""" | |
current_memory = self.get_memory_usage() | |
if current_memory < self.memory_threshold: | |
return | |
eviction_order = ["vae", "lora", "bvae", "clip", "ckpt", "controlnet", "unet", "t5", "chatglm3"] | |
for obj_type in eviction_order: | |
if current_memory < self.memory_threshold: | |
break | |
# Sort items based on age (using the timestamp) | |
items = list(self.loaded_objects[obj_type].items()) | |
items.sort(key=lambda x: x[1][1]) # Sorting by timestamp | |
for item in items: | |
if current_memory < self.memory_threshold: | |
break | |
del self.loaded_objects[obj_type][item[0]] | |
current_memory = self.get_memory_usage() | |
def load_checkpoint(self, ckpt_name, config_name=None, load_vision=False): | |
cache_name = ckpt_name | |
if config_name not in [None, "Default"]: | |
cache_name = ckpt_name + "_" + config_name | |
if cache_name in self.loaded_objects["ckpt"]: | |
clip_vision = self.loaded_objects["clip_vision"][cache_name][0] if load_vision else None | |
clip = self.loaded_objects["clip"][cache_name][0] if not load_vision else None | |
return self.loaded_objects["ckpt"][cache_name][0], clip, self.loaded_objects["bvae"][cache_name][0], clip_vision | |
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
output_clip = False if load_vision else True | |
output_clipvision = True if load_vision else False | |
if config_name not in [None, "Default"]: | |
config_path = folder_paths.get_full_path("configs", config_name) | |
loaded_ckpt = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) | |
else: | |
model_options = {} | |
if re.search("nf4", ckpt_name): | |
from ..bitsandbytes_NF4 import OPS | |
model_options = {"custom_operations": OPS} | |
loaded_ckpt = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=output_clip, output_clipvision=output_clipvision, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options) | |
self.add_to_cache("ckpt", cache_name, loaded_ckpt[0]) | |
self.add_to_cache("bvae", cache_name, loaded_ckpt[2]) | |
clip = loaded_ckpt[1] | |
clip_vision = loaded_ckpt[3] | |
if clip: | |
self.add_to_cache("clip", cache_name, clip) | |
if clip_vision: | |
self.add_to_cache("clip_vision", cache_name, clip_vision) | |
self.eviction_based_on_memory() | |
return loaded_ckpt[0], clip, loaded_ckpt[2], clip_vision | |
def load_vae(self, vae_name): | |
if vae_name in self.loaded_objects["vae"]: | |
return self.loaded_objects["vae"][vae_name][0] | |
vae_path = folder_paths.get_full_path("vae", vae_name) | |
sd = comfy.utils.load_torch_file(vae_path) | |
loaded_vae = comfy.sd.VAE(sd=sd) | |
self.add_to_cache("vae", vae_name, loaded_vae) | |
self.eviction_based_on_memory() | |
return loaded_vae | |
def load_unet(self, unet_name): | |
if unet_name in self.loaded_objects["unet"]: | |
log_node_info("Load UNet", f"{unet_name} cached") | |
return self.loaded_objects["unet"][unet_name][0] | |
unet_path = folder_paths.get_full_path("unet", unet_name) | |
model = comfy.sd.load_unet(unet_path) | |
self.add_to_cache("unet", unet_name, model) | |
self.eviction_based_on_memory() | |
return model | |
def load_controlnet(self, control_net_name, scale_soft_weights=1, use_cache=True): | |
unique_id = f'{control_net_name};{str(scale_soft_weights)}' | |
if use_cache and unique_id in self.loaded_objects["controlnet"]: | |
return self.loaded_objects["controlnet"][unique_id][0] | |
if scale_soft_weights < 1: | |
if "ScaledSoftControlNetWeights" in NODE_CLASS_MAPPINGS: | |
soft_weight_cls = NODE_CLASS_MAPPINGS['ScaledSoftControlNetWeights'] | |
(weights, timestep_keyframe) = soft_weight_cls().load_weights(scale_soft_weights, False) | |
cn_adv_cls = NODE_CLASS_MAPPINGS['ControlNetLoaderAdvanced'] | |
control_net, = cn_adv_cls().load_controlnet(control_net_name, timestep_keyframe) | |
else: | |
raise Exception(f"[Advanced-ControlNet Not Found] you need to install 'COMFYUI-Advanced-ControlNet'") | |
else: | |
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) | |
control_net = comfy.controlnet.load_controlnet(controlnet_path) | |
if use_cache: | |
self.add_to_cache("controlnet", unique_id, control_net) | |
self.eviction_based_on_memory() | |
return control_net | |
def load_clip(self, clip_name, type='stable_diffusion', load_clip=None): | |
if clip_name in self.loaded_objects["clip"]: | |
return self.loaded_objects["clip"][clip_name][0] | |
if type == 'stable_diffusion': | |
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION | |
elif type == 'stable_cascade': | |
clip_type = comfy.sd.CLIPType.STABLE_CASCADE | |
elif type == 'sd3': | |
clip_type = comfy.sd.CLIPType.SD3 | |
elif type == 'flux': | |
clip_type = comfy.sd.CLIPType.FLUX | |
elif type == 'stable_audio': | |
clip_type = comfy.sd.CLIPType.STABLE_AUDIO | |
clip_path = folder_paths.get_full_path("clip", clip_name) | |
load_clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) | |
self.add_to_cache("clip", clip_name, load_clip) | |
self.eviction_based_on_memory() | |
return load_clip | |
def load_lora(self, lora, model=None, clip=None, type=None): | |
lora_name = lora["lora_name"] | |
model = model if model is not None else lora["model"] | |
clip = clip if clip is not None else lora["clip"] | |
model_strength = lora["model_strength"] | |
clip_strength = lora["clip_strength"] | |
lbw = lora["lbw"] if "lbw" in lora else None | |
lbw_a = lora["lbw_a"] if "lbw_a" in lora else None | |
lbw_b = lora["lbw_b"] if "lbw_b" in lora else None | |
model_hash = str(model)[44:-1] | |
clip_hash = str(clip)[25:-1] if clip else '' | |
unique_id = f'{model_hash};{clip_hash};{lora_name};{model_strength};{clip_strength}' | |
if unique_id in self.loaded_objects["lora"]: | |
log_node_info("Load LORA",f"{lora_name} cached") | |
return self.loaded_objects["lora"][unique_id][0] | |
orig_lora_name = lora_name | |
lora_name = self.resolve_lora_name(lora_name) | |
if lora_name is not None: | |
lora_path = folder_paths.get_full_path("loras", lora_name) | |
else: | |
lora_path = None | |
if lora_path is not None: | |
log_node_info("Load LORA",f"{lora_name}: {model_strength}, {clip_strength}, LBW={lbw}, A={lbw_a}, B={lbw_b}") | |
if lbw: | |
lbw = lora["lbw"] | |
lbw_a = lora["lbw_a"] | |
lbw_b = lora["lbw_b"] | |
if 'LoraLoaderBlockWeight //Inspire' not in NODE_CLASS_MAPPINGS: | |
raise Exception('[InspirePack Not Found] you need to install ComfyUI-Inspire-Pack') | |
cls = NODE_CLASS_MAPPINGS['LoraLoaderBlockWeight //Inspire'] | |
model, clip, _ = cls().doit(model, clip, lora_name, model_strength, clip_strength, False, 0, | |
lbw_a, lbw_b, "", lbw) | |
else: | |
_lora = comfy.utils.load_torch_file(lora_path, safe_load=True) | |
keys = _lora.keys() | |
if "down_blocks.0.resnets.0.norm1.bias" in keys: | |
print('Using LORA for Resadapter') | |
key_map = {} | |
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) | |
mapping_norm = {} | |
for key in keys: | |
if ".weight" in key: | |
key_name_in_ori_sd = key_map[key.replace(".weight", "")] | |
mapping_norm[key_name_in_ori_sd] = _lora[key] | |
elif ".bias" in key: | |
key_name_in_ori_sd = key_map[key.replace(".bias", "")] | |
mapping_norm[key_name_in_ori_sd.replace(".weight", ".bias")] = _lora[ | |
key | |
] | |
else: | |
print("===>Unexpected key", key) | |
mapping_norm[key] = _lora[key] | |
for k in mapping_norm.keys(): | |
if k not in model.model.state_dict(): | |
print("===>Missing key:", k) | |
model.model.load_state_dict(mapping_norm, strict=False) | |
return (model, clip) | |
# PixArt | |
if type is not None and type == 'PixArt': | |
from ..dit.pixArt.loader import load_pixart_lora | |
model = load_pixart_lora(model, _lora, lora_path, model_strength) | |
else: | |
model, clip = comfy.sd.load_lora_for_models(model, clip, _lora, model_strength, clip_strength) | |
self.add_to_cache("lora", unique_id, (model, clip)) | |
self.eviction_based_on_memory() | |
else: | |
log_node_error(f"LORA NOT FOUND", orig_lora_name) | |
return model, clip | |
def resolve_lora_name(self, name): | |
if os.path.exists(name): | |
return name | |
else: | |
if len(self.lora_name_cache) == 0: | |
loras = folder_paths.get_filename_list("loras") | |
self.lora_name_cache.extend(loras) | |
for x in self.lora_name_cache: | |
if x.endswith(name): | |
return x | |
# 如果刷新网页后新添加的lora走这个逻辑 | |
log_node_info("LORA NOT IN CACHE", f"{name}") | |
loras = folder_paths.get_filename_list("loras") | |
for x in loras: | |
if x.endswith(name): | |
self.lora_name_cache.append(x) | |
return x | |
return None | |
def load_main(self, ckpt_name, config_name, vae_name, lora_name, lora_model_strength, lora_clip_strength, optional_lora_stack, model_override, clip_override, vae_override, prompt, nf4=False): | |
model: ModelPatcher | None = None | |
clip: comfy.sd.CLIP | None = None | |
vae: comfy.sd.VAE | None = None | |
clip_vision = None | |
lora_stack = [] | |
can_load_lora = True | |
# 判断是否存在 模型或Lora叠加xyplot, 若存在优先缓存第一个模型 | |
xy_model_id = next((x for x in prompt if str(prompt[x]["class_type"]) in ["easy XYInputs: ModelMergeBlocks", | |
"easy XYInputs: Checkpoint"]), None) | |
xy_lora_id = next((x for x in prompt if str(prompt[x]["class_type"]) == "easy XYInputs: Lora"), None) | |
if xy_lora_id is not None: | |
can_load_lora = False | |
if xy_model_id is not None: | |
node = prompt[xy_model_id] | |
if "ckpt_name_1" in node["inputs"]: | |
ckpt_name_1 = node["inputs"]["ckpt_name_1"] | |
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name_1) | |
can_load_lora = False | |
# Load models | |
elif model_override is not None and clip_override is not None and vae_override is not None: | |
model = model_override | |
clip = clip_override | |
vae = vae_override | |
elif model_override is not None: | |
raise Exception(f"[ERROR] clip or vae is missing") | |
elif vae_override is not None: | |
raise Exception(f"[ERROR] model or clip is missing") | |
elif clip_override is not None: | |
raise Exception(f"[ERROR] model or vae is missing") | |
else: | |
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name, config_name) | |
if optional_lora_stack is not None and can_load_lora: | |
for lora in optional_lora_stack: | |
lora = {"lora_name": lora[0], "model": model, "clip": clip, "model_strength": lora[1], | |
"clip_strength": lora[2]} | |
model, clip = self.load_lora(lora) | |
lora['model'] = model | |
lora['clip'] = clip | |
lora_stack.append(lora) | |
if lora_name != "None" and can_load_lora: | |
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": lora_model_strength, | |
"clip_strength": lora_clip_strength} | |
model, clip = self.load_lora(lora) | |
lora_stack.append(lora) | |
# Check for custom VAE | |
if vae_name not in ["Baked VAE", "Baked-VAE"]: | |
vae = self.load_vae(vae_name) | |
# CLIP skip | |
if not clip: | |
raise Exception("No CLIP found") | |
return model, clip, vae, clip_vision, lora_stack | |
# Kolors | |
def load_kolors_unet(self, unet_name): | |
if unet_name in self.loaded_objects["unet"]: | |
log_node_info("Load Kolors UNet", f"{unet_name} cached") | |
return self.loaded_objects["unet"][unet_name][0] | |
else: | |
from ..kolors.loader import applyKolorsUnet | |
with applyKolorsUnet(): | |
unet_path = folder_paths.get_full_path("unet", unet_name) | |
sd = comfy.utils.load_torch_file(unet_path) | |
model = comfy.sd.load_unet_state_dict(sd) | |
if model is None: | |
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) | |
self.add_to_cache("unet", unet_name, model) | |
self.eviction_based_on_memory() | |
return model | |
def load_chatglm3(self, chatglm3_name): | |
from ..kolors.loader import load_chatglm3 | |
if chatglm3_name in self.loaded_objects["chatglm3"]: | |
log_node_info("Load ChatGLM3", f"{chatglm3_name} cached") | |
return self.loaded_objects["chatglm3"][chatglm3_name][0] | |
chatglm_model = load_chatglm3(model_path=folder_paths.get_full_path("llm", chatglm3_name)) | |
self.add_to_cache("chatglm3", chatglm3_name, chatglm_model) | |
self.eviction_based_on_memory() | |
return chatglm_model | |
# DiT | |
def load_dit_ckpt(self, ckpt_name, model_name, **kwargs): | |
if (ckpt_name+'_'+model_name) in self.loaded_objects["ckpt"]: | |
return self.loaded_objects["ckpt"][ckpt_name+'_'+model_name][0] | |
model = None | |
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | |
model_type = kwargs['model_type'] if "model_type" in kwargs else 'PixArt' | |
if model_type == 'PixArt': | |
pixart_conf = kwargs['pixart_conf'] | |
model_conf = pixart_conf[model_name] | |
model = load_pixart(ckpt_path, model_conf) | |
if model: | |
self.add_to_cache("ckpt", ckpt_name + '_' + model_name, model) | |
self.eviction_based_on_memory() | |
return model | |
def load_dit_clip(self, clip_name, **kwargs): | |
if clip_name in self.loaded_objects["clip"]: | |
return self.loaded_objects["clip"][clip_name][0] | |
clip_path = folder_paths.get_full_path("clip", clip_name) | |
sd = comfy.utils.load_torch_file(clip_path) | |
prefix = "bert." | |
state_dict = {} | |
for key in sd: | |
nkey = key | |
if key.startswith(prefix): | |
nkey = key[len(prefix):] | |
state_dict[nkey] = sd[key] | |
m, e = model.load_sd(state_dict) | |
if len(m) > 0 or len(e) > 0: | |
print(f"{clip_name}: clip missing {len(m)} keys ({len(e)} extra)") | |
self.add_to_cache("clip", clip_name, model) | |
self.eviction_based_on_memory() | |
return model | |
def load_dit_t5(self, t5_name, **kwargs): | |
if t5_name in self.loaded_objects["t5"]: | |
return self.loaded_objects["t5"][t5_name][0] | |
model_type = kwargs['model_type'] if "model_type" in kwargs else 'HyDiT' | |
if model_type == 'HyDiT': | |
del kwargs['model_type'] | |
model = EXM_HyDiT_Tenc_Temp(model_class="mT5", **kwargs) | |
t5_path = folder_paths.get_full_path("t5", t5_name) | |
sd = comfy.utils.load_torch_file(t5_path) | |
m, e = model.load_sd(sd) | |
if len(m) > 0 or len(e) > 0: | |
print(f"{t5_name}: mT5 missing {len(m)} keys ({len(e)} extra)") | |
self.add_to_cache("t5", t5_name, model) | |
self.eviction_based_on_memory() | |
return model | |
def load_t5_from_sd3_clip(self, sd3_clip, padding): | |
try: | |
from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel | |
except: | |
from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel | |
import copy | |
clip = sd3_clip.clone() | |
assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!" | |
# remove transformer | |
transformer = clip.cond_stage_model.t5xxl.transformer | |
clip.cond_stage_model.t5xxl.transformer = None | |
# clone object | |
tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False) | |
tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl) | |
# put transformer back | |
clip.cond_stage_model.t5xxl.transformer = transformer | |
tmp.t5xxl.transformer = transformer | |
# override special tokens | |
tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens) | |
tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match | |
# tokenizer | |
tok = SD3Tokenizer() | |
tok.t5xxl.min_length = padding | |
clip.cond_stage_model = tmp | |
clip.tokenizer = tok | |
return clip | |