Freak-ppa commited on
Commit
8d4cfef
·
verified ·
1 Parent(s): 6b35f98

Upload 388 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ComfyUI/CODEOWNERS +1 -0
  2. ComfyUI/CONTRIBUTING.md +41 -0
  3. ComfyUI/comfy/__pycache__/checkpoint_pickle.cpython-310.pyc +0 -0
  4. ComfyUI/comfy/__pycache__/cli_args.cpython-310.pyc +0 -0
  5. ComfyUI/comfy/__pycache__/clip_model.cpython-310.pyc +0 -0
  6. ComfyUI/comfy/__pycache__/clip_vision.cpython-310.pyc +0 -0
  7. ComfyUI/comfy/__pycache__/conds.cpython-310.pyc +0 -0
  8. ComfyUI/comfy/__pycache__/controlnet.cpython-310.pyc +0 -0
  9. ComfyUI/comfy/__pycache__/diffusers_convert.cpython-310.pyc +0 -0
  10. ComfyUI/comfy/__pycache__/diffusers_load.cpython-310.pyc +0 -0
  11. ComfyUI/comfy/__pycache__/gligen.cpython-310.pyc +0 -0
  12. ComfyUI/comfy/__pycache__/latent_formats.cpython-310.pyc +0 -0
  13. ComfyUI/comfy/__pycache__/lora.cpython-310.pyc +0 -0
  14. ComfyUI/comfy/__pycache__/model_base.cpython-310.pyc +0 -0
  15. ComfyUI/comfy/__pycache__/model_detection.cpython-310.pyc +0 -0
  16. ComfyUI/comfy/__pycache__/model_management.cpython-310.pyc +0 -0
  17. ComfyUI/comfy/__pycache__/model_patcher.cpython-310.pyc +0 -0
  18. ComfyUI/comfy/__pycache__/model_sampling.cpython-310.pyc +0 -0
  19. ComfyUI/comfy/__pycache__/ops.cpython-310.pyc +0 -0
  20. ComfyUI/comfy/__pycache__/options.cpython-310.pyc +0 -0
  21. ComfyUI/comfy/__pycache__/sample.cpython-310.pyc +0 -0
  22. ComfyUI/comfy/__pycache__/sampler_helpers.cpython-310.pyc +0 -0
  23. ComfyUI/comfy/__pycache__/samplers.cpython-310.pyc +0 -0
  24. ComfyUI/comfy/__pycache__/sd.cpython-310.pyc +0 -0
  25. ComfyUI/comfy/__pycache__/sd1_clip.cpython-310.pyc +0 -0
  26. ComfyUI/comfy/__pycache__/sdxl_clip.cpython-310.pyc +0 -0
  27. ComfyUI/comfy/__pycache__/supported_models.cpython-310.pyc +0 -0
  28. ComfyUI/comfy/__pycache__/supported_models_base.cpython-310.pyc +0 -0
  29. ComfyUI/comfy/__pycache__/types.cpython-310.pyc +0 -0
  30. ComfyUI/comfy/__pycache__/utils.cpython-310.pyc +0 -0
  31. ComfyUI/comfy/checkpoint_pickle.py +13 -0
  32. ComfyUI/comfy/cldm/__pycache__/cldm.cpython-310.pyc +0 -0
  33. ComfyUI/comfy/cldm/__pycache__/control_types.cpython-310.pyc +0 -0
  34. ComfyUI/comfy/cldm/__pycache__/mmdit.cpython-310.pyc +0 -0
  35. ComfyUI/comfy/cldm/cldm.py +437 -0
  36. ComfyUI/comfy/cldm/control_types.py +10 -0
  37. ComfyUI/comfy/cldm/mmdit.py +77 -0
  38. ComfyUI/comfy/cli_args.py +180 -0
  39. ComfyUI/comfy/clip_config_bigg.json +23 -0
  40. ComfyUI/comfy/clip_model.py +196 -0
  41. ComfyUI/comfy/clip_vision.py +121 -0
  42. ComfyUI/comfy/clip_vision_config_g.json +18 -0
  43. ComfyUI/comfy/clip_vision_config_h.json +18 -0
  44. ComfyUI/comfy/clip_vision_config_vitl.json +18 -0
  45. ComfyUI/comfy/clip_vision_config_vitl_336.json +18 -0
  46. ComfyUI/comfy/conds.py +83 -0
  47. ComfyUI/comfy/controlnet.py +622 -0
  48. ComfyUI/comfy/diffusers_convert.py +281 -0
  49. ComfyUI/comfy/diffusers_load.py +36 -0
  50. ComfyUI/comfy/extra_samplers/__pycache__/uni_pc.cpython-310.pyc +0 -0
