VTBench / src /model_processing.py
huaweilin's picture
update
7d3842f
import requests
import os
import yaml
from .utils import get_ckpt, get_yaml_config
def download_ckpt_yaml(model_path, model_name, ckpt_path, yaml_url=None):
def download_file(url, save_path):
response = requests.get(url)
response.raise_for_status()
with open(save_path, 'wb') as f:
f.write(response.content)
os.makedirs(model_path, exist_ok=True)
local_dir = os.path.join(model_path, model_name)
os.makedirs(local_dir, exist_ok=True)
ckpt_name = ckpt_path.split("/")[-1]
local_ckpt_path = os.path.join(local_dir, ckpt_name)
if not os.path.exists(local_ckpt_path):
print(f"Downloading CKPT to {local_ckpt_path}")
download_file(ckpt_path, local_ckpt_path)
if yaml_url:
yaml_name = yaml_url.split("/")[-1]
local_yaml_path = os.path.join(local_dir, yaml_name)
if not os.path.exists(local_yaml_path):
print(f"Downloading YAML to {local_yaml_path}")
download_file(yaml_url, local_yaml_path)
return local_ckpt_path, local_yaml_path
return local_ckpt_path, None
def get_model(model_path, model_name):
model = None
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN", None)
data_params = {
"target_image_size": (512, 512),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
if model_name.lower() == "anole":
from src.vqvaes.anole.anole import VQModel
yaml_url = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.yaml"
ckpt_path = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.ckpt"
if model_path is not None:
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "anole", ckpt_path, yaml_url)
config = get_yaml_config(yaml_url)
params = config["model"]["params"]
if "lossconfig" in params:
del params["lossconfig"]
params["ckpt_path"] = ckpt_path
model = VQModel(**params)
data_params = {
"target_image_size": (512, 512),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif model_name.lower() == "chameleon":
from src.vqvaes.anole.anole import VQModel
yaml_url = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.yaml"
ckpt_path = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.ckpt"
if model_path is not None:
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "chameleon", ckpt_path, yaml_url)
config = get_yaml_config(yaml_url)
params = config["model"]["params"]
if "lossconfig" in params:
del params["lossconfig"]
params["ckpt_path"] = ckpt_path
model = VQModel(**params)
data_params = {
"target_image_size": (512, 512),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif model_name.lower() == "llamagen-ds16":
from src.vqvaes.llamagen.llamagen import VQ_models
ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt"
if model_path is not None:
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16", ckpt_path, None)
model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8)
model.load_state_dict(get_ckpt(ckpt_path, key="model"))
data_params = {
"target_image_size": (512, 512),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif model_name.lower() == "llamagen-ds16-t2i":
from src.vqvaes.llamagen.llamagen import VQ_models
ckpt_path = "https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt"
if model_path is not None:
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16-t2i", ckpt_path, None)
model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8)
model.load_state_dict(get_ckpt(ckpt_path, key="model"))
data_params = {
"target_image_size": (512, 512),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif model_name.lower() == "llamagen-ds8":
from src.vqvaes.llamagen.llamagen import VQ_models
ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt"
if model_path is not None:
ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds8", ckpt_path, None)
model = VQ_models["VQ-8"](codebook_size=16384, codebook_embed_dim=8)
model.load_state_dict(get_ckpt(ckpt_path, key="model"))
data_params = {
"target_image_size": (256, 256),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif model_name.lower() == "flowmo_lo":
from src.vqvaes.flowmo.flowmo import build_model
yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml"
ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_lo.pth"
if model_path is not None:
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_lo", ckpt_path, yaml_url)
config = get_yaml_config(yaml_url)
config.model.context_dim = 18
model = build_model(config)
model.load_state_dict(
get_ckpt(ckpt_path, key="model_ema_state_dict")
)
data_params = {
"target_image_size": (256, 256),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif model_name.lower() == "flowmo_hi":
from src.vqvaes.flowmo.flowmo import build_model
yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml"
ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_hi.pth"
if model_path is not None:
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_hi", ckpt_path, yaml_url)
config = get_yaml_config(yaml_url)
config.model.context_dim = 56
config.model.codebook_size_for_entropy = 14
model = build_model(config)
model.load_state_dict(
get_ckpt(ckpt_path, key="model_ema_state_dict")
)
data_params = {
"target_image_size": (256, 256),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif model_name.lower() == "open_magvit2":
from src.vqvaes.open_magvit2.open_magvit2 import VQModel
yaml_url = "https://raw.githubusercontent.com/TencentARC/SEED-Voken/refs/heads/main/configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml"
ckpt_path = "https://huggingface.co/TencentARC/Open-MAGVIT2-Tokenizer-256-resolution/resolve/main/imagenet_256_L.ckpt"
if model_path is not None:
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "open_magvit2", ckpt_path, yaml_url)
config = get_yaml_config(yaml_url)
model = VQModel(**config.model.init_args)
model.load_state_dict(get_ckpt(ckpt_path, key="state_dict"))
data_params = {
"target_image_size": (256, 256),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif "maskbit" in model_name.lower():
from src.vqvaes.maskbit.maskbit import ConvVQModel
if "16bit" in model_name.lower():
yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_16bit.yaml"
ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_16bit/resolve/main/maskbit_tokenizer_16bit.bin"
if model_path is not None:
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-16bit", ckpt_path, yaml_url)
elif "18bit" in model_name.lower():
yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_18bit.yaml"
ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_18bit/resolve/main/maskbit_tokenizer_18bit.bin"
if model_path is not None:
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-18bit", ckpt_path, yaml_url)
else:
raise Exception(f"Unsupported model: {model_name}")
config = get_yaml_config(yaml_url)
model = ConvVQModel(config.model.vq_model, legacy=False)
model.load_pretrained(get_ckpt(ckpt_path, key=None))
data_params = {
"target_image_size": (256, 256),
"lock_ratio": True,
"center_crop": True,
"padding": False,
"standardize": False,
}
elif "bsqvit" in model_name.lower():
from src.vqvaes.bsqvit.bsqvit import VITBSQModel
yaml_url = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/config.yaml"
ckpt_path = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/checkpoint.pt"
if model_path is not None:
ckpt_path, yaml_url = download_ckpt_yaml(model_path, "bsqvit", ckpt_path, yaml_url)
config = get_yaml_config(yaml_url)
model = VITBSQModel(**config["model"]["params"])
model.init_from_ckpt(get_ckpt(ckpt_path, key="state_dict"))
data_params = {
"target_image_size": (256, 256),
"lock_ratio": True,
"center_crop": True,
"padding": False,
"standardize": False,
}
elif "titok" in model_name.lower():
from src.vqvaes.titok.titok import TiTok
ckpt_path = None
if "bl64" in model_name.lower():
ckpt_path = "yucornetto/tokenizer_titok_bl64_vq8k_imagenet"
elif "bl128" in model_name.lower():
ckpt_path = "yucornetto/tokenizer_titok_bl128_vq8k_imagenet"
elif "sl256" in model_name.lower():
ckpt_path = "yucornetto/tokenizer_titok_sl256_vq8k_imagenet"
elif "l32" in model_name.lower():
ckpt_path = "yucornetto/tokenizer_titok_l32_imagenet"
elif "b64" in model_name.lower():
ckpt_path = "yucornetto/tokenizer_titok_b64_imagenet"
elif "s128" in model_name.lower():
ckpt_path = "yucornetto/tokenizer_titok_s128_imagenet"
else:
raise Exception(f"Unsupported model: {model_name}")
model = TiTok.from_pretrained(
ckpt_path,
token=huggingface_token
)
data_params = {
"target_image_size": (256, 256),
"lock_ratio": True,
"center_crop": True,
"padding": False,
"standardize": False,
}
elif "janus_pro" in model_name.lower():
from janus.models import MultiModalityCausalLM
from src.vqvaes.janus_pro.janus_pro import forward
import types
model = MultiModalityCausalLM.from_pretrained(
"deepseek-ai/Janus-Pro-7B", trust_remote_code=True,
token=huggingface_token
).gen_vision_model
model.forward = types.MethodType(forward, model)
data_params = {
"target_image_size": (384, 384),
"lock_ratio": True,
"center_crop": True,
"padding": False,
}
elif "var" in model_name.lower():
from src.vqvaes.var.var_vq import VQVAE
ckpt_path = "https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth"
if model_path is not None:
ckpt_path, _ = download_ckpt_yaml(model_path, "var", ckpt_path, None)
v_patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
if "512" in model_name.lower():
v_patch_nums = (1, 2, 3, 4, 6, 9, 13, 18, 24, 32)
model = VQVAE(
vocab_size=4096,
z_channels=32,
ch=160,
test_mode=True,
share_quant_resi=4,
v_patch_nums=v_patch_nums,
)
model.load_state_dict(get_ckpt(ckpt_path, key=None))
data_params = {
"target_image_size": (
(512, 512) if "512" in model_name.lower() else (256, 256)
),
"lock_ratio": True,
"center_crop": True,
"padding": False,
"standardize": False,
}
elif (
"infinity" in model_name.lower()
): # "infinity_d32", "infinity_d64", "infinity_d56_f8_14_patchify"
from src.vqvaes.infinity.vae import vae_model
if "d32" in model_name:
ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d32.pth"
codebook_dim = 32
if model_path is not None:
ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d32", ckpt_path, None)
elif "d64" in model_name:
ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d64.pth"
codebook_dim = 64
if model_path is not None:
ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d64", ckpt_path, None)
schedule_mode = "dynamic"
codebook_size = 2**codebook_dim
patch_size = 16
encoder_ch_mult = [1, 2, 4, 4, 4]
decoder_ch_mult = [1, 2, 4, 4, 4]
ckpt = get_ckpt(ckpt_path, key=None)
model = vae_model(
ckpt,
schedule_mode,
codebook_dim,
codebook_size,
patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult,
decoder_ch_mult=decoder_ch_mult,
test_mode=True,
)
data_params = {
"target_image_size": (1024, 1024),
"lock_ratio": True,
"center_crop": True,
"padding": False,
"standardize": False,
}
elif "sd3.5l" in model_name.lower(): # SD3.5L
from src.vaes.stable_diffusion.vae import forward
from diffusers import AutoencoderKL
import types
model = AutoencoderKL.from_pretrained(
"huaweilin/stable-diffusion-3.5-large-vae", subfolder="vae",
token=huggingface_token
)
model.forward = types.MethodType(forward, model)
data_params = {
"target_image_size": (1024, 1024),
"lock_ratio": True,
"center_crop": True,
"padding": False,
"standardize": True,
}
elif "FLUX.1-dev".lower() in model_name.lower(): # SD3.5L
from src.vaes.stable_diffusion.vae import forward
from diffusers import AutoencoderKL
import types
model = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae",
token=huggingface_token
)
model.forward = types.MethodType(forward, model)
data_params = {
"target_image_size": (1024, 1024),
"lock_ratio": True,
"center_crop": True,
"padding": False,
"standardize": True,
}
elif "gpt4o" in model_name.lower():
from src.vaes.gpt_image.gpt_image import GPTImage
data_params = {
"target_image_size": (1024, 1024),
"lock_ratio": True,
"center_crop": True,
"padding": False,
"standardize": False,
}
model = GPTImage(data_params)
else:
raise Exception(f"Unsupported model: \"{model_name}\"")
try:
trainable_params = sum(p.numel() for p in model.parameters())
print("trainable_params:", trainable_params)
except Exception as e:
print(e)
pass
model.eval()
return model, data_params