ZIP / models /__init__.py
Yiming-M's picture
2025-08-01 10:49 πŸš€
c628976
import os, torch
from typing import List, Tuple, Optional, Union, Dict
from .ebc import _ebc, EBC
from .clip_ebc import _clip_ebc, CLIP_EBC
def get_model(
model_info_path: str,
model_name: Optional[str] = None,
block_size: Optional[int] = None,
bins: Optional[List[Tuple[float, float]]] = None,
bin_centers: Optional[List[float]] = None,
zero_inflated: Optional[bool] = True,
# parameters for CLIP_EBC
clip_weight_name: Optional[str] = None,
num_vpt: Optional[int] = None,
vpt_drop: Optional[float] = None,
input_size: Optional[int] = None,
norm: str = "none",
act: str = "none",
text_prompts: Optional[List[str]] = None
) -> Union[EBC, CLIP_EBC]:
if os.path.exists(model_info_path):
model_info = torch.load(model_info_path, map_location="cpu", weights_only=False)
model_name = model_info["config"]["model_name"]
block_size = model_info["config"]["block_size"]
bins = model_info["config"]["bins"]
bin_centers = model_info["config"]["bin_centers"]
zero_inflated = model_info["config"]["zero_inflated"]
clip_weight_name = model_info["config"].get("clip_weight_name", None)
num_vpt = model_info["config"].get("num_vpt", None)
vpt_drop = model_info["config"].get("vpt_drop", None)
input_size = model_info["config"].get("input_size", None)
text_prompts = model_info["config"].get("text_prompts", None)
norm = model_info["config"].get("norm", "none")
act = model_info["config"].get("act", "none")
weights = model_info["weights"]
else:
assert model_name is not None, "model_name should be provided if model_info_path is not provided"
assert block_size is not None, "block_size should be provided"
assert bins is not None, "bins should be provided"
assert bin_centers is not None, "bin_centers should be provided"
weights = None
if "ViT" in model_name:
assert num_vpt is not None, f"num_vpt should be provided for ViT models, got {num_vpt}"
assert vpt_drop is not None, f"vpt_drop should be provided for ViT models, got {vpt_drop}"
if model_name.startswith("CLIP_") or model_name.startswith("CLIP-"):
assert clip_weight_name is not None, f"clip_weight_name should be provided for CLIP models, got {clip_weight_name}"
model = _clip_ebc(
model_name=model_name[5:],
weight_name=clip_weight_name,
block_size=block_size,
bins=bins,
bin_centers=bin_centers,
zero_inflated=zero_inflated,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
input_size=input_size,
text_prompts=text_prompts,
norm=norm,
act=act
)
model_config = {
"model_name": model_name,
"block_size": block_size,
"bins": bins,
"bin_centers": bin_centers,
"zero_inflated": zero_inflated,
"clip_weight_name": clip_weight_name,
"num_vpt": num_vpt,
"vpt_drop": vpt_drop,
"input_size": input_size,
"text_prompts": model.text_prompts,
"norm": norm,
"act": act
}
else:
model = _ebc(
model_name=model_name,
block_size=block_size,
bins=bins,
bin_centers=bin_centers,
zero_inflated=zero_inflated,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
input_size=input_size,
norm=norm,
act=act
)
model_config = {
"model_name": model_name,
"block_size": block_size,
"bins": bins,
"bin_centers": bin_centers,
"zero_inflated": zero_inflated,
"num_vpt": num_vpt,
"vpt_drop": vpt_drop,
"input_size": input_size,
"norm": norm,
"act": act
}
model.config = model_config
model_info = {"config": model_config, "weights": weights}
if weights is not None:
model.load_state_dict(weights)
if not os.path.exists(model_info_path):
torch.save(model_info, model_info_path)
return model
__all__ = ["get_model"]