|
import requests |
|
import os |
|
import yaml |
|
from .utils import get_ckpt, get_yaml_config |
|
|
|
|
|
def download_ckpt_yaml(model_path, model_name, ckpt_path, yaml_url=None): |
|
def download_file(url, save_path): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
with open(save_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
os.makedirs(model_path, exist_ok=True) |
|
local_dir = os.path.join(model_path, model_name) |
|
os.makedirs(local_dir, exist_ok=True) |
|
|
|
ckpt_name = ckpt_path.split("/")[-1] |
|
local_ckpt_path = os.path.join(local_dir, ckpt_name) |
|
if not os.path.exists(local_ckpt_path): |
|
print(f"Downloading CKPT to {local_ckpt_path}") |
|
download_file(ckpt_path, local_ckpt_path) |
|
|
|
if yaml_url: |
|
yaml_name = yaml_url.split("/")[-1] |
|
local_yaml_path = os.path.join(local_dir, yaml_name) |
|
if not os.path.exists(local_yaml_path): |
|
print(f"Downloading YAML to {local_yaml_path}") |
|
download_file(yaml_url, local_yaml_path) |
|
return local_ckpt_path, local_yaml_path |
|
|
|
return local_ckpt_path, None |
|
|
|
|
|
def get_model(model_path, model_name): |
|
model = None |
|
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN", None) |
|
data_params = { |
|
"target_image_size": (512, 512), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
if model_name.lower() == "anole": |
|
from src.vqvaes.anole.anole import VQModel |
|
yaml_url = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.yaml" |
|
ckpt_path = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.ckpt" |
|
|
|
if model_path is not None: |
|
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "anole", ckpt_path, yaml_url) |
|
config = get_yaml_config(yaml_url) |
|
|
|
params = config["model"]["params"] |
|
if "lossconfig" in params: |
|
del params["lossconfig"] |
|
params["ckpt_path"] = ckpt_path |
|
model = VQModel(**params) |
|
data_params = { |
|
"target_image_size": (512, 512), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif model_name.lower() == "chameleon": |
|
from src.vqvaes.anole.anole import VQModel |
|
|
|
yaml_url = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.yaml" |
|
ckpt_path = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.ckpt" |
|
if model_path is not None: |
|
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "chameleon", ckpt_path, yaml_url) |
|
config = get_yaml_config(yaml_url) |
|
|
|
params = config["model"]["params"] |
|
if "lossconfig" in params: |
|
del params["lossconfig"] |
|
params["ckpt_path"] = ckpt_path |
|
model = VQModel(**params) |
|
data_params = { |
|
"target_image_size": (512, 512), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif model_name.lower() == "llamagen-ds16": |
|
from src.vqvaes.llamagen.llamagen import VQ_models |
|
ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt" |
|
if model_path is not None: |
|
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16", ckpt_path, None) |
|
|
|
model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8) |
|
model.load_state_dict(get_ckpt(ckpt_path, key="model")) |
|
data_params = { |
|
"target_image_size": (512, 512), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif model_name.lower() == "llamagen-ds16-t2i": |
|
from src.vqvaes.llamagen.llamagen import VQ_models |
|
ckpt_path = "https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt" |
|
if model_path is not None: |
|
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16-t2i", ckpt_path, None) |
|
|
|
model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8) |
|
model.load_state_dict(get_ckpt(ckpt_path, key="model")) |
|
data_params = { |
|
"target_image_size": (512, 512), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif model_name.lower() == "llamagen-ds8": |
|
from src.vqvaes.llamagen.llamagen import VQ_models |
|
ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt" |
|
if model_path is not None: |
|
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds8", ckpt_path, None) |
|
|
|
model = VQ_models["VQ-8"](codebook_size=16384, codebook_embed_dim=8) |
|
model.load_state_dict(get_ckpt(ckpt_path, key="model")) |
|
data_params = { |
|
"target_image_size": (256, 256), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif model_name.lower() == "flowmo_lo": |
|
from src.vqvaes.flowmo.flowmo import build_model |
|
yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml" |
|
ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_lo.pth" |
|
if model_path is not None: |
|
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_lo", ckpt_path, yaml_url) |
|
config = get_yaml_config(yaml_url) |
|
|
|
config.model.context_dim = 18 |
|
model = build_model(config) |
|
model.load_state_dict( |
|
get_ckpt(ckpt_path, key="model_ema_state_dict") |
|
) |
|
data_params = { |
|
"target_image_size": (256, 256), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif model_name.lower() == "flowmo_hi": |
|
from src.vqvaes.flowmo.flowmo import build_model |
|
|
|
yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml" |
|
ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_hi.pth" |
|
if model_path is not None: |
|
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_hi", ckpt_path, yaml_url) |
|
config = get_yaml_config(yaml_url) |
|
|
|
config.model.context_dim = 56 |
|
config.model.codebook_size_for_entropy = 14 |
|
model = build_model(config) |
|
model.load_state_dict( |
|
get_ckpt(ckpt_path, key="model_ema_state_dict") |
|
) |
|
data_params = { |
|
"target_image_size": (256, 256), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif model_name.lower() == "open_magvit2": |
|
from src.vqvaes.open_magvit2.open_magvit2 import VQModel |
|
|
|
yaml_url = "https://raw.githubusercontent.com/TencentARC/SEED-Voken/refs/heads/main/configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml" |
|
ckpt_path = "https://huggingface.co/TencentARC/Open-MAGVIT2-Tokenizer-256-resolution/resolve/main/imagenet_256_L.ckpt" |
|
if model_path is not None: |
|
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "open_magvit2", ckpt_path, yaml_url) |
|
config = get_yaml_config(yaml_url) |
|
|
|
model = VQModel(**config.model.init_args) |
|
model.load_state_dict(get_ckpt(ckpt_path, key="state_dict")) |
|
data_params = { |
|
"target_image_size": (256, 256), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif "maskbit" in model_name.lower(): |
|
from src.vqvaes.maskbit.maskbit import ConvVQModel |
|
|
|
if "16bit" in model_name.lower(): |
|
yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_16bit.yaml" |
|
ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_16bit/resolve/main/maskbit_tokenizer_16bit.bin" |
|
if model_path is not None: |
|
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-16bit", ckpt_path, yaml_url) |
|
elif "18bit" in model_name.lower(): |
|
yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_18bit.yaml" |
|
ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_18bit/resolve/main/maskbit_tokenizer_18bit.bin" |
|
if model_path is not None: |
|
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-18bit", ckpt_path, yaml_url) |
|
else: |
|
raise Exception(f"Unsupported model: {model_name}") |
|
|
|
config = get_yaml_config(yaml_url) |
|
model = ConvVQModel(config.model.vq_model, legacy=False) |
|
model.load_pretrained(get_ckpt(ckpt_path, key=None)) |
|
data_params = { |
|
"target_image_size": (256, 256), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
"standardize": False, |
|
} |
|
|
|
elif "bsqvit" in model_name.lower(): |
|
from src.vqvaes.bsqvit.bsqvit import VITBSQModel |
|
|
|
yaml_url = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/config.yaml" |
|
ckpt_path = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/checkpoint.pt" |
|
if model_path is not None: |
|
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "bsqvit", ckpt_path, yaml_url) |
|
|
|
config = get_yaml_config(yaml_url) |
|
model = VITBSQModel(**config["model"]["params"]) |
|
model.init_from_ckpt(get_ckpt(ckpt_path, key="state_dict")) |
|
data_params = { |
|
"target_image_size": (256, 256), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
"standardize": False, |
|
} |
|
|
|
elif "titok" in model_name.lower(): |
|
from src.vqvaes.titok.titok import TiTok |
|
|
|
ckpt_path = None |
|
if "bl64" in model_name.lower(): |
|
ckpt_path = "yucornetto/tokenizer_titok_bl64_vq8k_imagenet" |
|
elif "bl128" in model_name.lower(): |
|
ckpt_path = "yucornetto/tokenizer_titok_bl128_vq8k_imagenet" |
|
elif "sl256" in model_name.lower(): |
|
ckpt_path = "yucornetto/tokenizer_titok_sl256_vq8k_imagenet" |
|
elif "l32" in model_name.lower(): |
|
ckpt_path = "yucornetto/tokenizer_titok_l32_imagenet" |
|
elif "b64" in model_name.lower(): |
|
ckpt_path = "yucornetto/tokenizer_titok_b64_imagenet" |
|
elif "s128" in model_name.lower(): |
|
ckpt_path = "yucornetto/tokenizer_titok_s128_imagenet" |
|
else: |
|
raise Exception(f"Unsupported model: {model_name}") |
|
|
|
model = TiTok.from_pretrained( |
|
ckpt_path, |
|
token=huggingface_token |
|
) |
|
data_params = { |
|
"target_image_size": (256, 256), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
"standardize": False, |
|
} |
|
|
|
elif "janus_pro" in model_name.lower(): |
|
from janus.models import MultiModalityCausalLM |
|
from src.vqvaes.janus_pro.janus_pro import forward |
|
import types |
|
|
|
model = MultiModalityCausalLM.from_pretrained( |
|
"deepseek-ai/Janus-Pro-7B", trust_remote_code=True, |
|
token=huggingface_token |
|
).gen_vision_model |
|
model.forward = types.MethodType(forward, model) |
|
data_params = { |
|
"target_image_size": (384, 384), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
} |
|
|
|
elif "var" in model_name.lower(): |
|
from src.vqvaes.var.var_vq import VQVAE |
|
|
|
ckpt_path = "https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth" |
|
if model_path is not None: |
|
ckpt_path, _ = download_ckpt_yaml(model_path, "var", ckpt_path, None) |
|
|
|
v_patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16) |
|
if "512" in model_name.lower(): |
|
v_patch_nums = (1, 2, 3, 4, 6, 9, 13, 18, 24, 32) |
|
model = VQVAE( |
|
vocab_size=4096, |
|
z_channels=32, |
|
ch=160, |
|
test_mode=True, |
|
share_quant_resi=4, |
|
v_patch_nums=v_patch_nums, |
|
) |
|
model.load_state_dict(get_ckpt(ckpt_path, key=None)) |
|
data_params = { |
|
"target_image_size": ( |
|
(512, 512) if "512" in model_name.lower() else (256, 256) |
|
), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
"standardize": False, |
|
} |
|
|
|
elif ( |
|
"infinity" in model_name.lower() |
|
): |
|
from src.vqvaes.infinity.vae import vae_model |
|
|
|
if "d32" in model_name: |
|
ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d32.pth" |
|
codebook_dim = 32 |
|
if model_path is not None: |
|
ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d32", ckpt_path, None) |
|
elif "d64" in model_name: |
|
ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d64.pth" |
|
codebook_dim = 64 |
|
if model_path is not None: |
|
ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d64", ckpt_path, None) |
|
|
|
schedule_mode = "dynamic" |
|
codebook_size = 2**codebook_dim |
|
patch_size = 16 |
|
encoder_ch_mult = [1, 2, 4, 4, 4] |
|
decoder_ch_mult = [1, 2, 4, 4, 4] |
|
|
|
ckpt = get_ckpt(ckpt_path, key=None) |
|
model = vae_model( |
|
ckpt, |
|
schedule_mode, |
|
codebook_dim, |
|
codebook_size, |
|
patch_size=patch_size, |
|
encoder_ch_mult=encoder_ch_mult, |
|
decoder_ch_mult=decoder_ch_mult, |
|
test_mode=True, |
|
) |
|
|
|
data_params = { |
|
"target_image_size": (1024, 1024), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
"standardize": False, |
|
} |
|
|
|
elif "sd3.5l" in model_name.lower(): |
|
from src.vaes.stable_diffusion.vae import forward |
|
from diffusers import AutoencoderKL |
|
import types |
|
|
|
model = AutoencoderKL.from_pretrained( |
|
"huaweilin/stable-diffusion-3.5-large-vae", subfolder="vae", |
|
token=huggingface_token |
|
) |
|
model.forward = types.MethodType(forward, model) |
|
data_params = { |
|
"target_image_size": (1024, 1024), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
"standardize": True, |
|
} |
|
|
|
elif "FLUX.1-dev".lower() in model_name.lower(): |
|
from src.vaes.stable_diffusion.vae import forward |
|
from diffusers import AutoencoderKL |
|
import types |
|
|
|
model = AutoencoderKL.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", subfolder="vae", |
|
token=huggingface_token |
|
) |
|
model.forward = types.MethodType(forward, model) |
|
data_params = { |
|
"target_image_size": (1024, 1024), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
"standardize": True, |
|
} |
|
|
|
elif "gpt4o" in model_name.lower(): |
|
from src.vaes.gpt_image.gpt_image import GPTImage |
|
|
|
data_params = { |
|
"target_image_size": (1024, 1024), |
|
"lock_ratio": True, |
|
"center_crop": True, |
|
"padding": False, |
|
"standardize": False, |
|
} |
|
model = GPTImage(data_params) |
|
|
|
else: |
|
raise Exception(f"Unsupported model: \"{model_name}\"") |
|
|
|
try: |
|
trainable_params = sum(p.numel() for p in model.parameters()) |
|
print("trainable_params:", trainable_params) |
|
except Exception as e: |
|
print(e) |
|
pass |
|
|
|
model.eval() |
|
return model, data_params |
|
|