Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
import numpy as np | |
from typing import List, Optional, Dict, Tuple | |
from copy import deepcopy | |
from .vit import vit_names_and_weights, _vit | |
from .convnext import convnext_names_and_weights, _convnext | |
from .resnet import resnet_names_and_weights, _resnet | |
from .mobileclip import mobileclip_names_and_weights, _mobileclip | |
from .utils import encode_text, optimize_text_prompts | |
from ..utils import conv1x1 | |
supported_models_and_weights = deepcopy(vit_names_and_weights) | |
supported_models_and_weights.update(convnext_names_and_weights) | |
supported_models_and_weights.update(resnet_names_and_weights) | |
supported_models_and_weights.update(mobileclip_names_and_weights) | |
class CLIP_EBC(nn.Module): | |
def __init__( | |
self, | |
model_name: str, | |
weight_name: str, | |
block_size: Optional[int] = None, | |
bins: Optional[List[Tuple[float, float]]] = None, | |
bin_centers: Optional[List[float]] = None, | |
zero_inflated: Optional[bool] = True, | |
num_vpt: Optional[int] = None, | |
vpt_drop: Optional[float] = None, | |
input_size: Optional[int] = None, | |
adapter: Optional[bool] = False, | |
adapter_reduction: Optional[int] = None, | |
lora: Optional[bool] = False, | |
lora_rank: Optional[int] = None, | |
lora_alpha: Optional[float] = None, | |
lora_dropout: Optional[float] = None, | |
text_prompts: Optional[Dict[str, List[str]]] = None, | |
norm: Optional[str] = "none", | |
act: Optional[str] = "none", | |
) -> None: | |
super().__init__() | |
if "mobileclip" in model_name.lower() or "vit" in model_name.lower(): | |
model_name = model_name.replace("_", "-") | |
assert model_name in supported_models_and_weights, f"Model name should be one of {list(supported_models_and_weights.keys())}, but got {model_name}." | |
assert weight_name in supported_models_and_weights[model_name], f"Pretrained should be one of {supported_models_and_weights[model_name]}, but got {weight_name}." | |
assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}" | |
assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}" | |
assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}" | |
bins = [(float(b[0]), float(b[1])) for b in bins] | |
assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}" | |
self.model_name = model_name | |
self.weight_name = weight_name | |
self.block_size = block_size | |
self.bins = bins | |
self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1)) | |
self.zero_inflated = zero_inflated | |
self.text_prompts = text_prompts | |
# Image encoder | |
if model_name in vit_names_and_weights: | |
assert num_vpt is not None and num_vpt >= 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}." | |
vpt_drop = 0. if vpt_drop is None else vpt_drop | |
self.backbone = _vit( | |
model_name=model_name, | |
weight_name=weight_name, | |
num_vpt=num_vpt, | |
vpt_drop=vpt_drop, | |
block_size=block_size, | |
adapter=adapter, | |
adapter_reduction=adapter_reduction, | |
lora=lora, | |
lora_rank=lora_rank, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
input_size=(input_size, input_size), | |
norm=norm, | |
act=act | |
) | |
elif model_name in convnext_names_and_weights: | |
self.backbone = _convnext( | |
model_name=model_name, | |
weight_name=weight_name, | |
block_size=block_size, | |
adapter=adapter, | |
adapter_reduction=adapter_reduction, | |
lora=lora, | |
lora_rank=lora_rank, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
norm=norm, | |
act=act | |
) | |
elif model_name in resnet_names_and_weights: | |
self.backbone = _resnet( | |
model_name=model_name, | |
weight_name=weight_name, | |
block_size=block_size, | |
adapter=adapter, | |
adapter_reduction=adapter_reduction, | |
lora=lora, | |
lora_rank=lora_rank, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
norm=norm, | |
act=act | |
) | |
elif model_name in mobileclip_names_and_weights: | |
self.backbone = _mobileclip( | |
model_name=model_name, | |
weight_name=weight_name, | |
block_size=block_size, | |
adapter=adapter, | |
adapter_reduction=adapter_reduction, | |
lora=lora, | |
lora_rank=lora_rank, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
norm=norm, | |
act=act | |
) | |
self._build_text_feats() | |
self._build_head() | |
def _build_text_feats(self) -> None: | |
model_name, weight_name = self.model_name, self.weight_name | |
text_prompts = self.text_prompts | |
if text_prompts is None: | |
bins = [b[0] if b[0] == b[1] else b for b in self.bins] # if the bin is a single value (e.g., [0, 0]), use that value | |
if self.zero_inflated: # separate 0 from the rest | |
assert bins[0] == 0, f"Expected the first bin to be 0, got {bins[0]}." | |
bins_pi = [0, (1, float("inf"))] | |
bins_lambda = bins[1:] | |
pi_text_prompts = optimize_text_prompts(model_name, weight_name, bins_pi) | |
lambda_text_prompts = optimize_text_prompts(model_name, weight_name, bins_lambda) | |
self.text_prompts = {"pi": pi_text_prompts, "lambda": lambda_text_prompts} | |
pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts) | |
lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts) | |
pi_text_feats.requires_grad = False | |
lambda_text_feats.requires_grad = False | |
self.register_buffer("pi_text_feats", pi_text_feats) | |
self.register_buffer("lambda_text_feats", lambda_text_feats) | |
else: | |
text_prompts = optimize_text_prompts(model_name, weight_name, bins) | |
self.text_prompts = text_prompts | |
text_feats = encode_text(model_name, weight_name, text_prompts) | |
text_feats.requires_grad = False | |
self.register_buffer("text_feats", text_feats) | |
else: | |
if self.zero_inflated: | |
assert "pi" in text_prompts and "lambda" in text_prompts, f"Expected text_prompts to have keys 'pi' and 'lambda', got {text_prompts.keys()}." | |
pi_text_prompts = text_prompts["pi"] | |
lambda_text_prompts = text_prompts["lambda"] | |
pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts) | |
lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts) | |
pi_text_feats.requires_grad = False | |
lambda_text_feats.requires_grad = False | |
self.register_buffer("pi_text_feats", pi_text_feats) | |
self.register_buffer("lambda_text_feats", lambda_text_feats) | |
else: | |
text_feats = encode_text(model_name, weight_name, text_prompts) | |
text_feats.requires_grad = False | |
self.register_buffer("text_feats", text_feats) | |
def _build_head(self) -> None: | |
in_channels = self.backbone.in_features | |
out_channels = self.backbone.out_features | |
if self.zero_inflated: | |
self.pi_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) | |
self.lambda_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) | |
self.pi_head = conv1x1(in_channels, out_channels, bias=False) | |
self.lambda_head = conv1x1(in_channels, out_channels, bias=False) | |
else: | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) | |
self.head = conv1x1(in_channels, out_channels, bias=False) | |
def forward(self, image: Tensor): | |
image_feats = self.backbone(image) | |
# image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
if self.zero_inflated: | |
pi_image_feats, lambda_image_feats = self.pi_head(image_feats), self.lambda_head(image_feats) | |
pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
pi_text_feats, lambda_text_feats = self.pi_text_feats, self.lambda_text_feats | |
pi_logit_scale, lambda_logit_scale = self.pi_logit_scale.exp(), self.lambda_logit_scale.exp() | |
pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image | |
lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image | |
pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W) | |
lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W) | |
lambda_map = (lambda_logit_map.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
# pi_logit_map.softmax(dim=1)[:, 0] is the probability of zeros | |
den_map = pi_logit_map.softmax(dim=1)[:, 1:] * lambda_map # (B, 1, H, W) | |
if self.training: | |
return pi_logit_map, lambda_logit_map, lambda_map, den_map | |
else: | |
return den_map | |
else: | |
image_feats = self.head(image_feats) | |
image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) | |
text_feats = self.text_feats | |
logit_scale = self.logit_scale.exp() | |
logit_map = logit_scale * image_feats @ text_feats.t() # (B, H, W, N), logits per image | |
logit_map = logit_map.permute(0, 3, 1, 2) # (B, N, H, W) | |
den_map = (logit_map.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
if self.training: | |
return logit_map, den_map | |
else: | |
return den_map | |
def _clip_ebc( | |
model_name: str, | |
weight_name: str, | |
block_size: Optional[int] = None, | |
bins: Optional[List[Tuple[float, float]]] = None, | |
bin_centers: Optional[List[float]] = None, | |
zero_inflated: Optional[bool] = True, | |
num_vpt: Optional[int] = None, | |
vpt_drop: Optional[float] = None, | |
input_size: Optional[int] = None, | |
adapter: Optional[bool] = False, | |
adapter_reduction: Optional[int] = None, | |
lora: Optional[bool] = False, | |
lora_rank: Optional[int] = None, | |
lora_alpha: Optional[float] = None, | |
lora_dropout: Optional[float] = None, | |
text_prompts: Optional[List[str]] = None, | |
norm: Optional[str] = "none", | |
act: Optional[str] = "none", | |
) -> CLIP_EBC: | |
return CLIP_EBC( | |
model_name=model_name, | |
weight_name=weight_name, | |
block_size=block_size, | |
bins=bins, | |
bin_centers=bin_centers, | |
zero_inflated=zero_inflated, | |
num_vpt=num_vpt, | |
vpt_drop=vpt_drop, | |
input_size=input_size, | |
adapter=adapter, | |
adapter_reduction=adapter_reduction, | |
lora=lora, | |
lora_rank=lora_rank, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
text_prompts=text_prompts, | |
norm=norm, | |
act=act, | |
) |