ComfyUI/CODEOWNERS ADDED
@@ -0,0 +1 @@
 
 
1
+ * @comfyanonymous
ComfyUI/CONTRIBUTING.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to ComfyUI
2
+
3
+ Welcome, and thank you for your interest in contributing to ComfyUI!
4
+
5
+ There are several ways in which you can contribute, beyond writing code. The goal of this document is to provide a high-level overview of how you can get involved.
6
+
7
+ ## Asking Questions
8
+
9
+ Have a question? Instead of opening an issue, please ask on [Discord](https://comfy.org/discord) or [Matrix](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) channels. Our team and the community will help you.
10
+
11
+ ## Providing Feedback
12
+
13
+ Your comments and feedback are welcome, and the development team is available via a handful of different channels.
14
+
15
+ See the `#bug-report`, `#feature-request` and `#feedback` channels on Discord.
16
+
17
+ ## Reporting Issues
18
+
19
+ Have you identified a reproducible problem in ComfyUI? Do you have a feature request? We want to hear about it! Here's how you can report your issue as effectively as possible.
20
+
21
+
22
+ ### Look For an Existing Issue
23
+
24
+ Before you create a new issue, please do a search in [open issues](https://github.com/comfyanonymous/ComfyUI/issues) to see if the issue or feature request has already been filed.
25
+
26
+ If you find your issue already exists, make relevant comments and add your [reaction](https://github.com/blog/2119-add-reactions-to-pull-requests-issues-and-comments). Use a reaction in place of a "+1" comment:
27
+
28
+ * 👍 - upvote
29
+ * 👎 - downvote
30
+
31
+ If you cannot find an existing issue that describes your bug or feature, create a new issue. We have an issue template in place to organize new issues.
32
+
33
+
34
+ ### Creating Pull Requests
35
+
36
+ * Please refer to the article on [creating pull requests](https://github.com/comfyanonymous/ComfyUI/wiki/How-to-Contribute-Code) and contributing to this project.
37
+
38
+
39
+ ## Thank You
40
+
41
+ Your contributions to open source, large or small, make great projects like this possible. Thank you for taking the time to contribute.
ComfyUI/comfy/__pycache__/checkpoint_pickle.cpython-310.pyc ADDED
Binary file (729 Bytes). View file
 
ComfyUI/comfy/__pycache__/cli_args.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
ComfyUI/comfy/__pycache__/clip_model.cpython-310.pyc ADDED
Binary file (8.9 kB). View file
 
ComfyUI/comfy/__pycache__/clip_vision.cpython-310.pyc ADDED
Binary file (5.38 kB). View file
 
ComfyUI/comfy/__pycache__/conds.cpython-310.pyc ADDED
Binary file (3.3 kB). View file
 
ComfyUI/comfy/__pycache__/controlnet.cpython-310.pyc ADDED
Binary file (18.9 kB). View file
 
ComfyUI/comfy/__pycache__/diffusers_convert.cpython-310.pyc ADDED
Binary file (7.2 kB). View file
 
ComfyUI/comfy/__pycache__/diffusers_load.cpython-310.pyc ADDED
Binary file (1.33 kB). View file
 
ComfyUI/comfy/__pycache__/gligen.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
ComfyUI/comfy/__pycache__/latent_formats.cpython-310.pyc ADDED
Binary file (6.53 kB). View file
 
ComfyUI/comfy/__pycache__/lora.cpython-310.pyc ADDED
Binary file (6.64 kB). View file
 
ComfyUI/comfy/__pycache__/model_base.cpython-310.pyc ADDED
Binary file (24.4 kB). View file
 
ComfyUI/comfy/__pycache__/model_detection.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
ComfyUI/comfy/__pycache__/model_management.cpython-310.pyc ADDED
Binary file (21.8 kB). View file
 
ComfyUI/comfy/__pycache__/model_patcher.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
ComfyUI/comfy/__pycache__/model_sampling.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
ComfyUI/comfy/__pycache__/ops.cpython-310.pyc ADDED
Binary file (9.69 kB). View file
 
ComfyUI/comfy/__pycache__/options.cpython-310.pyc ADDED
Binary file (299 Bytes). View file
 
ComfyUI/comfy/__pycache__/sample.cpython-310.pyc ADDED
Binary file (2.88 kB). View file
 
ComfyUI/comfy/__pycache__/sampler_helpers.cpython-310.pyc ADDED
Binary file (2.74 kB). View file
 
ComfyUI/comfy/__pycache__/samplers.cpython-310.pyc ADDED
Binary file (22.7 kB). View file
 
ComfyUI/comfy/__pycache__/sd.cpython-310.pyc ADDED
Binary file (23.1 kB). View file
 
ComfyUI/comfy/__pycache__/sd1_clip.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
ComfyUI/comfy/__pycache__/sdxl_clip.cpython-310.pyc ADDED
Binary file (5.65 kB). View file
 
ComfyUI/comfy/__pycache__/supported_models.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
ComfyUI/comfy/__pycache__/supported_models_base.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
ComfyUI/comfy/__pycache__/types.cpython-310.pyc ADDED
Binary file (1.38 kB). View file
 
ComfyUI/comfy/__pycache__/utils.cpython-310.pyc ADDED
Binary file (23 kB). View file
 
ComfyUI/comfy/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
ComfyUI/comfy/cldm/__pycache__/cldm.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
ComfyUI/comfy/cldm/__pycache__/control_types.cpython-310.pyc ADDED
Binary file (370 Bytes). View file
 
ComfyUI/comfy/cldm/__pycache__/mmdit.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
ComfyUI/comfy/cldm/cldm.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ..ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ..ldm.modules.attention import SpatialTransformer
14
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ..ldm.util import exists
16
+ from .control_types import UNION_CONTROLNET_TYPES
17
+ from collections import OrderedDict
18
+ import comfy.ops
19
+ from comfy.ldm.modules.attention import optimized_attention
20
+
21
+ class OptimizedAttention(nn.Module):
22
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
23
+ super().__init__()
24
+ self.heads = nhead
25
+ self.c = c
26
+
27
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
28
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
29
+
30
+ def forward(self, x):
31
+ x = self.in_proj(x)
32
+ q, k, v = x.split(self.c, dim=2)
33
+ out = optimized_attention(q, k, v, self.heads)
34
+ return self.out_proj(out)
35
+
36
+ class QuickGELU(nn.Module):
37
+ def forward(self, x: torch.Tensor):
38
+ return x * torch.sigmoid(1.702 * x)
39
+
40
+ class ResBlockUnionControlnet(nn.Module):
41
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
42
+ super().__init__()
43
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
44
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
45
+ self.mlp = nn.Sequential(
46
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
47
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
48
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
49
+
50
+ def attention(self, x: torch.Tensor):
51
+ return self.attn(x)
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = x + self.attention(self.ln_1(x))
55
+ x = x + self.mlp(self.ln_2(x))
56
+ return x
57
+
58
+ class ControlledUnetModel(UNetModel):
59
+ #implemented in the ldm unet
60
+ pass
61
+
62
+ class ControlNet(nn.Module):
63
+ def __init__(
64
+ self,
65
+ image_size,
66
+ in_channels,
67
+ model_channels,
68
+ hint_channels,
69
+ num_res_blocks,
70
+ dropout=0,
71
+ channel_mult=(1, 2, 4, 8),
72
+ conv_resample=True,
73
+ dims=2,
74
+ num_classes=None,
75
+ use_checkpoint=False,
76
+ dtype=torch.float32,
77
+ num_heads=-1,
78
+ num_head_channels=-1,
79
+ num_heads_upsample=-1,
80
+ use_scale_shift_norm=False,
81
+ resblock_updown=False,
82
+ use_new_attention_order=False,
83
+ use_spatial_transformer=False, # custom transformer support
84
+ transformer_depth=1, # custom transformer support
85
+ context_dim=None, # custom transformer support
86
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
87
+ legacy=True,
88
+ disable_self_attentions=None,
89
+ num_attention_blocks=None,
90
+ disable_middle_self_attn=False,
91
+ use_linear_in_transformer=False,
92
+ adm_in_channels=None,
93
+ transformer_depth_middle=None,
94
+ transformer_depth_output=None,
95
+ attn_precision=None,
96
+ union_controlnet_num_control_type=None,
97
+ device=None,
98
+ operations=comfy.ops.disable_weight_init,
99
+ **kwargs,
100
+ ):
101
+ super().__init__()
102
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
103
+ if use_spatial_transformer:
104
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
105
+
106
+ if context_dim is not None:
107
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
108
+ # from omegaconf.listconfig import ListConfig
109
+ # if type(context_dim) == ListConfig:
110
+ # context_dim = list(context_dim)
111
+
112
+ if num_heads_upsample == -1:
113
+ num_heads_upsample = num_heads
114
+
115
+ if num_heads == -1:
116
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
117
+
118
+ if num_head_channels == -1:
119
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
120
+
121
+ self.dims = dims
122
+ self.image_size = image_size
123
+ self.in_channels = in_channels
124
+ self.model_channels = model_channels
125
+
126
+ if isinstance(num_res_blocks, int):
127
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
128
+ else:
129
+ if len(num_res_blocks) != len(channel_mult):
130
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
131
+ "as a list/tuple (per-level) with the same length as channel_mult")
132
+ self.num_res_blocks = num_res_blocks
133
+
134
+ if disable_self_attentions is not None:
135
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
136
+ assert len(disable_self_attentions) == len(channel_mult)
137
+ if num_attention_blocks is not None:
138
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
139
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
140
+
141
+ transformer_depth = transformer_depth[:]
142
+
143
+ self.dropout = dropout
144
+ self.channel_mult = channel_mult
145
+ self.conv_resample = conv_resample
146
+ self.num_classes = num_classes
147
+ self.use_checkpoint = use_checkpoint
148
+ self.dtype = dtype
149
+ self.num_heads = num_heads
150
+ self.num_head_channels = num_head_channels
151
+ self.num_heads_upsample = num_heads_upsample
152
+ self.predict_codebook_ids = n_embed is not None
153
+
154
+ time_embed_dim = model_channels * 4
155
+ self.time_embed = nn.Sequential(
156
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
157
+ nn.SiLU(),
158
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
159
+ )
160
+
161
+ if self.num_classes is not None:
162
+ if isinstance(self.num_classes, int):
163
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
164
+ elif self.num_classes == "continuous":
165
+ print("setting up linear c_adm embedding layer")
166
+ self.label_emb = nn.Linear(1, time_embed_dim)
167
+ elif self.num_classes == "sequential":
168
+ assert adm_in_channels is not None
169
+ self.label_emb = nn.Sequential(
170
+ nn.Sequential(
171
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
172
+ nn.SiLU(),
173
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
174
+ )
175
+ )
176
+ else:
177
+ raise ValueError()
178
+
179
+ self.input_blocks = nn.ModuleList(
180
+ [
181
+ TimestepEmbedSequential(
182
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
183
+ )
184
+ ]
185
+ )
186
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
187
+
188
+ self.input_hint_block = TimestepEmbedSequential(
189
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
190
+ nn.SiLU(),
191
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
192
+ nn.SiLU(),
193
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
194
+ nn.SiLU(),
195
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
196
+ nn.SiLU(),
197
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
198
+ nn.SiLU(),
199
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
200
+ nn.SiLU(),
201
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
202
+ nn.SiLU(),
203
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
204
+ )
205
+
206
+ self._feature_size = model_channels
207
+ input_block_chans = [model_channels]
208
+ ch = model_channels
209
+ ds = 1
210
+ for level, mult in enumerate(channel_mult):
211
+ for nr in range(self.num_res_blocks[level]):
212
+ layers = [
213
+ ResBlock(
214
+ ch,
215
+ time_embed_dim,
216
+ dropout,
217
+ out_channels=mult * model_channels,
218
+ dims=dims,
219
+ use_checkpoint=use_checkpoint,
220
+ use_scale_shift_norm=use_scale_shift_norm,
221
+ dtype=self.dtype,
222
+ device=device,
223
+ operations=operations,
224
+ )
225
+ ]
226
+ ch = mult * model_channels
227
+ num_transformers = transformer_depth.pop(0)
228
+ if num_transformers > 0:
229
+ if num_head_channels == -1:
230
+ dim_head = ch // num_heads
231
+ else:
232
+ num_heads = ch // num_head_channels
233
+ dim_head = num_head_channels
234
+ if legacy:
235
+ #num_heads = 1
236
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
237
+ if exists(disable_self_attentions):
238
+ disabled_sa = disable_self_attentions[level]
239
+ else:
240
+ disabled_sa = False
241
+
242
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
243
+ layers.append(
244
+ SpatialTransformer(
245
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
246
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
247
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
248
+ )
249
+ )
250
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
251
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
252
+ self._feature_size += ch
253
+ input_block_chans.append(ch)
254
+ if level != len(channel_mult) - 1:
255
+ out_ch = ch
256
+ self.input_blocks.append(
257
+ TimestepEmbedSequential(
258
+ ResBlock(
259
+ ch,
260
+ time_embed_dim,
261
+ dropout,
262
+ out_channels=out_ch,
263
+ dims=dims,
264
+ use_checkpoint=use_checkpoint,
265
+ use_scale_shift_norm=use_scale_shift_norm,
266
+ down=True,
267
+ dtype=self.dtype,
268
+ device=device,
269
+ operations=operations
270
+ )
271
+ if resblock_updown
272
+ else Downsample(
273
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
274
+ )
275
+ )
276
+ )
277
+ ch = out_ch
278
+ input_block_chans.append(ch)
279
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
280
+ ds *= 2
281
+ self._feature_size += ch
282
+
283
+ if num_head_channels == -1:
284
+ dim_head = ch // num_heads
285
+ else:
286
+ num_heads = ch // num_head_channels
287
+ dim_head = num_head_channels
288
+ if legacy:
289
+ #num_heads = 1
290
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
291
+ mid_block = [
292
+ ResBlock(
293
+ ch,
294
+ time_embed_dim,
295
+ dropout,
296
+ dims=dims,
297
+ use_checkpoint=use_checkpoint,
298
+ use_scale_shift_norm=use_scale_shift_norm,
299
+ dtype=self.dtype,
300
+ device=device,
301
+ operations=operations
302
+ )]
303
+ if transformer_depth_middle >= 0:
304
+ mid_block += [SpatialTransformer( # always uses a self-attn
305
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
306
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
307
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
308
+ ),
309
+ ResBlock(
310
+ ch,
311
+ time_embed_dim,
312
+ dropout,
313
+ dims=dims,
314
+ use_checkpoint=use_checkpoint,
315
+ use_scale_shift_norm=use_scale_shift_norm,
316
+ dtype=self.dtype,
317
+ device=device,
318
+ operations=operations
319
+ )]
320
+ self.middle_block = TimestepEmbedSequential(*mid_block)
321
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
322
+ self._feature_size += ch
323
+
324
+ if union_controlnet_num_control_type is not None:
325
+ self.num_control_type = union_controlnet_num_control_type
326
+ num_trans_channel = 320
327
+ num_trans_head = 8
328
+ num_trans_layer = 1
329
+ num_proj_channel = 320
330
+ # task_scale_factor = num_trans_channel ** 0.5
331
+ self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
332
+
333
+ self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
334
+ self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
335
+ #-----------------------------------------------------------------------------------------------------
336
+
337
+ control_add_embed_dim = 256
338
+ class ControlAddEmbedding(nn.Module):
339
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
340
+ super().__init__()
341
+ self.num_control_type = num_control_type
342
+ self.in_dim = in_dim
343
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
344
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
345
+ def forward(self, control_type, dtype, device):
346
+ c_type = torch.zeros((self.num_control_type,), device=device)
347
+ c_type[control_type] = 1.0
348
+ c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
349
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
350
+
351
+ self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
352
+ else:
353
+ self.task_embedding = None
354
+ self.control_add_embedding = None
355
+
356
+ def union_controlnet_merge(self, hint, control_type, emb, context):
357
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
358
+ inputs = []
359
+ condition_list = []
360
+
361
+ for idx in range(min(1, len(control_type))):
362
+ controlnet_cond = self.input_hint_block(hint[idx], emb, context)
363
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
364
+ if idx < len(control_type):
365
+ feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
366
+
367
+ inputs.append(feat_seq.unsqueeze(1))
368
+ condition_list.append(controlnet_cond)
369
+
370
+ x = torch.cat(inputs, dim=1)
371
+ x = self.transformer_layes(x)
372
+ controlnet_cond_fuser = None
373
+ for idx in range(len(control_type)):
374
+ alpha = self.spatial_ch_projs(x[:, idx])
375
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
376
+ o = condition_list[idx] + alpha
377
+ if controlnet_cond_fuser is None:
378
+ controlnet_cond_fuser = o
379
+ else:
380
+ controlnet_cond_fuser += o
381
+ return controlnet_cond_fuser
382
+
383
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
384
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
385
+
386
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
387
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
388
+ emb = self.time_embed(t_emb)
389
+
390
+ guided_hint = None
391
+ if self.control_add_embedding is not None: #Union Controlnet
392
+ control_type = kwargs.get("control_type", [])
393
+
394
+ if any([c >= self.num_control_type for c in control_type]):
395
+ max_type = max(control_type)
396
+ max_type_name = {
397
+ v: k for k, v in UNION_CONTROLNET_TYPES.items()
398
+ }[max_type]
399
+ raise ValueError(
400
+ f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
401
+ f"({self.num_control_type}) supported.\n" +
402
+ "Please consider using the ProMax ControlNet Union model.\n" +
403
+ "https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
404
+ )
405
+
406
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
407
+ if len(control_type) > 0:
408
+ if len(hint.shape) < 5:
409
+ hint = hint.unsqueeze(dim=0)
410
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
411
+
412
+ if guided_hint is None:
413
+ guided_hint = self.input_hint_block(hint, emb, context)
414
+
415
+ out_output = []
416
+ out_middle = []
417
+
418
+ hs = []
419
+ if self.num_classes is not None:
420
+ assert y.shape[0] == x.shape[0]
421
+ emb = emb + self.label_emb(y)
422
+
423
+ h = x
424
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
425
+ if guided_hint is not None:
426
+ h = module(h, emb, context)
427
+ h += guided_hint
428
+ guided_hint = None
429
+ else:
430
+ h = module(h, emb, context)
431
+ out_output.append(zero_conv(h, emb, context))
432
+
433
+ h = self.middle_block(h, emb, context)
434
+ out_middle.append(self.middle_block_out(h, emb, context))
435
+
436
+ return {"middle": out_middle, "output": out_output}
437
+
ComfyUI/comfy/cldm/control_types.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ UNION_CONTROLNET_TYPES = {
2
+ "openpose": 0,
3
+ "depth": 1,
4
+ "hed/pidi/scribble/ted": 2,
5
+ "canny/lineart/anime_lineart/mlsd": 3,
6
+ "normal": 4,
7
+ "segment": 5,
8
+ "tile": 6,
9
+ "repaint": 7,
10
+ }
ComfyUI/comfy/cldm/mmdit.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Optional
3
+ import comfy.ldm.modules.diffusionmodules.mmdit
4
+
5
+ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
6
+ def __init__(
7
+ self,
8
+ num_blocks = None,
9
+ dtype = None,
10
+ device = None,
11
+ operations = None,
12
+ **kwargs,
13
+ ):
14
+ super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
15
+ # controlnet_blocks
16
+ self.controlnet_blocks = torch.nn.ModuleList([])
17
+ for _ in range(len(self.joint_blocks)):
18
+ self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
19
+
20
+ self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
21
+ None,
22
+ self.patch_size,
23
+ self.in_channels,
24
+ self.hidden_size,
25
+ bias=True,
26
+ strict_img_size=False,
27
+ dtype=dtype,
28
+ device=device,
29
+ operations=operations
30
+ )
31
+
32
+ def forward(
33
+ self,
34
+ x: torch.Tensor,
35
+ timesteps: torch.Tensor,
36
+ y: Optional[torch.Tensor] = None,
37
+ context: Optional[torch.Tensor] = None,
38
+ hint = None,
39
+ ) -> torch.Tensor:
40
+
41
+ #weird sd3 controlnet specific stuff
42
+ y = torch.zeros_like(y)
43
+
44
+ if self.context_processor is not None:
45
+ context = self.context_processor(context)
46
+
47
+ hw = x.shape[-2:]
48
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
49
+ x += self.pos_embed_input(hint)
50
+
51
+ c = self.t_embedder(timesteps, dtype=x.dtype)
52
+ if y is not None and self.y_embedder is not None:
53
+ y = self.y_embedder(y)
54
+ c = c + y
55
+
56
+ if context is not None:
57
+ context = self.context_embedder(context)
58
+
59
+ output = []
60
+
61
+ blocks = len(self.joint_blocks)
62
+ for i in range(blocks):
63
+ context, x = self.joint_blocks[i](
64
+ context,
65
+ x,
66
+ c=c,
67
+ use_checkpoint=self.use_checkpoint,
68
+ )
69
+
70
+ out = self.controlnet_blocks[i](x)
71
+ count = self.depth // blocks
72
+ if i == blocks - 1:
73
+ count -= 1
74
+ for j in range(count):
75
+ output.append(out)
76
+
77
+ return {"output": output}
ComfyUI/comfy/cli_args.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import os
4
+ from typing import Optional
5
+ import comfy.options
6
+
7
+
8
+ class EnumAction(argparse.Action):
9
+ """
10
+ Argparse action for handling Enums
11
+ """
12
+ def __init__(self, **kwargs):
13
+ # Pop off the type value
14
+ enum_type = kwargs.pop("type", None)
15
+
16
+ # Ensure an Enum subclass is provided
17
+ if enum_type is None:
18
+ raise ValueError("type must be assigned an Enum when using EnumAction")
19
+ if not issubclass(enum_type, enum.Enum):
20
+ raise TypeError("type must be an Enum when using EnumAction")
21
+
22
+ # Generate choices from the Enum
23
+ choices = tuple(e.value for e in enum_type)
24
+ kwargs.setdefault("choices", choices)
25
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
26
+
27
+ super(EnumAction, self).__init__(**kwargs)
28
+
29
+ self._enum = enum_type
30
+
31
+ def __call__(self, parser, namespace, values, option_string=None):
32
+ # Convert value back into an Enum
33
+ value = self._enum(values)
34
+ setattr(namespace, self.dest, value)
35
+
36
+
37
+ parser = argparse.ArgumentParser()
38
+
39
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
40
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
41
+ parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
42
+ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
43
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
44
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
45
+
46
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
47
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
48
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
49
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
50
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
51
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
53
+ cm_group = parser.add_mutually_exclusive_group()
54
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
55
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
56
+
57
+
58
+ fp_group = parser.add_mutually_exclusive_group()
59
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
60
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
61
+
62
+ fpunet_group = parser.add_mutually_exclusive_group()
63
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
64
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
65
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
66
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
67
+
68
+ fpvae_group = parser.add_mutually_exclusive_group()
69
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
70
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
71
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
72
+
73
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
74
+
75
+ fpte_group = parser.add_mutually_exclusive_group()
76
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
77
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
78
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
79
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
80
+
81
+ parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
82
+
83
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
84
+
85
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
86
+
87
+ class LatentPreviewMethod(enum.Enum):
88
+ NoPreviews = "none"
89
+ Auto = "auto"
90
+ Latent2RGB = "latent2rgb"
91
+ TAESD = "taesd"
92
+
93
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
94
+
95
+ attn_group = parser.add_mutually_exclusive_group()
96
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
97
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
98
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
99
+
100
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
101
+
102
+ upcast = parser.add_mutually_exclusive_group()
103
+ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
104
+ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
105
+
106
+
107
+ vram_group = parser.add_mutually_exclusive_group()
108
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
109
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
110
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
111
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
112
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
113
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
114
+
115
+ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
116
+
117
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
118
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
119
+
120
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
121
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
122
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
123
+
124
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
125
+ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
126
+
127
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
128
+
129
+ parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
130
+
131
+ # The default built-in provider hosted under web/
132
+ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
133
+
134
+ parser.add_argument(
135
+ "--front-end-version",
136
+ type=str,
137
+ default=DEFAULT_VERSION_STRING,
138
+ help="""
139
+ Specifies the version of the frontend to be used. This command needs internet connectivity to query and
140
+ download available frontend implementations from GitHub releases.
141
+
142
+ The version string should be in the format of:
143
+ [repoOwner]/[repoName]@[version]
144
+ where version is one of: "latest" or a valid version number (e.g. "1.0.0")
145
+ """,
146
+ )
147
+
148
+ def is_valid_directory(path: Optional[str]) -> Optional[str]:
149
+ """Validate if the given path is a directory."""
150
+ if path is None:
151
+ return None
152
+
153
+ if not os.path.isdir(path):
154
+ raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
155
+ return path
156
+
157
+ parser.add_argument(
158
+ "--front-end-root",
159
+ type=is_valid_directory,
160
+ default=None,
161
+ help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
162
+ )
163
+
164
+ if comfy.options.args_parsing:
165
+ args = parser.parse_args()
166
+ else:
167
+ args = parser.parse_args([])
168
+
169
+ if args.windows_standalone_build:
170
+ args.auto_launch = True
171
+
172
+ if args.disable_auto_launch:
173
+ args.auto_launch = False
174
+
175
+ import logging
176
+ logging_level = logging.INFO
177
+ if args.verbose:
178
+ logging_level = logging.DEBUG
179
+
180
+ logging.basicConfig(format="%(message)s", level=logging_level)
ComfyUI/comfy/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
ComfyUI/comfy/clip_model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.ldm.modules.attention import optimized_attention_for_device
3
+ import comfy.ops
4
+
5
+ class CLIPAttention(torch.nn.Module):
6
+ def __init__(self, embed_dim, heads, dtype, device, operations):
7
+ super().__init__()
8
+
9
+ self.heads = heads
10
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
13
+
14
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
15
+
16
+ def forward(self, x, mask=None, optimized_attention=None):
17
+ q = self.q_proj(x)
18
+ k = self.k_proj(x)
19
+ v = self.v_proj(x)
20
+
21
+ out = optimized_attention(q, k, v, self.heads, mask)
22
+ return self.out_proj(out)
23
+
24
+ ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
25
+ "gelu": torch.nn.functional.gelu,
26
+ }
27
+
28
+ class CLIPMLP(torch.nn.Module):
29
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
30
+ super().__init__()
31
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
32
+ self.activation = ACTIVATIONS[activation]
33
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.activation(x)
38
+ x = self.fc2(x)
39
+ return x
40
+
41
+ class CLIPLayer(torch.nn.Module):
42
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
43
+ super().__init__()
44
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
45
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
46
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
47
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
48
+
49
+ def forward(self, x, mask=None, optimized_attention=None):
50
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
51
+ x += self.mlp(self.layer_norm2(x))
52
+ return x
53
+
54
+
55
+ class CLIPEncoder(torch.nn.Module):
56
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
57
+ super().__init__()
58
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
59
+
60
+ def forward(self, x, mask=None, intermediate_output=None):
61
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
62
+
63
+ if intermediate_output is not None:
64
+ if intermediate_output < 0:
65
+ intermediate_output = len(self.layers) + intermediate_output
66
+
67
+ intermediate = None
68
+ for i, l in enumerate(self.layers):
69
+ x = l(x, mask, optimized_attention)
70
+ if i == intermediate_output:
71
+ intermediate = x.clone()
72
+ return x, intermediate
73
+
74
+ class CLIPEmbeddings(torch.nn.Module):
75
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
76
+ super().__init__()
77
+ self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
78
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
79
+
80
+ def forward(self, input_tokens, dtype=torch.float32):
81
+ return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
82
+
83
+
84
+ class CLIPTextModel_(torch.nn.Module):
85
+ def __init__(self, config_dict, dtype, device, operations):
86
+ num_layers = config_dict["num_hidden_layers"]
87
+ embed_dim = config_dict["hidden_size"]
88
+ heads = config_dict["num_attention_heads"]
89
+ intermediate_size = config_dict["intermediate_size"]
90
+ intermediate_activation = config_dict["hidden_act"]
91
+ self.eos_token_id = config_dict["eos_token_id"]
92
+
93
+ super().__init__()
94
+ self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
95
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
96
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
97
+
98
+ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
99
+ x = self.embeddings(input_tokens, dtype=dtype)
100
+ mask = None
101
+ if attention_mask is not None:
102
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
103
+ mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
104
+
105
+ causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
106
+ if mask is not None:
107
+ mask += causal_mask
108
+ else:
109
+ mask = causal_mask
110
+
111
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
112
+ x = self.final_layer_norm(x)
113
+ if i is not None and final_layer_norm_intermediate:
114
+ i = self.final_layer_norm(i)
115
+
116
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
117
+ return x, i, pooled_output
118
+
119
+ class CLIPTextModel(torch.nn.Module):
120
+ def __init__(self, config_dict, dtype, device, operations):
121
+ super().__init__()
122
+ self.num_layers = config_dict["num_hidden_layers"]
123
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
124
+ embed_dim = config_dict["hidden_size"]
125
+ self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
126
+ self.text_projection.weight.copy_(torch.eye(embed_dim))
127
+ self.dtype = dtype
128
+
129
+ def get_input_embeddings(self):
130
+ return self.text_model.embeddings.token_embedding
131
+
132
+ def set_input_embeddings(self, embeddings):
133
+ self.text_model.embeddings.token_embedding = embeddings
134
+
135
+ def forward(self, *args, **kwargs):
136
+ x = self.text_model(*args, **kwargs)
137
+ out = self.text_projection(x[2])
138
+ return (x[0], x[1], out, x[2])
139
+
140
+
141
+ class CLIPVisionEmbeddings(torch.nn.Module):
142
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
143
+ super().__init__()
144
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
145
+
146
+ self.patch_embedding = operations.Conv2d(
147
+ in_channels=num_channels,
148
+ out_channels=embed_dim,
149
+ kernel_size=patch_size,
150
+ stride=patch_size,
151
+ bias=False,
152
+ dtype=dtype,
153
+ device=device
154
+ )
155
+
156
+ num_patches = (image_size // patch_size) ** 2
157
+ num_positions = num_patches + 1
158
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
159
+
160
+ def forward(self, pixel_values):
161
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
162
+ return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
163
+
164
+
165
+ class CLIPVision(torch.nn.Module):
166
+ def __init__(self, config_dict, dtype, device, operations):
167
+ super().__init__()
168
+ num_layers = config_dict["num_hidden_layers"]
169
+ embed_dim = config_dict["hidden_size"]
170
+ heads = config_dict["num_attention_heads"]
171
+ intermediate_size = config_dict["intermediate_size"]
172
+ intermediate_activation = config_dict["hidden_act"]
173
+
174
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
175
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
176
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
177
+ self.post_layernorm = operations.LayerNorm(embed_dim)
178
+
179
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
180
+ x = self.embeddings(pixel_values)
181
+ x = self.pre_layrnorm(x)
182
+ #TODO: attention_mask?
183
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
184
+ pooled_output = self.post_layernorm(x[:, 0, :])
185
+ return x, i, pooled_output
186
+
187
+ class CLIPVisionModelProjection(torch.nn.Module):
188
+ def __init__(self, config_dict, dtype, device, operations):
189
+ super().__init__()
190
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
191
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
192
+
193
+ def forward(self, *args, **kwargs):
194
+ x = self.vision_model(*args, **kwargs)
195
+ out = self.visual_projection(x[2])
196
+ return (x[0], x[1], out)
ComfyUI/comfy/clip_vision.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
2
+ import os
3
+ import torch
4
+ import json
5
+ import logging
6
+
7
+ import comfy.ops
8
+ import comfy.model_patcher
9
+ import comfy.model_management
10
+ import comfy.utils
11
+ import comfy.clip_model
12
+
13
+ class Output:
14
+ def __getitem__(self, key):
15
+ return getattr(self, key)
16
+ def __setitem__(self, key, item):
17
+ setattr(self, key, item)
18
+
19
+ def clip_preprocess(image, size=224):
20
+ mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
21
+ std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
22
+ image = image.movedim(-1, 1)
23
+ if not (image.shape[2] == size and image.shape[3] == size):
24
+ scale = (size / min(image.shape[2], image.shape[3]))
25
+ image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
26
+ h = (image.shape[2] - size)//2
27
+ w = (image.shape[3] - size)//2
28
+ image = image[:,:,h:h+size,w:w+size]
29
+ image = torch.clip((255. * image), 0, 255).round() / 255.0
30
+ return (image - mean.view([3,1,1])) / std.view([3,1,1])
31
+
32
+ class ClipVisionModel():
33
+ def __init__(self, json_config):
34
+ with open(json_config) as f:
35
+ config = json.load(f)
36
+
37
+ self.image_size = config.get("image_size", 224)
38
+ self.load_device = comfy.model_management.text_encoder_device()
39
+ offload_device = comfy.model_management.text_encoder_offload_device()
40
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
41
+ self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
42
+ self.model.eval()
43
+
44
+ self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
45
+
46
+ def load_sd(self, sd):
47
+ return self.model.load_state_dict(sd, strict=False)
48
+
49
+ def get_sd(self):
50
+ return self.model.state_dict()
51
+
52
+ def encode_image(self, image):
53
+ comfy.model_management.load_model_gpu(self.patcher)
54
+ pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
55
+ out = self.model(pixel_values=pixel_values, intermediate_output=-2)
56
+
57
+ outputs = Output()
58
+ outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
59
+ outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
60
+ outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
61
+ return outputs
62
+
63
+ def convert_to_transformers(sd, prefix):
64
+ sd_k = sd.keys()
65
+ if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
66
+ keys_to_replace = {
67
+ "{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
68
+ "{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
69
+ "{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
70
+ "{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
71
+ "{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
72
+ "{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
73
+ "{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
74
+ }
75
+
76
+ for x in keys_to_replace:
77
+ if x in sd_k:
78
+ sd[keys_to_replace[x]] = sd.pop(x)
79
+
80
+ if "{}proj".format(prefix) in sd_k:
81
+ sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
82
+
83
+ sd = transformers_convert(sd, prefix, "vision_model.", 48)
84
+ else:
85
+ replace_prefix = {prefix: ""}
86
+ sd = state_dict_prefix_replace(sd, replace_prefix)
87
+ return sd
88
+
89
+ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
90
+ if convert_keys:
91
+ sd = convert_to_transformers(sd, prefix)
92
+ if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
93
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
94
+ elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
95
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
96
+ elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
97
+ if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
98
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
99
+ else:
100
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
101
+ else:
102
+ return None
103
+
104
+ clip = ClipVisionModel(json_config)
105
+ m, u = clip.load_sd(sd)
106
+ if len(m) > 0:
107
+ logging.warning("missing clip vision: {}".format(m))
108
+ u = set(u)
109
+ keys = list(sd.keys())
110
+ for k in keys:
111
+ if k not in u:
112
+ t = sd.pop(k)
113
+ del t
114
+ return clip
115
+
116
+ def load(ckpt_path):
117
+ sd = load_torch_file(ckpt_path)
118
+ if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
119
+ return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
120
+ else:
121
+ return load_clipvision_from_sd(sd)
ComfyUI/comfy/clip_vision_config_g.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1664,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 8192,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 48,
15
+ "patch_size": 14,
16
+ "projection_dim": 1280,
17
+ "torch_dtype": "float32"
18
+ }
ComfyUI/comfy/clip_vision_config_h.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1280,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 5120,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 32,
15
+ "patch_size": 14,
16
+ "projection_dim": 1024,
17
+ "torch_dtype": "float32"
18
+ }
ComfyUI/comfy/clip_vision_config_vitl.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
ComfyUI/comfy/clip_vision_config_vitl_336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 336,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-5,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
ComfyUI/comfy/conds.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import comfy.utils
4
+
5
+
6
+ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
7
+ return abs(a*b) // math.gcd(a, b)
8
+
9
+ class CONDRegular:
10
+ def __init__(self, cond):
11
+ self.cond = cond
12
+
13
+ def _copy_with(self, cond):
14
+ return self.__class__(cond)
15
+
16
+ def process_cond(self, batch_size, device, **kwargs):
17
+ return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
18
+
19
+ def can_concat(self, other):
20
+ if self.cond.shape != other.cond.shape:
21
+ return False
22
+ return True
23
+
24
+ def concat(self, others):
25
+ conds = [self.cond]
26
+ for x in others:
27
+ conds.append(x.cond)
28
+ return torch.cat(conds)
29
+
30
+ class CONDNoiseShape(CONDRegular):
31
+ def process_cond(self, batch_size, device, area, **kwargs):
32
+ data = self.cond
33
+ if area is not None:
34
+ dims = len(area) // 2
35
+ for i in range(dims):
36
+ data = data.narrow(i + 2, area[i + dims], area[i])
37
+
38
+ return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
39
+
40
+
41
+ class CONDCrossAttn(CONDRegular):
42
+ def can_concat(self, other):
43
+ s1 = self.cond.shape
44
+ s2 = other.cond.shape
45
+ if s1 != s2:
46
+ if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
47
+ return False
48
+
49
+ mult_min = lcm(s1[1], s2[1])
50
+ diff = mult_min // min(s1[1], s2[1])
51
+ if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
52
+ return False
53
+ return True
54
+
55
+ def concat(self, others):
56
+ conds = [self.cond]
57
+ crossattn_max_len = self.cond.shape[1]
58
+ for x in others:
59
+ c = x.cond
60
+ crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
61
+ conds.append(c)
62
+
63
+ out = []
64
+ for c in conds:
65
+ if c.shape[1] < crossattn_max_len:
66
+ c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
67
+ out.append(c)
68
+ return torch.cat(out)
69
+
70
+ class CONDConstant(CONDRegular):
71
+ def __init__(self, cond):
72
+ self.cond = cond
73
+
74
+ def process_cond(self, batch_size, device, **kwargs):
75
+ return self._copy_with(self.cond)
76
+
77
+ def can_concat(self, other):
78
+ if self.cond != other.cond:
79
+ return False
80
+ return True
81
+
82
+ def concat(self, others):
83
+ return self.cond
ComfyUI/comfy/controlnet.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import os
4
+ import logging
5
+ import comfy.utils
6
+ import comfy.model_management
7
+ import comfy.model_detection
8
+ import comfy.model_patcher
9
+ import comfy.ops
10
+ import comfy.latent_formats
11
+
12
+ import comfy.cldm.cldm
13
+ import comfy.t2i_adapter.adapter
14
+ import comfy.ldm.cascade.controlnet
15
+ import comfy.cldm.mmdit
16
+
17
+
18
+ def broadcast_image_to(tensor, target_batch_size, batched_number):
19
+ current_batch_size = tensor.shape[0]
20
+ #print(current_batch_size, target_batch_size)
21
+ if current_batch_size == 1:
22
+ return tensor
23
+
24
+ per_batch = target_batch_size // batched_number
25
+ tensor = tensor[:per_batch]
26
+
27
+ if per_batch > tensor.shape[0]:
28
+ tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
29
+
30
+ current_batch_size = tensor.shape[0]
31
+ if current_batch_size == target_batch_size:
32
+ return tensor
33
+ else:
34
+ return torch.cat([tensor] * batched_number, dim=0)
35
+
36
+ class ControlBase:
37
+ def __init__(self, device=None):
38
+ self.cond_hint_original = None
39
+ self.cond_hint = None
40
+ self.strength = 1.0
41
+ self.timestep_percent_range = (0.0, 1.0)
42
+ self.latent_format = None
43
+ self.vae = None
44
+ self.global_average_pooling = False
45
+ self.timestep_range = None
46
+ self.compression_ratio = 8
47
+ self.upscale_algorithm = 'nearest-exact'
48
+ self.extra_args = {}
49
+
50
+ if device is None:
51
+ device = comfy.model_management.get_torch_device()
52
+ self.device = device
53
+ self.previous_controlnet = None
54
+
55
+ def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
56
+ self.cond_hint_original = cond_hint
57
+ self.strength = strength
58
+ self.timestep_percent_range = timestep_percent_range
59
+ if self.latent_format is not None:
60
+ self.vae = vae
61
+ return self
62
+
63
+ def pre_run(self, model, percent_to_timestep_function):
64
+ self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
65
+ if self.previous_controlnet is not None:
66
+ self.previous_controlnet.pre_run(model, percent_to_timestep_function)
67
+
68
+ def set_previous_controlnet(self, controlnet):
69
+ self.previous_controlnet = controlnet
70
+ return self
71
+
72
+ def cleanup(self):
73
+ if self.previous_controlnet is not None:
74
+ self.previous_controlnet.cleanup()
75
+ if self.cond_hint is not None:
76
+ del self.cond_hint
77
+ self.cond_hint = None
78
+ self.timestep_range = None
79
+
80
+ def get_models(self):
81
+ out = []
82
+ if self.previous_controlnet is not None:
83
+ out += self.previous_controlnet.get_models()
84
+ return out
85
+
86
+ def copy_to(self, c):
87
+ c.cond_hint_original = self.cond_hint_original
88
+ c.strength = self.strength
89
+ c.timestep_percent_range = self.timestep_percent_range
90
+ c.global_average_pooling = self.global_average_pooling
91
+ c.compression_ratio = self.compression_ratio
92
+ c.upscale_algorithm = self.upscale_algorithm
93
+ c.latent_format = self.latent_format
94
+ c.extra_args = self.extra_args.copy()
95
+ c.vae = self.vae
96
+
97
+ def inference_memory_requirements(self, dtype):
98
+ if self.previous_controlnet is not None:
99
+ return self.previous_controlnet.inference_memory_requirements(dtype)
100
+ return 0
101
+
102
+ def control_merge(self, control, control_prev, output_dtype):
103
+ out = {'input':[], 'middle':[], 'output': []}
104
+
105
+ for key in control:
106
+ control_output = control[key]
107
+ applied_to = set()
108
+ for i in range(len(control_output)):
109
+ x = control_output[i]
110
+ if x is not None:
111
+ if self.global_average_pooling:
112
+ x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
113
+
114
+ if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
115
+ applied_to.add(x)
116
+ x *= self.strength
117
+
118
+ if x.dtype != output_dtype:
119
+ x = x.to(output_dtype)
120
+
121
+ out[key].append(x)
122
+
123
+ if control_prev is not None:
124
+ for x in ['input', 'middle', 'output']:
125
+ o = out[x]
126
+ for i in range(len(control_prev[x])):
127
+ prev_val = control_prev[x][i]
128
+ if i >= len(o):
129
+ o.append(prev_val)
130
+ elif prev_val is not None:
131
+ if o[i] is None:
132
+ o[i] = prev_val
133
+ else:
134
+ if o[i].shape[0] < prev_val.shape[0]:
135
+ o[i] = prev_val + o[i]
136
+ else:
137
+ o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
138
+ return out
139
+
140
+ def set_extra_arg(self, argument, value=None):
141
+ self.extra_args[argument] = value
142
+
143
+
144
+ class ControlNet(ControlBase):
145
+ def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
146
+ super().__init__(device)
147
+ self.control_model = control_model
148
+ self.load_device = load_device
149
+ if control_model is not None:
150
+ self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
151
+
152
+ self.compression_ratio = compression_ratio
153
+ self.global_average_pooling = global_average_pooling
154
+ self.model_sampling_current = None
155
+ self.manual_cast_dtype = manual_cast_dtype
156
+ self.latent_format = latent_format
157
+
158
+ def get_control(self, x_noisy, t, cond, batched_number):
159
+ control_prev = None
160
+ if self.previous_controlnet is not None:
161
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
162
+
163
+ if self.timestep_range is not None:
164
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
165
+ if control_prev is not None:
166
+ return control_prev
167
+ else:
168
+ return None
169
+
170
+ dtype = self.control_model.dtype
171
+ if self.manual_cast_dtype is not None:
172
+ dtype = self.manual_cast_dtype
173
+
174
+ output_dtype = x_noisy.dtype
175
+ if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
176
+ if self.cond_hint is not None:
177
+ del self.cond_hint
178
+ self.cond_hint = None
179
+ compression_ratio = self.compression_ratio
180
+ if self.vae is not None:
181
+ compression_ratio *= self.vae.downscale_ratio
182
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
183
+ if self.vae is not None:
184
+ loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
185
+ self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
186
+ comfy.model_management.load_models_gpu(loaded_models)
187
+ if self.latent_format is not None:
188
+ self.cond_hint = self.latent_format.process_in(self.cond_hint)
189
+ self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
190
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
191
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
192
+
193
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
194
+ extra = self.extra_args.copy()
195
+ for c in ["y", "guidance"]: #TODO
196
+ temp = cond.get(c, None)
197
+ if temp is not None:
198
+ extra[c] = temp.to(dtype)
199
+
200
+ timestep = self.model_sampling_current.timestep(t)
201
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
202
+
203
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
204
+ return self.control_merge(control, control_prev, output_dtype)
205
+
206
+ def copy(self):
207
+ c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
208
+ c.control_model = self.control_model
209
+ c.control_model_wrapped = self.control_model_wrapped
210
+ self.copy_to(c)
211
+ return c
212
+
213
+ def get_models(self):
214
+ out = super().get_models()
215
+ out.append(self.control_model_wrapped)
216
+ return out
217
+
218
+ def pre_run(self, model, percent_to_timestep_function):
219
+ super().pre_run(model, percent_to_timestep_function)
220
+ self.model_sampling_current = model.model_sampling
221
+
222
+ def cleanup(self):
223
+ self.model_sampling_current = None
224
+ super().cleanup()
225
+
226
+ class ControlLoraOps:
227
+ class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
228
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
229
+ device=None, dtype=None) -> None:
230
+ factory_kwargs = {'device': device, 'dtype': dtype}
231
+ super().__init__()
232
+ self.in_features = in_features
233
+ self.out_features = out_features
234
+ self.weight = None
235
+ self.up = None
236
+ self.down = None
237
+ self.bias = None
238
+
239
+ def forward(self, input):
240
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
241
+ if self.up is not None:
242
+ return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
243
+ else:
244
+ return torch.nn.functional.linear(input, weight, bias)
245
+
246
+ class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
247
+ def __init__(
248
+ self,
249
+ in_channels,
250
+ out_channels,
251
+ kernel_size,
252
+ stride=1,
253
+ padding=0,
254
+ dilation=1,
255
+ groups=1,
256
+ bias=True,
257
+ padding_mode='zeros',
258
+ device=None,
259
+ dtype=None
260
+ ):
261
+ super().__init__()
262
+ self.in_channels = in_channels
263
+ self.out_channels = out_channels
264
+ self.kernel_size = kernel_size
265
+ self.stride = stride
266
+ self.padding = padding
267
+ self.dilation = dilation
268
+ self.transposed = False
269
+ self.output_padding = 0
270
+ self.groups = groups
271
+ self.padding_mode = padding_mode
272
+
273
+ self.weight = None
274
+ self.bias = None
275
+ self.up = None
276
+ self.down = None
277
+
278
+
279
+ def forward(self, input):
280
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
281
+ if self.up is not None:
282
+ return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
283
+ else:
284
+ return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
285
+
286
+
287
+ class ControlLora(ControlNet):
288
+ def __init__(self, control_weights, global_average_pooling=False, device=None):
289
+ ControlBase.__init__(self, device)
290
+ self.control_weights = control_weights
291
+ self.global_average_pooling = global_average_pooling
292
+
293
+ def pre_run(self, model, percent_to_timestep_function):
294
+ super().pre_run(model, percent_to_timestep_function)
295
+ controlnet_config = model.model_config.unet_config.copy()
296
+ controlnet_config.pop("out_channels")
297
+ controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
298
+ self.manual_cast_dtype = model.manual_cast_dtype
299
+ dtype = model.get_dtype()
300
+ if self.manual_cast_dtype is None:
301
+ class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
302
+ pass
303
+ else:
304
+ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
305
+ pass
306
+ dtype = self.manual_cast_dtype
307
+
308
+ controlnet_config["operations"] = control_lora_ops
309
+ controlnet_config["dtype"] = dtype
310
+ self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
311
+ self.control_model.to(comfy.model_management.get_torch_device())
312
+ diffusion_model = model.diffusion_model
313
+ sd = diffusion_model.state_dict()
314
+ cm = self.control_model.state_dict()
315
+
316
+ for k in sd:
317
+ weight = sd[k]
318
+ try:
319
+ comfy.utils.set_attr_param(self.control_model, k, weight)
320
+ except:
321
+ pass
322
+
323
+ for k in self.control_weights:
324
+ if k not in {"lora_controlnet"}:
325
+ comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
326
+
327
+ def copy(self):
328
+ c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
329
+ self.copy_to(c)
330
+ return c
331
+
332
+ def cleanup(self):
333
+ del self.control_model
334
+ self.control_model = None
335
+ super().cleanup()
336
+
337
+ def get_models(self):
338
+ out = ControlBase.get_models(self)
339
+ return out
340
+
341
+ def inference_memory_requirements(self, dtype):
342
+ return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
343
+
344
+ def controlnet_config(sd):
345
+ model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
346
+
347
+ supported_inference_dtypes = model_config.supported_inference_dtypes
348
+
349
+ controlnet_config = model_config.unet_config
350
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
351
+ load_device = comfy.model_management.get_torch_device()
352
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
353
+ if manual_cast_dtype is not None:
354
+ operations = comfy.ops.manual_cast
355
+ else:
356
+ operations = comfy.ops.disable_weight_init
357
+
358
+ return model_config, operations, load_device, unet_dtype, manual_cast_dtype
359
+
360
+ def controlnet_load_state_dict(control_model, sd):
361
+ missing, unexpected = control_model.load_state_dict(sd, strict=False)
362
+
363
+ if len(missing) > 0:
364
+ logging.warning("missing controlnet keys: {}".format(missing))
365
+
366
+ if len(unexpected) > 0:
367
+ logging.debug("unexpected controlnet keys: {}".format(unexpected))
368
+ return control_model
369
+
370
+ def load_controlnet_mmdit(sd):
371
+ new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
372
+ model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
373
+ num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
374
+ for k in sd:
375
+ new_sd[k] = sd[k]
376
+
377
+ control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
378
+ control_model = controlnet_load_state_dict(control_model, new_sd)
379
+
380
+ latent_format = comfy.latent_formats.SD3()
381
+ latent_format.shift_factor = 0 #SD3 controlnet weirdness
382
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
383
+ return control
384
+
385
+
386
+ def load_controlnet(ckpt_path, model=None):
387
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
388
+ if "lora_controlnet" in controlnet_data:
389
+ return ControlLora(controlnet_data)
390
+
391
+ controlnet_config = None
392
+ supported_inference_dtypes = None
393
+
394
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
395
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
396
+ diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
397
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
398
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
399
+
400
+ count = 0
401
+ loop = True
402
+ while loop:
403
+ suffix = [".weight", ".bias"]
404
+ for s in suffix:
405
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
406
+ k_out = "zero_convs.{}.0{}".format(count, s)
407
+ if k_in not in controlnet_data:
408
+ loop = False
409
+ break
410
+ diffusers_keys[k_in] = k_out
411
+ count += 1
412
+
413
+ count = 0
414
+ loop = True
415
+ while loop:
416
+ suffix = [".weight", ".bias"]
417
+ for s in suffix:
418
+ if count == 0:
419
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
420
+ else:
421
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
422
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
423
+ if k_in not in controlnet_data:
424
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
425
+ loop = False
426
+ diffusers_keys[k_in] = k_out
427
+ count += 1
428
+
429
+ new_sd = {}
430
+ for k in diffusers_keys:
431
+ if k in controlnet_data:
432
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
433
+
434
+ if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
435
+ controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
436
+ for k in list(controlnet_data.keys()):
437
+ new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
438
+ new_sd[new_k] = controlnet_data.pop(k)
439
+
440
+ leftover_keys = controlnet_data.keys()
441
+ if len(leftover_keys) > 0:
442
+ logging.warning("leftover keys: {}".format(leftover_keys))
443
+ controlnet_data = new_sd
444
+ elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
445
+ return load_controlnet_mmdit(controlnet_data)
446
+
447
+ pth_key = 'control_model.zero_convs.0.0.weight'
448
+ pth = False
449
+ key = 'zero_convs.0.0.weight'
450
+ if pth_key in controlnet_data:
451
+ pth = True
452
+ key = pth_key
453
+ prefix = "control_model."
454
+ elif key in controlnet_data:
455
+ prefix = ""
456
+ else:
457
+ net = load_t2i_adapter(controlnet_data)
458
+ if net is None:
459
+ logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
460
+ return net
461
+
462
+ if controlnet_config is None:
463
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
464
+ supported_inference_dtypes = model_config.supported_inference_dtypes
465
+ controlnet_config = model_config.unet_config
466
+
467
+ load_device = comfy.model_management.get_torch_device()
468
+ if supported_inference_dtypes is None:
469
+ unet_dtype = comfy.model_management.unet_dtype()
470
+ else:
471
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
472
+
473
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
474
+ if manual_cast_dtype is not None:
475
+ controlnet_config["operations"] = comfy.ops.manual_cast
476
+ controlnet_config["dtype"] = unet_dtype
477
+ controlnet_config.pop("out_channels")
478
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
479
+ control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
480
+
481
+ if pth:
482
+ if 'difference' in controlnet_data:
483
+ if model is not None:
484
+ comfy.model_management.load_models_gpu([model])
485
+ model_sd = model.model_state_dict()
486
+ for x in controlnet_data:
487
+ c_m = "control_model."
488
+ if x.startswith(c_m):
489
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
490
+ if sd_key in model_sd:
491
+ cd = controlnet_data[x]
492
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
493
+ else:
494
+ logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
495
+
496
+ class WeightsLoader(torch.nn.Module):
497
+ pass
498
+ w = WeightsLoader()
499
+ w.control_model = control_model
500
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
501
+ else:
502
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
503
+
504
+ if len(missing) > 0:
505
+ logging.warning("missing controlnet keys: {}".format(missing))
506
+
507
+ if len(unexpected) > 0:
508
+ logging.debug("unexpected controlnet keys: {}".format(unexpected))
509
+
510
+ global_average_pooling = False
511
+ filename = os.path.splitext(ckpt_path)[0]
512
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
513
+ global_average_pooling = True
514
+
515
+ control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
516
+ return control
517
+
518
+ class T2IAdapter(ControlBase):
519
+ def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
520
+ super().__init__(device)
521
+ self.t2i_model = t2i_model
522
+ self.channels_in = channels_in
523
+ self.control_input = None
524
+ self.compression_ratio = compression_ratio
525
+ self.upscale_algorithm = upscale_algorithm
526
+
527
+ def scale_image_to(self, width, height):
528
+ unshuffle_amount = self.t2i_model.unshuffle_amount
529
+ width = math.ceil(width / unshuffle_amount) * unshuffle_amount
530
+ height = math.ceil(height / unshuffle_amount) * unshuffle_amount
531
+ return width, height
532
+
533
+ def get_control(self, x_noisy, t, cond, batched_number):
534
+ control_prev = None
535
+ if self.previous_controlnet is not None:
536
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
537
+
538
+ if self.timestep_range is not None:
539
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
540
+ if control_prev is not None:
541
+ return control_prev
542
+ else:
543
+ return None
544
+
545
+ if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
546
+ if self.cond_hint is not None:
547
+ del self.cond_hint
548
+ self.control_input = None
549
+ self.cond_hint = None
550
+ width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
551
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
552
+ if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
553
+ self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
554
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
555
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
556
+ if self.control_input is None:
557
+ self.t2i_model.to(x_noisy.dtype)
558
+ self.t2i_model.to(self.device)
559
+ self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
560
+ self.t2i_model.cpu()
561
+
562
+ control_input = {}
563
+ for k in self.control_input:
564
+ control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
565
+
566
+ return self.control_merge(control_input, control_prev, x_noisy.dtype)
567
+
568
+ def copy(self):
569
+ c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
570
+ self.copy_to(c)
571
+ return c
572
+
573
+ def load_t2i_adapter(t2i_data):
574
+ compression_ratio = 8
575
+ upscale_algorithm = 'nearest-exact'
576
+
577
+ if 'adapter' in t2i_data:
578
+ t2i_data = t2i_data['adapter']
579
+ if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
580
+ prefix_replace = {}
581
+ for i in range(4):
582
+ for j in range(2):
583
+ prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
584
+ prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
585
+ prefix_replace["adapter."] = ""
586
+ t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
587
+ keys = t2i_data.keys()
588
+
589
+ if "body.0.in_conv.weight" in keys:
590
+ cin = t2i_data['body.0.in_conv.weight'].shape[1]
591
+ model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
592
+ elif 'conv_in.weight' in keys:
593
+ cin = t2i_data['conv_in.weight'].shape[1]
594
+ channel = t2i_data['conv_in.weight'].shape[0]
595
+ ksize = t2i_data['body.0.block2.weight'].shape[2]
596
+ use_conv = False
597
+ down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
598
+ if len(down_opts) > 0:
599
+ use_conv = True
600
+ xl = False
601
+ if cin == 256 or cin == 768:
602
+ xl = True
603
+ model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
604
+ elif "backbone.0.0.weight" in keys:
605
+ model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
606
+ compression_ratio = 32
607
+ upscale_algorithm = 'bilinear'
608
+ elif "backbone.10.blocks.0.weight" in keys:
609
+ model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
610
+ compression_ratio = 1
611
+ upscale_algorithm = 'nearest-exact'
612
+ else:
613
+ return None
614
+
615
+ missing, unexpected = model_ad.load_state_dict(t2i_data)
616
+ if len(missing) > 0:
617
+ logging.warning("t2i missing {}".format(missing))
618
+
619
+ if len(unexpected) > 0:
620
+ logging.debug("t2i unexpected {}".format(unexpected))
621
+
622
+ return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
ComfyUI/comfy/diffusers_convert.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import logging
4
+
5
+ # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
6
+
7
+ # =================#
8
+ # UNet Conversion #
9
+ # =================#
10
+
11
+ unet_conversion_map = [
12
+ # (stable-diffusion, HF Diffusers)
13
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
14
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
15
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
16
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
17
+ ("input_blocks.0.0.weight", "conv_in.weight"),
18
+ ("input_blocks.0.0.bias", "conv_in.bias"),
19
+ ("out.0.weight", "conv_norm_out.weight"),
20
+ ("out.0.bias", "conv_norm_out.bias"),
21
+ ("out.2.weight", "conv_out.weight"),
22
+ ("out.2.bias", "conv_out.bias"),
23
+ ]
24
+
25
+ unet_conversion_map_resnet = [
26
+ # (stable-diffusion, HF Diffusers)
27
+ ("in_layers.0", "norm1"),
28
+ ("in_layers.2", "conv1"),
29
+ ("out_layers.0", "norm2"),
30
+ ("out_layers.3", "conv2"),
31
+ ("emb_layers.1", "time_emb_proj"),
32
+ ("skip_connection", "conv_shortcut"),
33
+ ]
34
+
35
+ unet_conversion_map_layer = []
36
+ # hardcoded number of downblocks and resnets/attentions...
37
+ # would need smarter logic for other networks.
38
+ for i in range(4):
39
+ # loop over downblocks/upblocks
40
+
41
+ for j in range(2):
42
+ # loop over resnets/attentions for downblocks
43
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
44
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
45
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
46
+
47
+ if i < 3:
48
+ # no attention layers in down_blocks.3
49
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
50
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
51
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
52
+
53
+ for j in range(3):
54
+ # loop over resnets/attentions for upblocks
55
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
56
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
57
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
58
+
59
+ if i > 0:
60
+ # no attention layers in up_blocks.0
61
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
62
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
63
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
64
+
65
+ if i < 3:
66
+ # no downsample in down_blocks.3
67
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
68
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
69
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
70
+
71
+ # no upsample in up_blocks.3
72
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
73
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
74
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
75
+
76
+ hf_mid_atn_prefix = "mid_block.attentions.0."
77
+ sd_mid_atn_prefix = "middle_block.1."
78
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
79
+
80
+ for j in range(2):
81
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
82
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
83
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
84
+
85
+
86
+ def convert_unet_state_dict(unet_state_dict):
87
+ # buyer beware: this is a *brittle* function,
88
+ # and correct output requires that all of these pieces interact in
89
+ # the exact order in which I have arranged them.
90
+ mapping = {k: k for k in unet_state_dict.keys()}
91
+ for sd_name, hf_name in unet_conversion_map:
92
+ mapping[hf_name] = sd_name
93
+ for k, v in mapping.items():
94
+ if "resnets" in k:
95
+ for sd_part, hf_part in unet_conversion_map_resnet:
96
+ v = v.replace(hf_part, sd_part)
97
+ mapping[k] = v
98
+ for k, v in mapping.items():
99
+ for sd_part, hf_part in unet_conversion_map_layer:
100
+ v = v.replace(hf_part, sd_part)
101
+ mapping[k] = v
102
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
103
+ return new_state_dict
104
+
105
+
106
+ # ================#
107
+ # VAE Conversion #
108
+ # ================#
109
+
110
+ vae_conversion_map = [
111
+ # (stable-diffusion, HF Diffusers)
112
+ ("nin_shortcut", "conv_shortcut"),
113
+ ("norm_out", "conv_norm_out"),
114
+ ("mid.attn_1.", "mid_block.attentions.0."),
115
+ ]
116
+
117
+ for i in range(4):
118
+ # down_blocks have two resnets
119
+ for j in range(2):
120
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
121
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
122
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
123
+
124
+ if i < 3:
125
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
126
+ sd_downsample_prefix = f"down.{i}.downsample."
127
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
128
+
129
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
130
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
131
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
132
+
133
+ # up_blocks have three resnets
134
+ # also, up blocks in hf are numbered in reverse from sd
135
+ for j in range(3):
136
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
137
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
138
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
139
+
140
+ # this part accounts for mid blocks in both the encoder and the decoder
141
+ for i in range(2):
142
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
143
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
144
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
145
+
146
+ vae_conversion_map_attn = [
147
+ # (stable-diffusion, HF Diffusers)
148
+ ("norm.", "group_norm."),
149
+ ("q.", "query."),
150
+ ("k.", "key."),
151
+ ("v.", "value."),
152
+ ("q.", "to_q."),
153
+ ("k.", "to_k."),
154
+ ("v.", "to_v."),
155
+ ("proj_out.", "to_out.0."),
156
+ ("proj_out.", "proj_attn."),
157
+ ]
158
+
159
+
160
+ def reshape_weight_for_sd(w):
161
+ # convert HF linear weights to SD conv2d weights
162
+ return w.reshape(*w.shape, 1, 1)
163
+
164
+
165
+ def convert_vae_state_dict(vae_state_dict):
166
+ mapping = {k: k for k in vae_state_dict.keys()}
167
+ for k, v in mapping.items():
168
+ for sd_part, hf_part in vae_conversion_map:
169
+ v = v.replace(hf_part, sd_part)
170
+ mapping[k] = v
171
+ for k, v in mapping.items():
172
+ if "attentions" in k:
173
+ for sd_part, hf_part in vae_conversion_map_attn:
174
+ v = v.replace(hf_part, sd_part)
175
+ mapping[k] = v
176
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
177
+ weights_to_convert = ["q", "k", "v", "proj_out"]
178
+ for k, v in new_state_dict.items():
179
+ for weight_name in weights_to_convert:
180
+ if f"mid.attn_1.{weight_name}.weight" in k:
181
+ logging.debug(f"Reshaping {k} for SD format")
182
+ new_state_dict[k] = reshape_weight_for_sd(v)
183
+ return new_state_dict
184
+
185
+
186
+ # =========================#
187
+ # Text Encoder Conversion #
188
+ # =========================#
189
+
190
+
191
+ textenc_conversion_lst = [
192
+ # (stable-diffusion, HF Diffusers)
193
+ ("resblocks.", "text_model.encoder.layers."),
194
+ ("ln_1", "layer_norm1"),
195
+ ("ln_2", "layer_norm2"),
196
+ (".c_fc.", ".fc1."),
197
+ (".c_proj.", ".fc2."),
198
+ (".attn", ".self_attn"),
199
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
200
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
201
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
202
+ ]
203
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
204
+ textenc_pattern = re.compile("|".join(protected.keys()))
205
+
206
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
207
+ code2idx = {"q": 0, "k": 1, "v": 2}
208
+
209
+ # This function exists because at the time of writing torch.cat can't do fp8 with cuda
210
+ def cat_tensors(tensors):
211
+ x = 0
212
+ for t in tensors:
213
+ x += t.shape[0]
214
+
215
+ shape = [x] + list(tensors[0].shape)[1:]
216
+ out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
217
+
218
+ x = 0
219
+ for t in tensors:
220
+ out[x:x + t.shape[0]] = t
221
+ x += t.shape[0]
222
+
223
+ return out
224
+
225
+ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
226
+ new_state_dict = {}
227
+ capture_qkv_weight = {}
228
+ capture_qkv_bias = {}
229
+ for k, v in text_enc_dict.items():
230
+ if not k.startswith(prefix):
231
+ continue
232
+ if (
233
+ k.endswith(".self_attn.q_proj.weight")
234
+ or k.endswith(".self_attn.k_proj.weight")
235
+ or k.endswith(".self_attn.v_proj.weight")
236
+ ):
237
+ k_pre = k[: -len(".q_proj.weight")]
238
+ k_code = k[-len("q_proj.weight")]
239
+ if k_pre not in capture_qkv_weight:
240
+ capture_qkv_weight[k_pre] = [None, None, None]
241
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
242
+ continue
243
+
244
+ if (
245
+ k.endswith(".self_attn.q_proj.bias")
246
+ or k.endswith(".self_attn.k_proj.bias")
247
+ or k.endswith(".self_attn.v_proj.bias")
248
+ ):
249
+ k_pre = k[: -len(".q_proj.bias")]
250
+ k_code = k[-len("q_proj.bias")]
251
+ if k_pre not in capture_qkv_bias:
252
+ capture_qkv_bias[k_pre] = [None, None, None]
253
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
254
+ continue
255
+
256
+ text_proj = "transformer.text_projection.weight"
257
+ if k.endswith(text_proj):
258
+ new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
259
+ else:
260
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
261
+ new_state_dict[relabelled_key] = v
262
+
263
+ for k_pre, tensors in capture_qkv_weight.items():
264
+ if None in tensors:
265
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
266
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
267
+ new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
268
+
269
+ for k_pre, tensors in capture_qkv_bias.items():
270
+ if None in tensors:
271
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
272
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
273
+ new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
274
+
275
+ return new_state_dict
276
+
277
+
278
+ def convert_text_enc_state_dict(text_enc_dict):
279
+ return text_enc_dict
280
+
281
+
ComfyUI/comfy/diffusers_load.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import comfy.sd
4
+
5
+ def first_file(path, filenames):
6
+ for f in filenames:
7
+ p = os.path.join(path, f)
8
+ if os.path.exists(p):
9
+ return p
10
+ return None
11
+
12
+ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
13
+ diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
14
+ unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
15
+ vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
16
+
17
+ text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
18
+ text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
19
+ text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
20
+
21
+ text_encoder_paths = [text_encoder1_path]
22
+ if text_encoder2_path is not None:
23
+ text_encoder_paths.append(text_encoder2_path)
24
+
25
+ unet = comfy.sd.load_unet(unet_path)
26
+
27
+ clip = None
28
+ if output_clip:
29
+ clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
30
+
31
+ vae = None
32
+ if output_vae:
33
+ sd = comfy.utils.load_torch_file(vae_path)
34
+ vae = comfy.sd.VAE(sd=sd)
35
+
36
+ return (unet, clip, vae)
ComfyUI/comfy/extra_samplers/__pycache__/uni_pc.cpython-310.pyc ADDED
Binary file (28.5 kB). View file