|
import comfy_extras.nodes_model_merging |
|
|
|
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
arg_dict = { "model1": ("MODEL",), |
|
"model2": ("MODEL",)} |
|
|
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) |
|
|
|
arg_dict["time_embed."] = argument |
|
arg_dict["label_emb."] = argument |
|
|
|
for i in range(9): |
|
arg_dict["input_blocks.{}".format(i)] = argument |
|
|
|
for i in range(3): |
|
arg_dict["middle_block.{}".format(i)] = argument |
|
|
|
for i in range(9): |
|
arg_dict["output_blocks.{}".format(i)] = argument |
|
|
|
arg_dict["out."] = argument |
|
|
|
return {"required": arg_dict} |
|
|
|
|
|
class ModelMergeSDXLTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
arg_dict = { "model1": ("MODEL",), |
|
"model2": ("MODEL",)} |
|
|
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) |
|
|
|
arg_dict["time_embed."] = argument |
|
arg_dict["label_emb."] = argument |
|
|
|
transformers = {4: 2, 5:2, 7:10, 8:10} |
|
|
|
for i in range(9): |
|
arg_dict["input_blocks.{}.0.".format(i)] = argument |
|
if i in transformers: |
|
arg_dict["input_blocks.{}.1.".format(i)] = argument |
|
for j in range(transformers[i]): |
|
arg_dict["input_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument |
|
|
|
for i in range(3): |
|
arg_dict["middle_block.{}.".format(i)] = argument |
|
if i == 1: |
|
for j in range(10): |
|
arg_dict["middle_block.{}.transformer_blocks.{}.".format(i, j)] = argument |
|
|
|
transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10} |
|
for i in range(9): |
|
arg_dict["output_blocks.{}.0.".format(i)] = argument |
|
t = 8 - i |
|
if t in transformers: |
|
arg_dict["output_blocks.{}.1.".format(i)] = argument |
|
for j in range(transformers[t]): |
|
arg_dict["output_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument |
|
|
|
arg_dict["out."] = argument |
|
|
|
return {"required": arg_dict} |
|
|
|
class ModelMergeSDXLDetailedTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
arg_dict = { "model1": ("MODEL",), |
|
"model2": ("MODEL",)} |
|
|
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) |
|
|
|
arg_dict["time_embed."] = argument |
|
arg_dict["label_emb."] = argument |
|
|
|
transformers = {4: 2, 5:2, 7:10, 8:10} |
|
transformers_args = ["norm1", "attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out", "ff.net", "norm2", "attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out", "norm3"] |
|
|
|
for i in range(9): |
|
arg_dict["input_blocks.{}.0.".format(i)] = argument |
|
if i in transformers: |
|
arg_dict["input_blocks.{}.1.".format(i)] = argument |
|
for j in range(transformers[i]): |
|
for x in transformers_args: |
|
arg_dict["input_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument |
|
|
|
for i in range(3): |
|
arg_dict["middle_block.{}.".format(i)] = argument |
|
if i == 1: |
|
for j in range(10): |
|
for x in transformers_args: |
|
arg_dict["middle_block.{}.transformer_blocks.{}.{}".format(i, j, x)] = argument |
|
|
|
transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10} |
|
for i in range(9): |
|
arg_dict["output_blocks.{}.0.".format(i)] = argument |
|
t = 8 - i |
|
if t in transformers: |
|
arg_dict["output_blocks.{}.1.".format(i)] = argument |
|
for j in range(transformers[t]): |
|
for x in transformers_args: |
|
arg_dict["output_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument |
|
|
|
arg_dict["out."] = argument |
|
|
|
return {"required": arg_dict} |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"ModelMergeSDXL": ModelMergeSDXL, |
|
"ModelMergeSDXLTransformers": ModelMergeSDXLTransformers, |
|
"ModelMergeSDXLDetailedTransformers": ModelMergeSDXLDetailedTransformers, |
|
} |
|
|