File size: 12,825 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import torch
import math
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
import comfy.conds
from comfy import model_management
from .diffusers_convert import convert_state_dict

# checkpointbf
class EXM_PixArt(comfy.supported_models_base.BASE):
    unet_config = {}
    unet_extra_config = {}
    latent_format = comfy.latent_formats.SD15

    def __init__(self, model_conf):
        self.model_target = model_conf.get("target")
        self.unet_config = model_conf.get("unet_config", {})
        self.sampling_settings = model_conf.get("sampling_settings", {})
        self.latent_format = self.latent_format()
        # UNET is handled by extension
        self.unet_config["disable_unet_model_creation"] = True

    def model_type(self, state_dict, prefix=""):
        return comfy.model_base.ModelType.EPS


class EXM_PixArt_Model(comfy.model_base.BaseModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def extra_conds(self, **kwargs):
        out = super().extra_conds(**kwargs)

        img_hw = kwargs.get("img_hw", None)
        if img_hw is not None:
            out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw))

        aspect_ratio = kwargs.get("aspect_ratio", None)
        if aspect_ratio is not None:
            out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio))

        cn_hint = kwargs.get("cn_hint", None)
        if cn_hint is not None:
            out["cn_hint"] = comfy.conds.CONDRegular(cn_hint)

        return out


def load_pixart(model_path, model_conf=None):
    state_dict = comfy.utils.load_torch_file(model_path)
    state_dict = state_dict.get("model", state_dict)

    # prefix
    for prefix in ["model.diffusion_model.", ]:
        if any(True for x in state_dict if x.startswith(prefix)):
            state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}

    # diffusers
    if "adaln_single.linear.weight" in state_dict:
        state_dict = convert_state_dict(state_dict)  # Diffusers

    # guess auto config
    if model_conf is None:
        model_conf = guess_pixart_config(state_dict)

    parameters = comfy.utils.calculate_parameters(state_dict)
    unet_dtype = model_management.unet_dtype(model_params=parameters)
    load_device = comfy.model_management.get_torch_device()
    offload_device = comfy.model_management.unet_offload_device()

    # ignore fp8/etc and use directly for now
    manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
    if manual_cast_dtype:
        print(f"PixArt: falling back to {manual_cast_dtype}")
        unet_dtype = manual_cast_dtype

    model_conf = EXM_PixArt(model_conf)  # convert to object
    model = EXM_PixArt_Model(  # same as comfy.model_base.BaseModel
        model_conf,
        model_type=comfy.model_base.ModelType.EPS,
        device=model_management.get_torch_device()
    )

    if model_conf.model_target == "PixArtMS":
        from .models.PixArtMS import PixArtMS
        model.diffusion_model = PixArtMS(**model_conf.unet_config)
    elif model_conf.model_target == "PixArt":
        from .models.PixArt import PixArt
        model.diffusion_model = PixArt(**model_conf.unet_config)
    elif model_conf.model_target == "PixArtMSSigma":
        from .models.PixArtMS import PixArtMS
        model.diffusion_model = PixArtMS(**model_conf.unet_config)
        model.latent_format = comfy.latent_formats.SDXL()
    elif model_conf.model_target == "ControlPixArtMSHalf":
        from .models.PixArtMS import PixArtMS
        from .models.pixart_controlnet import ControlPixArtMSHalf
        model.diffusion_model = PixArtMS(**model_conf.unet_config)
        model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model)
    elif model_conf.model_target == "ControlPixArtHalf":
        from .models.PixArt import PixArt
        from .models.pixart_controlnet import ControlPixArtHalf
        model.diffusion_model = PixArt(**model_conf.unet_config)
        model.diffusion_model = ControlPixArtHalf(model.diffusion_model)
    else:
        raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")

    m, u = model.diffusion_model.load_state_dict(state_dict, strict=False)
    if len(m) > 0: print("Missing UNET keys", m)
    if len(u) > 0: print("Leftover UNET keys", u)
    model.diffusion_model.dtype = unet_dtype
    model.diffusion_model.eval()
    model.diffusion_model.to(unet_dtype)

    model_patcher = comfy.model_patcher.ModelPatcher(
        model,
        load_device=load_device,
        offload_device=offload_device,
    )
    return model_patcher


def guess_pixart_config(sd):
    """
    Guess config based on converted state dict.
    """
    # Shared settings based on DiT_XL_2 - could be enumerated
    config = {
        "num_heads": 16,  # get from attention
        "patch_size": 2,  # final layer I guess?
        "hidden_size": 1152,  # pos_embed.shape[2]
    }
    config["depth"] = sum([key.endswith(".attn.proj.weight") for key in sd.keys()]) or 28

    try:
        # this is not present in the diffusers version for sigma?
        config["model_max_length"] = sd["y_embedder.y_embedding"].shape[0]
    except KeyError:
        # need better logic to guess this
        config["model_max_length"] = 300

    if "pos_embed" in sd:
        config["input_size"] = int(math.sqrt(sd["pos_embed"].shape[1])) * config["patch_size"]
        config["pe_interpolation"] = config["input_size"] // (512 // 8)  # dumb guess

    target_arch = "PixArtMS"
    if config["model_max_length"] == 300:
        # Sigma
        target_arch = "PixArtMSSigma"
        config["micro_condition"] = False
        if "input_size" not in config:
            # The diffusers weights for 1K/2K are exactly the same...?
            # replace patch embed logic with HyDiT?
            print(f"PixArt: diffusers weights - 2K model will be broken, use manual loading!")
            config["input_size"] = 1024 // 8
    else:
        # Alpha
        if "csize_embedder.mlp.0.weight" in sd:
            # MS (microconds)
            target_arch = "PixArtMS"
            config["micro_condition"] = True
            if "input_size" not in config:
                config["input_size"] = 1024 // 8
                config["pe_interpolation"] = 2
        else:
            # PixArt
            target_arch = "PixArt"
            if "input_size" not in config:
                config["input_size"] = 512 // 8
                config["pe_interpolation"] = 1

    print("PixArt guessed config:", target_arch, config)
    return {
        "target": target_arch,
        "unet_config": config,
        "sampling_settings": {
            "beta_schedule": "sqrt_linear",
            "linear_start": 0.0001,
            "linear_end": 0.02,
            "timesteps": 1000,
        }
    }

