Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,516 Bytes
a7dedf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os, torch
from typing import List, Tuple, Optional, Union, Dict
from .ebc import _ebc, EBC
from .clip_ebc import _clip_ebc, CLIP_EBC
def get_model(
model_info_path: str,
model_name: Optional[str] = None,
block_size: Optional[int] = None,
bins: Optional[List[Tuple[float, float]]] = None,
bin_centers: Optional[List[float]] = None,
zero_inflated: Optional[bool] = True,
# parameters for CLIP_EBC
clip_weight_name: Optional[str] = None,
num_vpt: Optional[int] = None,
vpt_drop: Optional[float] = None,
input_size: Optional[int] = None,
adapter: bool = False,
adapter_reduction: Optional[int] = None,
lora: bool = False,
lora_rank: Optional[int] = None,
lora_alpha: Optional[int] = None,
lora_dropout: Optional[float] = None,
norm: str = "none",
act: str = "none",
text_prompts: Optional[List[str]] = None
) -> Union[EBC, CLIP_EBC]:
if os.path.exists(model_info_path):
model_info = torch.load(model_info_path, map_location="cpu", weights_only=False)
model_name = model_info["config"]["model_name"]
block_size = model_info["config"]["block_size"]
bins = model_info["config"]["bins"]
bin_centers = model_info["config"]["bin_centers"]
zero_inflated = model_info["config"]["zero_inflated"]
clip_weight_name = model_info["config"].get("clip_weight_name", None)
num_vpt = model_info["config"].get("num_vpt", None)
vpt_drop = model_info["config"].get("vpt_drop", None)
adapter = model_info["config"].get("adapter", False)
adapter_reduction = model_info["config"].get("adapter_reduction", None)
lora = model_info["config"].get("lora", False)
lora_rank = model_info["config"].get("lora_rank", None)
lora_alpha = model_info["config"].get("lora_alpha", None)
lora_dropout = model_info["config"].get("lora_dropout", None)
input_size = model_info["config"].get("input_size", None)
text_prompts = model_info["config"].get("text_prompts", None)
norm = model_info["config"].get("norm", "none")
act = model_info["config"].get("act", "none")
weights = model_info["weights"]
else:
assert model_name is not None, "model_name should be provided if model_info_path is not provided"
assert block_size is not None, "block_size should be provided"
assert bins is not None, "bins should be provided"
assert bin_centers is not None, "bin_centers should be provided"
weights = None
if "ViT" in model_name:
assert num_vpt is not None, f"num_vpt should be provided for ViT models, got {num_vpt}"
assert vpt_drop is not None, f"vpt_drop should be provided for ViT models, got {vpt_drop}"
if model_name.startswith("CLIP_") or model_name.startswith("CLIP-"):
assert clip_weight_name is not None, f"clip_weight_name should be provided for CLIP models, got {clip_weight_name}"
model = _clip_ebc(
model_name=model_name[5:],
weight_name=clip_weight_name,
block_size=block_size,
bins=bins,
bin_centers=bin_centers,
zero_inflated=zero_inflated,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
input_size=input_size,
adapter=adapter,
adapter_reduction=adapter_reduction,
lora=lora,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
text_prompts=text_prompts,
norm=norm,
act=act
)
model_config = {
"model_name": model_name,
"block_size": block_size,
"bins": bins,
"bin_centers": bin_centers,
"zero_inflated": zero_inflated,
"clip_weight_name": clip_weight_name,
"num_vpt": num_vpt,
"vpt_drop": vpt_drop,
"input_size": input_size,
"adapter": adapter,
"adapter_reduction": adapter_reduction,
"lora": lora,
"lora_rank": lora_rank,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
"text_prompts": model.text_prompts,
"norm": norm,
"act": act
}
else:
assert not adapter, "adapter for non-CLIP models is not implemented yet"
assert not lora, "lora for non-CLIP models is not implemented yet"
model = _ebc(
model_name=model_name,
block_size=block_size,
bins=bins,
bin_centers=bin_centers,
zero_inflated=zero_inflated,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
input_size=input_size,
norm=norm,
act=act
)
model_config = {
"model_name": model_name,
"block_size": block_size,
"bins": bins,
"bin_centers": bin_centers,
"zero_inflated": zero_inflated,
"num_vpt": num_vpt,
"vpt_drop": vpt_drop,
"input_size": input_size,
"norm": norm,
"act": act
}
model.config = model_config
model_info = {"config": model_config, "weights": weights}
if weights is not None:
model.load_state_dict(weights)
if not os.path.exists(model_info_path):
torch.save(model_info, model_info_path)
return model
__all__ = ["get_model"]
|