JasonSmithSO's picture
Upload 578 files
8866644 verified
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
import comfy.conds
import torch
import math
from comfy import model_management
from .diffusers_convert import convert_state_dict
class EXM_PixArt(comfy.supported_models_base.BASE):
unet_config = {}
unet_extra_config = {}
latent_format = comfy.latent_formats.SD15
def __init__(self, model_conf):
self.model_target = model_conf.get("target")
self.unet_config = model_conf.get("unet_config", {})
self.sampling_settings = model_conf.get("sampling_settings", {})
self.latent_format = self.latent_format()
# UNET is handled by extension
self.unet_config["disable_unet_model_creation"] = True
def model_type(self, state_dict, prefix=""):
return comfy.model_base.ModelType.EPS
class EXM_PixArt_Model(comfy.model_base.BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
img_hw = kwargs.get("img_hw", None)
if img_hw is not None:
out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw))
aspect_ratio = kwargs.get("aspect_ratio", None)
if aspect_ratio is not None:
out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio))
cn_hint = kwargs.get("cn_hint", None)
if cn_hint is not None:
out["cn_hint"] = comfy.conds.CONDRegular(cn_hint)
return out
def load_pixart(model_path, model_conf=None):
state_dict = comfy.utils.load_torch_file(model_path)
state_dict = state_dict.get("model", state_dict)
# prefix
for prefix in ["model.diffusion_model.",]:
if any(True for x in state_dict if x.startswith(prefix)):
state_dict = {k[len(prefix):]:v for k,v in state_dict.items()}
# diffusers
if "adaln_single.linear.weight" in state_dict:
state_dict = convert_state_dict(state_dict) # Diffusers
# guess auto config
if model_conf is None:
model_conf = guess_pixart_config(state_dict)
parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
# ignore fp8/etc and use directly for now
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype:
print(f"PixArt: falling back to {manual_cast_dtype}")
unet_dtype = manual_cast_dtype
model_conf = EXM_PixArt(model_conf) # convert to object
model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel
model_conf,
model_type=comfy.model_base.ModelType.EPS,
device=model_management.get_torch_device()
)
if model_conf.model_target == "PixArtMS":
from .models.PixArtMS import PixArtMS
model.diffusion_model = PixArtMS(**model_conf.unet_config)
elif model_conf.model_target == "PixArt":
from .models.PixArt import PixArt
model.diffusion_model = PixArt(**model_conf.unet_config)
elif model_conf.model_target == "PixArtMSSigma":
from .models.PixArtMS import PixArtMS
model.diffusion_model = PixArtMS(**model_conf.unet_config)
model.latent_format = comfy.latent_formats.SDXL()
elif model_conf.model_target == "ControlPixArtMSHalf":
from .models.PixArtMS import PixArtMS
from .models.pixart_controlnet import ControlPixArtMSHalf
model.diffusion_model = PixArtMS(**model_conf.unet_config)
model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model)
elif model_conf.model_target == "ControlPixArtHalf":
from .models.PixArt import PixArt
from .models.pixart_controlnet import ControlPixArtHalf
model.diffusion_model = PixArt(**model_conf.unet_config)
model.diffusion_model = ControlPixArtHalf(model.diffusion_model)
else:
raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")
m, u = model.diffusion_model.load_state_dict(state_dict, strict=False)
if len(m) > 0: print("Missing UNET keys", m)
if len(u) > 0: print("Leftover UNET keys", u)
model.diffusion_model.dtype = unet_dtype
model.diffusion_model.eval()
model.diffusion_model.to(unet_dtype)
model_patcher = comfy.model_patcher.ModelPatcher(
model,
load_device = load_device,
offload_device = offload_device,
current_device = "cpu",
)
return model_patcher
def guess_pixart_config(sd):
"""
Guess config based on converted state dict.
"""
# Shared settings based on DiT_XL_2 - could be enumerated
config = {
"num_heads" : 16, # get from attention
"patch_size" : 2, # final layer I guess?
"hidden_size" : 1152, # pos_embed.shape[2]
}
config["depth"] = sum([key.endswith(".attn.proj.weight") for key in sd.keys()]) or 28
try:
# this is not present in the diffusers version for sigma?
config["model_max_length"] = sd["y_embedder.y_embedding"].shape[0]
except KeyError:
# need better logic to guess this
config["model_max_length"] = 300
if "pos_embed" in sd:
config["input_size"] = int(math.sqrt(sd["pos_embed"].shape[1])) * config["patch_size"]
config["pe_interpolation"] = config["input_size"] // (512//8) # dumb guess
target_arch = "PixArtMS"
if config["model_max_length"] == 300:
# Sigma
target_arch = "PixArtMSSigma"
config["micro_condition"] = False
if "input_size" not in config:
# The diffusers weights for 1K/2K are exactly the same...?
# replace patch embed logic with HyDiT?
print(f"PixArt: diffusers weights - 2K model will be broken, use manual loading!")
config["input_size"] = 1024//8
else:
# Alpha
if "csize_embedder.mlp.0.weight" in sd:
# MS (microconds)
target_arch = "PixArtMS"
config["micro_condition"] = True
if "input_size" not in config:
config["input_size"] = 1024//8
config["pe_interpolation"] = 2
else:
# PixArt
target_arch = "PixArt"
if "input_size" not in config:
config["input_size"] = 512//8
config["pe_interpolation"] = 1
print("PixArt guessed config:", target_arch, config)
return {
"target": target_arch,
"unet_config": config,
"sampling_settings": {
"beta_schedule" : "sqrt_linear",
"linear_start" : 0.0001,
"linear_end" : 0.02,
"timesteps" : 1000,
}
}