import re, time, os, psutil import folder_paths import comfy.utils import comfy.sd import comfy.controlnet from comfy.model_patcher import ModelPatcher from nodes import NODE_CLASS_MAPPINGS from collections import defaultdict from .log import log_node_info, log_node_error from ..dit.pixArt.loader import load_pixart stable_diffusion_loaders = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader", "easy hunyuanDiTLoader","easy zero123Loader", "easy svdLoader"] stable_cascade_loaders = ["easy cascadeLoader"] dit_loaders = ['easy pixArtLoader'] controlnet_loaders = ["easy controlnetLoader", "easy controlnetLoaderADV"] instant_loaders = ["easy instantIDApply", "easy instantIDApplyADV"] cascade_vae_node = ["easy preSamplingCascade", "easy fullCascadeKSampler"] model_merge_node = ["easy XYInputs: ModelMergeBlocks"] lora_widget = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader"] class easyLoader: def __init__(self): self.loaded_objects = { "ckpt": defaultdict(tuple), # {ckpt_name: (model, ...)} "unet": defaultdict(tuple), "clip": defaultdict(tuple), "clip_vision": defaultdict(tuple), "bvae": defaultdict(tuple), "vae": defaultdict(object), "lora": defaultdict(dict), # {lora_name: {UID: (model_lora, clip_lora)}} "controlnet": defaultdict(dict), "t5": defaultdict(tuple), "chatglm3": defaultdict(tuple), } self.memory_threshold = self.determine_memory_threshold(0.7) self.lora_name_cache = [] def clean_values(self, values: str): original_values = values.split("; ") cleaned_values = [] for value in original_values: cleaned_value = value.strip(';').strip() if cleaned_value == "": continue try: cleaned_value = int(cleaned_value) except ValueError: try: cleaned_value = float(cleaned_value) except ValueError: pass cleaned_values.append(cleaned_value) return cleaned_values def clear_unused_objects(self, desired_names: set, object_type: str): keys = set(self.loaded_objects[object_type].keys()) for key in keys - desired_names: del self.loaded_objects[object_type][key] def get_input_value(self, entry, key, prompt=None): val = entry["inputs"][key] if isinstance(val, str): return val elif isinstance(val, list): if prompt is not None and val[0]: return prompt[val[0]]['inputs'][key] else: return val[0] else: return str(val) def process_pipe_loader(self, entry, desired_ckpt_names, desired_vae_names, desired_lora_names, desired_lora_settings, num_loras=3, suffix=""): for idx in range(1, num_loras + 1): lora_name_key = f"{suffix}lora{idx}_name" desired_lora_names.add(self.get_input_value(entry, lora_name_key)) setting = f'{self.get_input_value(entry, lora_name_key)};{entry["inputs"][f"{suffix}lora{idx}_model_strength"]};{entry["inputs"][f"{suffix}lora{idx}_clip_strength"]}' desired_lora_settings.add(setting) desired_ckpt_names.add(self.get_input_value(entry, f"{suffix}ckpt_name")) desired_vae_names.add(self.get_input_value(entry, f"{suffix}vae_name")) def update_loaded_objects(self, prompt): desired_ckpt_names = set() desired_unet_names = set() desired_clip_names = set() desired_vae_names = set() desired_lora_names = set() desired_lora_settings = set() desired_controlnet_names = set() desired_t5_names = set() desired_glm3_names = set() for entry in prompt.values(): class_type = entry["class_type"] if class_type in lora_widget: lora_name = self.get_input_value(entry, "lora_name") desired_lora_names.add(lora_name) setting = f'{lora_name};{entry["inputs"]["lora_model_strength"]};{entry["inputs"]["lora_clip_strength"]}' desired_lora_settings.add(setting) if class_type in stable_diffusion_loaders: desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name", prompt)) desired_vae_names.add(self.get_input_value(entry, "vae_name")) elif class_type in ['easy kolorsLoader']: desired_unet_names.add(self.get_input_value(entry, "unet_name")) desired_vae_names.add(self.get_input_value(entry, "vae_name")) desired_glm3_names.add(self.get_input_value(entry, "chatglm3_name")) elif class_type in dit_loaders: t5_name = self.get_input_value(entry, "mt5_name") if "mt5_name" in entry["inputs"] else None clip_name = self.get_input_value(entry, "clip_name") if "clip_name" in entry["inputs"] else None model_name = self.get_input_value(entry, "model_name") ckpt_name = self.get_input_value(entry, "ckpt_name", prompt) if t5_name: desired_t5_names.add(t5_name) if clip_name: desired_clip_names.add(clip_name) desired_ckpt_names.add(ckpt_name+'_'+model_name) elif class_type in stable_cascade_loaders: desired_unet_names.add(self.get_input_value(entry, "stage_c")) desired_unet_names.add(self.get_input_value(entry, "stage_b")) desired_clip_names.add(self.get_input_value(entry, "clip_name")) desired_vae_names.add(self.get_input_value(entry, "stage_a")) elif class_type in cascade_vae_node: encode_vae_name = self.get_input_value(entry, "encode_vae_name") decode_vae_name = self.get_input_value(entry, "decode_vae_name") if encode_vae_name and encode_vae_name != 'None': desired_vae_names.add(encode_vae_name) if decode_vae_name and decode_vae_name != 'None': desired_vae_names.add(decode_vae_name) elif class_type in controlnet_loaders: control_net_name = self.get_input_value(entry, "control_net_name", prompt) scale_soft_weights = self.get_input_value(entry, "scale_soft_weights") desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}') elif class_type in instant_loaders: control_net_name = self.get_input_value(entry, "control_net_name", prompt) scale_soft_weights = self.get_input_value(entry, "cn_soft_weights") desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}') elif class_type in model_merge_node: desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_1")) desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_2")) vae_use = self.get_input_value(entry, "vae_use") if vae_use != 'Use Model 1' and vae_use != 'Use Model 2': desired_vae_names.add(vae_use) object_types = ["ckpt", "unet", "clip", "bvae", "vae", "lora", "controlnet", "t5"] for object_type in object_types: if object_type == 'unet': desired_names = desired_unet_names elif object_type in ["ckpt", "clip", "bvae"]: if object_type == 'clip': desired_names = desired_ckpt_names.union(desired_clip_names) else: desired_names = desired_ckpt_names elif object_type == "vae": desired_names = desired_vae_names elif object_type == "controlnet": desired_names = desired_controlnet_names elif object_type == "t5": desired_names = desired_t5_names elif object_type == "chatglm3": desired_names = desired_glm3_names else: desired_names = desired_lora_names self.clear_unused_objects(desired_names, object_type) def add_to_cache(self, obj_type, key, value): """ Add an item to the cache with the current timestamp. """ timestamped_value = (value, time.time()) self.loaded_objects[obj_type][key] = timestamped_value def determine_memory_threshold(self, percentage=0.8): """ Determines the memory threshold as a percentage of the total available memory. Args: - percentage (float): The fraction of total memory to use as the threshold. Should be a value between 0 and 1. Default is 0.8 (80%). Returns: - memory_threshold (int): Memory threshold in bytes. """ total_memory = psutil.virtual_memory().total memory_threshold = total_memory * percentage return memory_threshold def get_memory_usage(self): """ Returns the memory usage of the current process in bytes. """ process = psutil.Process(os.getpid()) return process.memory_info().rss def eviction_based_on_memory(self): """ Evicts objects from cache based on memory usage and priority. """ current_memory = self.get_memory_usage() if current_memory < self.memory_threshold: return eviction_order = ["vae", "lora", "bvae", "clip", "ckpt", "controlnet", "unet", "t5", "chatglm3"] for obj_type in eviction_order: if current_memory < self.memory_threshold: break # Sort items based on age (using the timestamp) items = list(self.loaded_objects[obj_type].items()) items.sort(key=lambda x: x[1][1]) # Sorting by timestamp for item in items: if current_memory < self.memory_threshold: break del self.loaded_objects[obj_type][item[0]] current_memory = self.get_memory_usage() def load_checkpoint(self, ckpt_name, config_name=None, load_vision=False): cache_name = ckpt_name if config_name not in [None, "Default"]: cache_name = ckpt_name + "_" + config_name if cache_name in self.loaded_objects["ckpt"]: clip_vision = self.loaded_objects["clip_vision"][cache_name][0] if load_vision else None clip = self.loaded_objects["clip"][cache_name][0] if not load_vision else None return self.loaded_objects["ckpt"][cache_name][0], clip, self.loaded_objects["bvae"][cache_name][0], clip_vision ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) output_clip = False if load_vision else True output_clipvision = True if load_vision else False if config_name not in [None, "Default"]: config_path = folder_paths.get_full_path("configs", config_name) loaded_ckpt = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) else: model_options = {} if re.search("nf4", ckpt_name): from ..bitsandbytes_NF4 import OPS model_options = {"custom_operations": OPS} loaded_ckpt = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=output_clip, output_clipvision=output_clipvision, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options) self.add_to_cache("ckpt", cache_name, loaded_ckpt[0]) self.add_to_cache("bvae", cache_name, loaded_ckpt[2]) clip = loaded_ckpt[1] clip_vision = loaded_ckpt[3] if clip: self.add_to_cache("clip", cache_name, clip) if clip_vision: self.add_to_cache("clip_vision", cache_name, clip_vision) self.eviction_based_on_memory() return loaded_ckpt[0], clip, loaded_ckpt[2], clip_vision def load_vae(self, vae_name): if vae_name in self.loaded_objects["vae"]: return self.loaded_objects["vae"][vae_name][0] vae_path = folder_paths.get_full_path("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) loaded_vae = comfy.sd.VAE(sd=sd) self.add_to_cache("vae", vae_name, loaded_vae) self.eviction_based_on_memory() return loaded_vae def load_unet(self, unet_name): if unet_name in self.loaded_objects["unet"]: log_node_info("Load UNet", f"{unet_name} cached") return self.loaded_objects["unet"][unet_name][0] unet_path = folder_paths.get_full_path("unet", unet_name) model = comfy.sd.load_unet(unet_path) self.add_to_cache("unet", unet_name, model) self.eviction_based_on_memory() return model def load_controlnet(self, control_net_name, scale_soft_weights=1, use_cache=True): unique_id = f'{control_net_name};{str(scale_soft_weights)}' if use_cache and unique_id in self.loaded_objects["controlnet"]: return self.loaded_objects["controlnet"][unique_id][0] if scale_soft_weights < 1: if "ScaledSoftControlNetWeights" in NODE_CLASS_MAPPINGS: soft_weight_cls = NODE_CLASS_MAPPINGS['ScaledSoftControlNetWeights'] (weights, timestep_keyframe) = soft_weight_cls().load_weights(scale_soft_weights, False) cn_adv_cls = NODE_CLASS_MAPPINGS['ControlNetLoaderAdvanced'] control_net, = cn_adv_cls().load_controlnet(control_net_name, timestep_keyframe) else: raise Exception(f"[Advanced-ControlNet Not Found] you need to install 'COMFYUI-Advanced-ControlNet'") else: controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) control_net = comfy.controlnet.load_controlnet(controlnet_path) if use_cache: self.add_to_cache("controlnet", unique_id, control_net) self.eviction_based_on_memory() return control_net def load_clip(self, clip_name, type='stable_diffusion', load_clip=None): if clip_name in self.loaded_objects["clip"]: return self.loaded_objects["clip"][clip_name][0] if type == 'stable_diffusion': clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION elif type == 'stable_cascade': clip_type = comfy.sd.CLIPType.STABLE_CASCADE elif type == 'sd3': clip_type = comfy.sd.CLIPType.SD3 elif type == 'flux': clip_type = comfy.sd.CLIPType.FLUX elif type == 'stable_audio': clip_type = comfy.sd.CLIPType.STABLE_AUDIO clip_path = folder_paths.get_full_path("clip", clip_name) load_clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) self.add_to_cache("clip", clip_name, load_clip) self.eviction_based_on_memory() return load_clip def load_lora(self, lora, model=None, clip=None, type=None): lora_name = lora["lora_name"] model = model if model is not None else lora["model"] clip = clip if clip is not None else lora["clip"] model_strength = lora["model_strength"] clip_strength = lora["clip_strength"] lbw = lora["lbw"] if "lbw" in lora else None lbw_a = lora["lbw_a"] if "lbw_a" in lora else None lbw_b = lora["lbw_b"] if "lbw_b" in lora else None model_hash = str(model)[44:-1] clip_hash = str(clip)[25:-1] if clip else '' unique_id = f'{model_hash};{clip_hash};{lora_name};{model_strength};{clip_strength}' if unique_id in self.loaded_objects["lora"]: log_node_info("Load LORA",f"{lora_name} cached") return self.loaded_objects["lora"][unique_id][0] orig_lora_name = lora_name lora_name = self.resolve_lora_name(lora_name) if lora_name is not None: lora_path = folder_paths.get_full_path("loras", lora_name) else: lora_path = None if lora_path is not None: log_node_info("Load LORA",f"{lora_name}: {model_strength}, {clip_strength}, LBW={lbw}, A={lbw_a}, B={lbw_b}") if lbw: lbw = lora["lbw"] lbw_a = lora["lbw_a"] lbw_b = lora["lbw_b"] if 'LoraLoaderBlockWeight //Inspire' not in NODE_CLASS_MAPPINGS: raise Exception('[InspirePack Not Found] you need to install ComfyUI-Inspire-Pack') cls = NODE_CLASS_MAPPINGS['LoraLoaderBlockWeight //Inspire'] model, clip, _ = cls().doit(model, clip, lora_name, model_strength, clip_strength, False, 0, lbw_a, lbw_b, "", lbw) else: _lora = comfy.utils.load_torch_file(lora_path, safe_load=True) keys = _lora.keys() if "down_blocks.0.resnets.0.norm1.bias" in keys: print('Using LORA for Resadapter') key_map = {} key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) mapping_norm = {} for key in keys: if ".weight" in key: key_name_in_ori_sd = key_map[key.replace(".weight", "")] mapping_norm[key_name_in_ori_sd] = _lora[key] elif ".bias" in key: key_name_in_ori_sd = key_map[key.replace(".bias", "")] mapping_norm[key_name_in_ori_sd.replace(".weight", ".bias")] = _lora[ key ] else: print("===>Unexpected key", key) mapping_norm[key] = _lora[key] for k in mapping_norm.keys(): if k not in model.model.state_dict(): print("===>Missing key:", k) model.model.load_state_dict(mapping_norm, strict=False) return (model, clip) # PixArt if type is not None and type == 'PixArt': from ..dit.pixArt.loader import load_pixart_lora model = load_pixart_lora(model, _lora, lora_path, model_strength) else: model, clip = comfy.sd.load_lora_for_models(model, clip, _lora, model_strength, clip_strength) self.add_to_cache("lora", unique_id, (model, clip)) self.eviction_based_on_memory() else: log_node_error(f"LORA NOT FOUND", orig_lora_name) return model, clip def resolve_lora_name(self, name): if os.path.exists(name): return name else: if len(self.lora_name_cache) == 0: loras = folder_paths.get_filename_list("loras") self.lora_name_cache.extend(loras) for x in self.lora_name_cache: if x.endswith(name): return x # 如果刷新网页后新添加的lora走这个逻辑 log_node_info("LORA NOT IN CACHE", f"{name}") loras = folder_paths.get_filename_list("loras") for x in loras: if x.endswith(name): self.lora_name_cache.append(x) return x return None def load_main(self, ckpt_name, config_name, vae_name, lora_name, lora_model_strength, lora_clip_strength, optional_lora_stack, model_override, clip_override, vae_override, prompt, nf4=False): model: ModelPatcher | None = None clip: comfy.sd.CLIP | None = None vae: comfy.sd.VAE | None = None clip_vision = None lora_stack = [] can_load_lora = True # 判断是否存在 模型或Lora叠加xyplot, 若存在优先缓存第一个模型 xy_model_id = next((x for x in prompt if str(prompt[x]["class_type"]) in ["easy XYInputs: ModelMergeBlocks", "easy XYInputs: Checkpoint"]), None) xy_lora_id = next((x for x in prompt if str(prompt[x]["class_type"]) == "easy XYInputs: Lora"), None) if xy_lora_id is not None: can_load_lora = False if xy_model_id is not None: node = prompt[xy_model_id] if "ckpt_name_1" in node["inputs"]: ckpt_name_1 = node["inputs"]["ckpt_name_1"] model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name_1) can_load_lora = False # Load models elif model_override is not None and clip_override is not None and vae_override is not None: model = model_override clip = clip_override vae = vae_override elif model_override is not None: raise Exception(f"[ERROR] clip or vae is missing") elif vae_override is not None: raise Exception(f"[ERROR] model or clip is missing") elif clip_override is not None: raise Exception(f"[ERROR] model or vae is missing") else: model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name, config_name) if optional_lora_stack is not None and can_load_lora: for lora in optional_lora_stack: lora = {"lora_name": lora[0], "model": model, "clip": clip, "model_strength": lora[1], "clip_strength": lora[2]} model, clip = self.load_lora(lora) lora['model'] = model lora['clip'] = clip lora_stack.append(lora) if lora_name != "None" and can_load_lora: lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": lora_model_strength, "clip_strength": lora_clip_strength} model, clip = self.load_lora(lora) lora_stack.append(lora) # Check for custom VAE if vae_name not in ["Baked VAE", "Baked-VAE"]: vae = self.load_vae(vae_name) # CLIP skip if not clip: raise Exception("No CLIP found") return model, clip, vae, clip_vision, lora_stack # Kolors def load_kolors_unet(self, unet_name): if unet_name in self.loaded_objects["unet"]: log_node_info("Load Kolors UNet", f"{unet_name} cached") return self.loaded_objects["unet"][unet_name][0] else: from ..kolors.loader import applyKolorsUnet with applyKolorsUnet(): unet_path = folder_paths.get_full_path("unet", unet_name) sd = comfy.utils.load_torch_file(unet_path) model = comfy.sd.load_unet_state_dict(sd) if model is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) self.add_to_cache("unet", unet_name, model) self.eviction_based_on_memory() return model def load_chatglm3(self, chatglm3_name): from ..kolors.loader import load_chatglm3 if chatglm3_name in self.loaded_objects["chatglm3"]: log_node_info("Load ChatGLM3", f"{chatglm3_name} cached") return self.loaded_objects["chatglm3"][chatglm3_name][0] chatglm_model = load_chatglm3(model_path=folder_paths.get_full_path("llm", chatglm3_name)) self.add_to_cache("chatglm3", chatglm3_name, chatglm_model) self.eviction_based_on_memory() return chatglm_model # DiT def load_dit_ckpt(self, ckpt_name, model_name, **kwargs): if (ckpt_name+'_'+model_name) in self.loaded_objects["ckpt"]: return self.loaded_objects["ckpt"][ckpt_name+'_'+model_name][0] model = None ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) model_type = kwargs['model_type'] if "model_type" in kwargs else 'PixArt' if model_type == 'PixArt': pixart_conf = kwargs['pixart_conf'] model_conf = pixart_conf[model_name] model = load_pixart(ckpt_path, model_conf) if model: self.add_to_cache("ckpt", ckpt_name + '_' + model_name, model) self.eviction_based_on_memory() return model def load_dit_clip(self, clip_name, **kwargs): if clip_name in self.loaded_objects["clip"]: return self.loaded_objects["clip"][clip_name][0] clip_path = folder_paths.get_full_path("clip", clip_name) sd = comfy.utils.load_torch_file(clip_path) prefix = "bert." state_dict = {} for key in sd: nkey = key if key.startswith(prefix): nkey = key[len(prefix):] state_dict[nkey] = sd[key] m, e = model.load_sd(state_dict) if len(m) > 0 or len(e) > 0: print(f"{clip_name}: clip missing {len(m)} keys ({len(e)} extra)") self.add_to_cache("clip", clip_name, model) self.eviction_based_on_memory() return model def load_dit_t5(self, t5_name, **kwargs): if t5_name in self.loaded_objects["t5"]: return self.loaded_objects["t5"][t5_name][0] model_type = kwargs['model_type'] if "model_type" in kwargs else 'HyDiT' if model_type == 'HyDiT': del kwargs['model_type'] model = EXM_HyDiT_Tenc_Temp(model_class="mT5", **kwargs) t5_path = folder_paths.get_full_path("t5", t5_name) sd = comfy.utils.load_torch_file(t5_path) m, e = model.load_sd(sd) if len(m) > 0 or len(e) > 0: print(f"{t5_name}: mT5 missing {len(m)} keys ({len(e)} extra)") self.add_to_cache("t5", t5_name, model) self.eviction_based_on_memory() return model def load_t5_from_sd3_clip(self, sd3_clip, padding): try: from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel except: from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel import copy clip = sd3_clip.clone() assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!" # remove transformer transformer = clip.cond_stage_model.t5xxl.transformer clip.cond_stage_model.t5xxl.transformer = None # clone object tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False) tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl) # put transformer back clip.cond_stage_model.t5xxl.transformer = transformer tmp.t5xxl.transformer = transformer # override special tokens tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens) tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match # tokenizer tok = SD3Tokenizer() tok.t5xxl.min_length = padding clip.cond_stage_model = tmp clip.tokenizer = tok return clip