File size: 4,383 Bytes
baa8e90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import comfy_extras.nodes_model_merging
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
for i in range(9):
arg_dict["input_blocks.{}".format(i)] = argument
for i in range(3):
arg_dict["middle_block.{}".format(i)] = argument
for i in range(9):
arg_dict["output_blocks.{}".format(i)] = argument
arg_dict["out."] = argument
return {"required": arg_dict}
class ModelMergeSDXLTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks):
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
transformers = {4: 2, 5:2, 7:10, 8:10}
for i in range(9):
arg_dict["input_blocks.{}.0.".format(i)] = argument
if i in transformers:
arg_dict["input_blocks.{}.1.".format(i)] = argument
for j in range(transformers[i]):
arg_dict["input_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument
for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
if i == 1:
for j in range(10):
arg_dict["middle_block.{}.transformer_blocks.{}.".format(i, j)] = argument
transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10}
for i in range(9):
arg_dict["output_blocks.{}.0.".format(i)] = argument
t = 8 - i
if t in transformers:
arg_dict["output_blocks.{}.1.".format(i)] = argument
for j in range(transformers[t]):
arg_dict["output_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument
arg_dict["out."] = argument
return {"required": arg_dict}
class ModelMergeSDXLDetailedTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks):
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument
transformers = {4: 2, 5:2, 7:10, 8:10}
transformers_args = ["norm1", "attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out", "ff.net", "norm2", "attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out", "norm3"]
for i in range(9):
arg_dict["input_blocks.{}.0.".format(i)] = argument
if i in transformers:
arg_dict["input_blocks.{}.1.".format(i)] = argument
for j in range(transformers[i]):
for x in transformers_args:
arg_dict["input_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument
for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
if i == 1:
for j in range(10):
for x in transformers_args:
arg_dict["middle_block.{}.transformer_blocks.{}.{}".format(i, j, x)] = argument
transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10}
for i in range(9):
arg_dict["output_blocks.{}.0.".format(i)] = argument
t = 8 - i
if t in transformers:
arg_dict["output_blocks.{}.1.".format(i)] = argument
for j in range(transformers[t]):
for x in transformers_args:
arg_dict["output_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument
arg_dict["out."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSDXLTransformers": ModelMergeSDXLTransformers,
"ModelMergeSDXLDetailedTransformers": ModelMergeSDXLDetailedTransformers,
}
|