# lora
class EXM_PixArt_ModelPatcher(comfy.model_patcher.ModelPatcher):
    def calculate_weight(self, patches, weight, key):
        """
        This is almost the same as the comfy function, but stripped down to just the LoRA patch code.
        The problem with the original code is the q/k/v keys being combined into one for the attention.
        In the diffusers code, they're treated as separate keys, but in the reference code they're recombined (q+kv|qkv).
        This means, for example, that the [1152,1152] weights become [3456,1152] in the state dict.
        The issue with this is that the LoRA weights are [128,1152],[1152,128] and become [384,1162],[3456,128] instead.

        This is the best thing I could think of that would fix that, but it's very fragile.
         - Check key shape to determine if it needs the fallback logic
         - Cut the input into parts based on the shape (undoing the torch.cat)
         - Do the matrix multiplication logic
         - Recombine them to match the expected shape
        """
        for p in patches:
            alpha = p[0]
            v = p[1]
            strength_model = p[2]
            if strength_model != 1.0:
                weight *= strength_model

            if isinstance(v, list):
                v = (self.calculate_weight(v[1:], v[0].clone(), key),)

            if len(v) == 2:
                patch_type = v[0]
                v = v[1]

            if patch_type == "lora":
                mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
                mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
                if v[2] is not None:
                    alpha *= v[2] / mat2.shape[0]
                try:
                    mat1 = mat1.flatten(start_dim=1)
                    mat2 = mat2.flatten(start_dim=1)

                    ch1 = mat1.shape[0] // mat2.shape[1]
                    ch2 = mat2.shape[0] // mat1.shape[1]
                    ### Fallback logic for shape mismatch ###
                    if mat1.shape[0] != mat2.shape[1] and ch1 == ch2 and (mat1.shape[0] / mat2.shape[1]) % 1 == 0:
                        mat1 = mat1.chunk(ch1, dim=0)
                        mat2 = mat2.chunk(ch1, dim=0)
                        weight += torch.cat(
                            [alpha * torch.mm(mat1[x], mat2[x]) for x in range(ch1)],
                            dim=0,
                        ).reshape(weight.shape).type(weight.dtype)
                    else:
                        weight += (alpha * torch.mm(mat1, mat2)).reshape(weight.shape).type(weight.dtype)
                except Exception as e:
                    print("ERROR", key, e)
        return weight

    def clone(self):
        n = EXM_PixArt_ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
                                    weight_inplace_update=self.weight_inplace_update)
        n.patches = {}
        for k in self.patches:
            n.patches[k] = self.patches[k][:]

        n.object_patches = self.object_patches.copy()
        n.model_options = copy.deepcopy(self.model_options)
        n.model_keys = self.model_keys
        return n


def replace_model_patcher(model):
    n = EXM_PixArt_ModelPatcher(
        model=model.model,
        size=model.size,
        load_device=model.load_device,
        offload_device=model.offload_device,
        current_device=model.current_device,
        weight_inplace_update=model.weight_inplace_update,
    )
    n.patches = {}
    for k in model.patches:
        n.patches[k] = model.patches[k][:]

    n.object_patches = model.object_patches.copy()
    n.model_options = copy.deepcopy(model.model_options)
    return n


def find_peft_alpha(path):
    def load_json(json_path):
        with open(json_path) as f:
            data = json.load(f)
        alpha = data.get("lora_alpha")
        alpha = alpha or data.get("alpha")
        if not alpha:
            print(" Found config but `lora_alpha` is missing!")
        else:
            print(f" Found config at {json_path} [alpha:{alpha}]")
        return alpha

    # For some weird reason peft doesn't include the alpha in the actual model
    print("PixArt: Warning! This is a PEFT LoRA. Trying to find config...")
    files = [
        f"{os.path.splitext(path)[0]}.json",
        f"{os.path.splitext(path)[0]}.config.json",
        os.path.join(os.path.dirname(path), "adapter_config.json"),
    ]
    for file in files:
        if os.path.isfile(file):
            return load_json(file)

    print(" Missing config/alpha! assuming alpha of 8. Consider converting it/adding a config json to it.")
    return 8.0


def load_pixart_lora(model, lora, lora_path, strength):
    k_back = lambda x: x.replace(".lora_up.weight", "")
    # need to convert the actual weights for this to work.
    if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")):
        lora = convert_lora_state_dict(lora, peft=True)
        alpha = find_peft_alpha(lora_path)
        lora.update({f"{k_back(x)}.alpha": torch.tensor(alpha) for x in lora.keys() if "lora_up" in x})
    else:  # OneTrainer
        lora = convert_lora_state_dict(lora, peft=False)

    key_map = {k_back(x): f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x}  # fake

    loaded = comfy.lora.load_lora(lora, key_map)
    if model is not None:
        # switch to custom model patcher when using LoRAs
        if isinstance(model, EXM_PixArt_ModelPatcher):
            new_modelpatcher = model.clone()
        else:
            new_modelpatcher = replace_model_patcher(model)
        k = new_modelpatcher.add_patches(loaded, strength)
    else:
        k = ()
        new_modelpatcher = None

    k = set(k)
    for x in loaded:
        if (x not in k):
            print("NOT LOADED", x)

    return new_modelpatcher