ZIP / models /clip_ebc /model.py
Yiming-M's picture
2025-07-31 18:59 🐣
a7dedf9
raw
history blame
12.3 kB
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import numpy as np
from typing import List, Optional, Dict, Tuple
from copy import deepcopy
from .vit import vit_names_and_weights, _vit
from .convnext import convnext_names_and_weights, _convnext
from .resnet import resnet_names_and_weights, _resnet
from .mobileclip import mobileclip_names_and_weights, _mobileclip
from .utils import encode_text, optimize_text_prompts
from ..utils import conv1x1
supported_models_and_weights = deepcopy(vit_names_and_weights)
supported_models_and_weights.update(convnext_names_and_weights)
supported_models_and_weights.update(resnet_names_and_weights)
supported_models_and_weights.update(mobileclip_names_and_weights)
class CLIP_EBC(nn.Module):
def __init__(
self,
model_name: str,
weight_name: str,
block_size: Optional[int] = None,
bins: Optional[List[Tuple[float, float]]] = None,
bin_centers: Optional[List[float]] = None,
zero_inflated: Optional[bool] = True,
num_vpt: Optional[int] = None,
vpt_drop: Optional[float] = None,
input_size: Optional[int] = None,
adapter: Optional[bool] = False,
adapter_reduction: Optional[int] = None,
lora: Optional[bool] = False,
lora_rank: Optional[int] = None,
lora_alpha: Optional[float] = None,
lora_dropout: Optional[float] = None,
text_prompts: Optional[Dict[str, List[str]]] = None,
norm: Optional[str] = "none",
act: Optional[str] = "none",
) -> None:
super().__init__()
if "mobileclip" in model_name.lower() or "vit" in model_name.lower():
model_name = model_name.replace("_", "-")
assert model_name in supported_models_and_weights, f"Model name should be one of {list(supported_models_and_weights.keys())}, but got {model_name}."
assert weight_name in supported_models_and_weights[model_name], f"Pretrained should be one of {supported_models_and_weights[model_name]}, but got {weight_name}."
assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}"
assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}"
assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}"
bins = [(float(b[0]), float(b[1])) for b in bins]
assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}"
self.model_name = model_name
self.weight_name = weight_name
self.block_size = block_size
self.bins = bins
self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1))
self.zero_inflated = zero_inflated
self.text_prompts = text_prompts
# Image encoder
if model_name in vit_names_and_weights:
assert num_vpt is not None and num_vpt >= 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
vpt_drop = 0. if vpt_drop is None else vpt_drop
self.backbone = _vit(
model_name=model_name,
weight_name=weight_name,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
block_size=block_size,
adapter=adapter,
adapter_reduction=adapter_reduction,
lora=lora,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
input_size=(input_size, input_size),
norm=norm,
act=act
)
elif model_name in convnext_names_and_weights:
self.backbone = _convnext(
model_name=model_name,
weight_name=weight_name,
block_size=block_size,
adapter=adapter,
adapter_reduction=adapter_reduction,
lora=lora,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
norm=norm,
act=act
)
elif model_name in resnet_names_and_weights:
self.backbone = _resnet(
model_name=model_name,
weight_name=weight_name,
block_size=block_size,
adapter=adapter,
adapter_reduction=adapter_reduction,
lora=lora,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
norm=norm,
act=act
)
elif model_name in mobileclip_names_and_weights:
self.backbone = _mobileclip(
model_name=model_name,
weight_name=weight_name,
block_size=block_size,
adapter=adapter,
adapter_reduction=adapter_reduction,
lora=lora,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
norm=norm,
act=act
)
self._build_text_feats()
self._build_head()
def _build_text_feats(self) -> None:
model_name, weight_name = self.model_name, self.weight_name
text_prompts = self.text_prompts
if text_prompts is None:
bins = [b[0] if b[0] == b[1] else b for b in self.bins] # if the bin is a single value (e.g., [0, 0]), use that value
if self.zero_inflated: # separate 0 from the rest
assert bins[0] == 0, f"Expected the first bin to be 0, got {bins[0]}."
bins_pi = [0, (1, float("inf"))]
bins_lambda = bins[1:]
pi_text_prompts = optimize_text_prompts(model_name, weight_name, bins_pi)
lambda_text_prompts = optimize_text_prompts(model_name, weight_name, bins_lambda)
self.text_prompts = {"pi": pi_text_prompts, "lambda": lambda_text_prompts}
pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts)
lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts)
pi_text_feats.requires_grad = False
lambda_text_feats.requires_grad = False
self.register_buffer("pi_text_feats", pi_text_feats)
self.register_buffer("lambda_text_feats", lambda_text_feats)
else:
text_prompts = optimize_text_prompts(model_name, weight_name, bins)
self.text_prompts = text_prompts
text_feats = encode_text(model_name, weight_name, text_prompts)
text_feats.requires_grad = False
self.register_buffer("text_feats", text_feats)
else:
if self.zero_inflated:
assert "pi" in text_prompts and "lambda" in text_prompts, f"Expected text_prompts to have keys 'pi' and 'lambda', got {text_prompts.keys()}."
pi_text_prompts = text_prompts["pi"]
lambda_text_prompts = text_prompts["lambda"]
pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts)
lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts)
pi_text_feats.requires_grad = False
lambda_text_feats.requires_grad = False
self.register_buffer("pi_text_feats", pi_text_feats)
self.register_buffer("lambda_text_feats", lambda_text_feats)
else:
text_feats = encode_text(model_name, weight_name, text_prompts)
text_feats.requires_grad = False
self.register_buffer("text_feats", text_feats)
def _build_head(self) -> None:
in_channels = self.backbone.in_features
out_channels = self.backbone.out_features
if self.zero_inflated:
self.pi_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
self.lambda_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
self.pi_head = conv1x1(in_channels, out_channels, bias=False)
self.lambda_head = conv1x1(in_channels, out_channels, bias=False)
else:
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
self.head = conv1x1(in_channels, out_channels, bias=False)
def forward(self, image: Tensor):
image_feats = self.backbone(image)
# image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
if self.zero_inflated:
pi_image_feats, lambda_image_feats = self.pi_head(image_feats), self.lambda_head(image_feats)
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
pi_text_feats, lambda_text_feats = self.pi_text_feats, self.lambda_text_feats
pi_logit_scale, lambda_logit_scale = self.pi_logit_scale.exp(), self.lambda_logit_scale.exp()
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
lambda_map = (lambda_logit_map.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
# pi_logit_map.softmax(dim=1)[:, 0] is the probability of zeros
den_map = pi_logit_map.softmax(dim=1)[:, 1:] * lambda_map # (B, 1, H, W)
if self.training:
return pi_logit_map, lambda_logit_map, lambda_map, den_map
else:
return den_map
else:
image_feats = self.head(image_feats)
image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1)
text_feats = self.text_feats
logit_scale = self.logit_scale.exp()
logit_map = logit_scale * image_feats @ text_feats.t() # (B, H, W, N), logits per image
logit_map = logit_map.permute(0, 3, 1, 2) # (B, N, H, W)
den_map = (logit_map.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True) # (B, 1, H, W)
if self.training:
return logit_map, den_map
else:
return den_map
def _clip_ebc(
model_name: str,
weight_name: str,
block_size: Optional[int] = None,
bins: Optional[List[Tuple[float, float]]] = None,
bin_centers: Optional[List[float]] = None,
zero_inflated: Optional[bool] = True,
num_vpt: Optional[int] = None,
vpt_drop: Optional[float] = None,
input_size: Optional[int] = None,
adapter: Optional[bool] = False,
adapter_reduction: Optional[int] = None,
lora: Optional[bool] = False,
lora_rank: Optional[int] = None,
lora_alpha: Optional[float] = None,
lora_dropout: Optional[float] = None,
text_prompts: Optional[List[str]] = None,
norm: Optional[str] = "none",
act: Optional[str] = "none",
) -> CLIP_EBC:
return CLIP_EBC(
model_name=model_name,
weight_name=weight_name,
block_size=block_size,
bins=bins,
bin_centers=bin_centers,
zero_inflated=zero_inflated,
num_vpt=num_vpt,
vpt_drop=vpt_drop,
input_size=input_size,
adapter=adapter,
adapter_reduction=adapter_reduction,
lora=lora,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
text_prompts=text_prompts,
norm=norm,
act=act,
)