import comfy_extras.nodes_model_merging class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks): @classmethod def INPUT_TYPES(s): arg_dict = { "model1": ("MODEL",), "model2": ("MODEL",)} argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) arg_dict["time_embed."] = argument arg_dict["label_emb."] = argument for i in range(9): arg_dict["input_blocks.{}".format(i)] = argument for i in range(3): arg_dict["middle_block.{}".format(i)] = argument for i in range(9): arg_dict["output_blocks.{}".format(i)] = argument arg_dict["out."] = argument return {"required": arg_dict} class ModelMergeSDXLTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks): @classmethod def INPUT_TYPES(s): arg_dict = { "model1": ("MODEL",), "model2": ("MODEL",)} argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) arg_dict["time_embed."] = argument arg_dict["label_emb."] = argument transformers = {4: 2, 5:2, 7:10, 8:10} for i in range(9): arg_dict["input_blocks.{}.0.".format(i)] = argument if i in transformers: arg_dict["input_blocks.{}.1.".format(i)] = argument for j in range(transformers[i]): arg_dict["input_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument for i in range(3): arg_dict["middle_block.{}.".format(i)] = argument if i == 1: for j in range(10): arg_dict["middle_block.{}.transformer_blocks.{}.".format(i, j)] = argument transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10} for i in range(9): arg_dict["output_blocks.{}.0.".format(i)] = argument t = 8 - i if t in transformers: arg_dict["output_blocks.{}.1.".format(i)] = argument for j in range(transformers[t]): arg_dict["output_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument arg_dict["out."] = argument return {"required": arg_dict} class ModelMergeSDXLDetailedTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks): @classmethod def INPUT_TYPES(s): arg_dict = { "model1": ("MODEL",), "model2": ("MODEL",)} argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) arg_dict["time_embed."] = argument arg_dict["label_emb."] = argument transformers = {4: 2, 5:2, 7:10, 8:10} transformers_args = ["norm1", "attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out", "ff.net", "norm2", "attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out", "norm3"] for i in range(9): arg_dict["input_blocks.{}.0.".format(i)] = argument if i in transformers: arg_dict["input_blocks.{}.1.".format(i)] = argument for j in range(transformers[i]): for x in transformers_args: arg_dict["input_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument for i in range(3): arg_dict["middle_block.{}.".format(i)] = argument if i == 1: for j in range(10): for x in transformers_args: arg_dict["middle_block.{}.transformer_blocks.{}.{}".format(i, j, x)] = argument transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10} for i in range(9): arg_dict["output_blocks.{}.0.".format(i)] = argument t = 8 - i if t in transformers: arg_dict["output_blocks.{}.1.".format(i)] = argument for j in range(transformers[t]): for x in transformers_args: arg_dict["output_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument arg_dict["out."] = argument return {"required": arg_dict} NODE_CLASS_MAPPINGS = { "ModelMergeSDXL": ModelMergeSDXL, "ModelMergeSDXLTransformers": ModelMergeSDXLTransformers, "ModelMergeSDXLDetailedTransformers": ModelMergeSDXLDetailedTransformers, }