Spaces:
Running
Running
""" | |
Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py | |
""" | |
import warnings | |
from collections import OrderedDict | |
from typing import Tuple, Union, Optional | |
import hashlib | |
import os | |
import urllib | |
import warnings | |
from tqdm import tqdm | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear | |
from torch.nn.init import xavier_uniform_ | |
from torch.nn.init import constant_ | |
from torch.nn.init import xavier_normal_ | |
from torch.nn.parameter import Parameter | |
from torch.nn.modules.module import Module | |
from .module_gated_attention import gated_coattention | |
from torch import nn | |
_MODELS = { | |
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", | |
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", | |
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", | |
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", | |
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", | |
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", | |
} | |
_PT_NAME = { | |
"RN50": "RN50.pt", | |
"RN101": "RN101.pt", | |
"RN50x4": "RN50x4.pt", | |
"RN50x16": "RN50x16.pt", | |
"ViT-B/32": "ViT-B-32.pt", | |
"ViT-B/16": "ViT-B-16.pt", | |
} | |
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): | |
os.makedirs(root, exist_ok=True) | |
filename = os.path.basename(url) | |
expected_sha256 = url.split("/")[-2] | |
download_target = os.path.join(root, filename) | |
if os.path.exists(download_target) and not os.path.isfile(download_target): | |
raise RuntimeError(f"{download_target} exists and is not a regular file") | |
if os.path.isfile(download_target): | |
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: | |
return download_target | |
else: | |
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") | |
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: | |
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: | |
while True: | |
buffer = source.read(8192) | |
if not buffer: | |
break | |
output.write(buffer) | |
loop.update(len(buffer)) | |
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: | |
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") | |
return download_target | |
def available_models(): | |
"""Returns the names of available CLIP models""" | |
return list(_MODELS.keys()) | |
# ============================= | |
class TABAttention(Module): | |
r"""Allows the model to jointly attend to information | |
from different representation subspaces. | |
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ | |
.. math:: | |
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O | |
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. | |
Args: | |
embed_dim: total dimension of the model. | |
num_heads: parallel attention heads. | |
dropout: a Dropout layer on attn_output_weights. Default: 0.0. | |
bias: add bias as module parameter. Default: True. | |
add_bias_kv: add bias to the key and value sequences at dim=0. | |
add_zero_attn: add a new batch of zeros to the key and | |
value sequences at dim=1. | |
kdim: total number of features in key. Default: None. | |
vdim: total number of features in value. Default: None. | |
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set | |
to :attr:`embed_dim` such that query, key, and value have the same | |
number of features. | |
Examples:: | |
>>> multihead_attn = TABAttention(embed_dim, num_heads) | |
>>> attn_output, attn_output_weights = multihead_attn(query, key, value) | |
This is a version of multihead attention written to comply with the defintion of TAB!!! | |
""" | |
bias_k: Optional[torch.Tensor] | |
bias_v: Optional[torch.Tensor] | |
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): | |
super(TABAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.kdim = kdim if kdim is not None else embed_dim | |
self.vdim = vdim if vdim is not None else embed_dim | |
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.head_dim = embed_dim // num_heads | |
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |
if self._qkv_same_embed_dim is False: | |
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) | |
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) | |
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) | |
self.register_parameter('in_proj_weight', None) | |
else: | |
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) | |
self.register_parameter('q_proj_weight', None) | |
self.register_parameter('k_proj_weight', None) | |
self.register_parameter('v_proj_weight', None) | |
if bias: | |
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) | |
else: | |
self.register_parameter('in_proj_bias', None) | |
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias) | |
if add_bias_kv: | |
self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) | |
self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) | |
else: | |
self.bias_k = self.bias_v = None | |
self.add_zero_attn = add_zero_attn | |
self._reset_parameters() | |
def _reset_parameters(self): | |
if self._qkv_same_embed_dim: | |
xavier_uniform_(self.in_proj_weight) | |
else: | |
xavier_uniform_(self.q_proj_weight) | |
xavier_uniform_(self.k_proj_weight) | |
xavier_uniform_(self.v_proj_weight) | |
if self.in_proj_bias is not None: | |
constant_(self.in_proj_bias, 0.) | |
constant_(self.out_proj.bias, 0.) | |
if self.bias_k is not None: | |
xavier_normal_(self.bias_k) | |
if self.bias_v is not None: | |
xavier_normal_(self.bias_v) | |
def __setstate__(self, state): | |
# Support loading old TABAttention checkpoints generated by v1.1.0 | |
if '_qkv_same_embed_dim' not in state: | |
state['_qkv_same_embed_dim'] = True | |
super(TABAttention, self).__setstate__(state) | |
def forward(self, query: Tensor, key: Tensor, value: Tensor, gt_attention_map: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, | |
need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: | |
r""" | |
Args: | |
query, key, value: map a query and a set of key-value pairs to an output. | |
See "Attention Is All You Need" for more details. | |
key_padding_mask: if provided, specified padding elements in the key will | |
be ignored by the attention. When given a binary mask and a value is True, | |
the corresponding value on the attention layer will be ignored. When given | |
a byte mask and a value is non-zero, the corresponding value on the attention | |
layer will be ignored | |
need_weights: output attn_output_weights. | |
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all | |
the batches while a 3D mask allows to specify a different mask for the entries of each batch. | |
Shapes for inputs: | |
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is | |
the embedding dimension. | |
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is | |
the embedding dimension. | |
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is | |
the embedding dimension. | |
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. | |
If a ByteTensor is provided, the non-zero positions will be ignored while the position | |
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the | |
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. | |
- attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the | |
source sequence length. | |
If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence | |
length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend | |
the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend | |
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` | |
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor | |
is provided, it will be added to the attention weight. | |
Shapes for outputs: | |
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, | |
E is the embedding dimension. | |
- attn_output_weights: :math:`(N, L, S)` where N is the batch size, | |
L is the target sequence length, S is the source sequence length. | |
""" | |
if not self._qkv_same_embed_dim: | |
return gated_coattention( | |
query, key, value, self.embed_dim, self.num_heads, | |
self.in_proj_weight.half(), self.in_proj_bias.half(), | |
self.bias_k, self.bias_v, self.add_zero_attn, | |
self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(), | |
training=self.training, gt_attention_map=gt_attention_map, | |
key_padding_mask=key_padding_mask, need_weights=need_weights, | |
attn_mask=attn_mask, use_separate_proj_weight=True, | |
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, | |
v_proj_weight=self.v_proj_weight) | |
else: | |
return gated_coattention( | |
query, key, value, self.embed_dim, self.num_heads, | |
self.in_proj_weight.half(), self.in_proj_bias.half(), | |
self.bias_k, self.bias_v, self.add_zero_attn, | |
self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(), | |
training=self.training, gt_attention_map=gt_attention_map, | |
key_padding_mask=key_padding_mask, need_weights=need_weights, | |
attn_mask=attn_mask) | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor): | |
orig_type = x.dtype | |
ret = super().forward(x.type(torch.float32)) | |
return ret.type(orig_type) | |
class QuickGELU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, d_model: int, n_head: int, attn_mask=None): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential(OrderedDict([ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("gelu", QuickGELU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)) | |
])) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
def attention(self, x: torch.Tensor): | |
attn_mask_ = self.attn_mask | |
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): | |
attn_mask_ = self.attn_mask(x.size(0)) # LND | |
attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None | |
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] | |
def forward(self, x_tuple:tuple): | |
x, video_frame = x_tuple | |
x = x + self.attention(self.ln_1(x)) | |
x = x + self.mlp(self.ln_2(x)) | |
return (x, video_frame) | |
def visualize_attention(self, x: torch.Tensor): | |
attn_outputs, attn_weights = self.attn(x, x, x, need_weights=True, attn_mask=None) | |
return attn_outputs, attn_weights | |
def visualize_forward(self, x_tuple:tuple): | |
x, video_frame = x_tuple | |
attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x)) | |
x = x + attn_outputs | |
x = x + self.mlp(self.ln_2(x)) | |
return (x, video_frame, attn_weights) | |
class TABLayer(nn.Module): | |
def __init__(self, d_model: int, n_head: int, attn_mask=None): | |
super().__init__() | |
self.attn = TABAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential(OrderedDict([ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("gelu", QuickGELU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)) | |
])) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
def attention(self, x: torch.Tensor, y: torch.Tensor): | |
attn_mask_ = self.attn_mask | |
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'): | |
attn_mask_ = self.attn_mask(x.size(0)) # LND | |
attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None | |
return self.attn(x, y, y, need_weights=False, attn_mask=attn_mask_)[0] | |
def forward(self, x: torch.Tensor, y: torch.Tensor): | |
x = self.attention(self.ln_1(x), self.ln_1(y)) | |
x = x + self.mlp(self.ln_2(x)) | |
return x | |
def visualize_attention(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map): | |
attn_outputs, attn_weights = self.attn(x, y, y, gt_attention_map=gt_attention_map, need_weights=True, attn_mask=None) | |
return attn_outputs, attn_weights | |
def visualize_forward(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map): | |
attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x), self.ln_1(y), gt_attention_map) | |
x = attn_outputs | |
x = x + self.mlp(self.ln_2(x)) | |
return (x, attn_weights) | |
class visionTransformer(nn.Module): | |
def __init__(self, width: int, layers: int, heads: int, attn_mask = None): | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) if i < (layers - 1) else TABLayer(width, 1, attn_mask) for i in range(layers)]) | |
def forward(self, x: torch.Tensor, video_frame=-1): | |
return self.resblocks((x, video_frame))[0] | |
class Transformer(nn.Module): | |
def __init__(self, width: int, layers: int, heads: int, attn_mask = None): | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) | |
def forward(self, x: torch.Tensor, video_frame=-1): | |
return self.resblocks((x, video_frame))[0] | |
class VisualTransformer(nn.Module): | |
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, | |
linear_patch: str = '2d', intra_layers: int = 9): | |
super().__init__() | |
self.input_resolution = input_resolution | |
self.output_dim = output_dim | |
self.intra_layers = intra_layers | |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) | |
scale = width ** -0.5 | |
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) | |
self.ln_pre = LayerNorm(width) | |
self.joint_positional_embedding = nn.Parameter(scale * torch.randn(2 * ((input_resolution // patch_size) ** 2 + 1), width)) | |
self.bef_embedding = nn.Parameter(scale * torch.randn(width)) | |
self.aft_embedding = nn.Parameter(scale * torch.randn(width)) | |
self.ln_mid = LayerNorm(width) | |
self.transformer = visionTransformer(width, layers, heads) | |
self.ln_post = LayerNorm(width) | |
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |
# For 3D | |
assert linear_patch in ['2d', '3d'] | |
self.linear_patch = linear_patch | |
if self.linear_patch == '3d': | |
self.conv2 = nn.Conv3d(in_channels=3, out_channels=width, kernel_size=(3, patch_size, patch_size), | |
stride=(1, patch_size, patch_size), padding=(1, 0, 0), bias=False) | |
def forward(self, x: torch.Tensor, left_gt_map, right_gt_map, video_frame=-1, visualize=False): | |
if self.linear_patch == '3d': | |
assert video_frame != -1 | |
x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], x.shape[-1]) | |
x_3d = x_3d.permute(0, 2, 1, 3, 4) | |
x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid] | |
x_3d = x_3d.permute(0, 2, 1, 3, 4) # shape = [*, frame, width, grid, grid] | |
x = x_3d.reshape(-1, x_3d.shape[-3], x_3d.shape[-2], x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid] | |
else: | |
x = self.conv1(x) # shape = [*, width, grid, grid] | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] | |
x = x + self.positional_embedding.to(x.dtype) | |
x = self.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
if visualize is True: | |
all_attn_weights = [] | |
for i in range(self.intra_layers): | |
x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame)) | |
attn_weights = attn_weights.view(x.size(1) // video_frame, -1, attn_weights.size(-2), | |
attn_weights.size(-1)) | |
all_attn_weights.append(attn_weights) | |
else: | |
for i in range(self.intra_layers): | |
x = self.transformer.resblocks[i]((x, video_frame))[0] | |
x = x.permute(1, 0, 2) # LND -> NLD | |
bs = x.size(0) // video_frame | |
x = x.view(bs, video_frame, x.size(-2), x.size(-1)) | |
x = torch.cat([x[:, 0] + self.bef_embedding.to(x.dtype), | |
x[:, 1] + self.aft_embedding.to(x.dtype)], dim=1) | |
x = x + self.joint_positional_embedding.to(x.dtype) | |
x = self.ln_mid(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
if visualize is True: | |
for i in range(self.intra_layers, self.transformer.layers - 1): | |
x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame)) | |
all_attn_weights.append(attn_weights) | |
cls_index = int(x.size(0) / 2) | |
left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map) | |
right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map) | |
all_attn_weights.append(left_attn_weights) | |
all_attn_weights.append(right_attn_weights) | |
else: | |
for i in range(self.intra_layers, self.transformer.layers - 1): | |
x = self.transformer.resblocks[i]((x, video_frame))[0] | |
cls_index = int(x.size(0) / 2) | |
left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map) | |
right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map) | |
left_features = left_features.permute(1, 0, 2) # LND -> NLD | |
right_features = right_features.permute(1, 0, 2) # LND -> NLD | |
x = torch.cat([left_features, right_features], 1) | |
# Move the three lines below to `encode_image` for entire hidden sequence | |
# x = self.ln_post(x[:, 0, :]) | |
# if self.proj is not None: | |
# x = x @ self.proj | |
if visualize is True: | |
return x, all_attn_weights | |
return x, left_attn_weights, right_attn_weights | |
class CLIP(nn.Module): | |
def __init__(self, | |
embed_dim: int, | |
# vision | |
image_resolution: int, | |
vision_layers: Union[Tuple[int, int, int, int], int], | |
vision_width: int, | |
vision_patch_size: int, | |
# text | |
context_length: int, | |
vocab_size: int, | |
transformer_width: int, | |
transformer_heads: int, | |
transformer_layers: int, | |
# vision linear of patch | |
linear_patch: str = '2d', | |
intra_layers: int = 9, | |
): | |
super().__init__() | |
self.context_length = context_length | |
vision_heads = vision_width // 64 | |
self.visual = VisualTransformer( | |
input_resolution=image_resolution, | |
patch_size=vision_patch_size, | |
width=vision_width, | |
layers=vision_layers, | |
heads=vision_heads, | |
output_dim=embed_dim, | |
linear_patch=linear_patch, | |
intra_layers=intra_layers, | |
) | |
self.transformer = Transformer( | |
width=transformer_width, | |
layers=transformer_layers, | |
heads=transformer_heads, | |
attn_mask=self.build_attention_mask | |
) | |
self.vocab_size = vocab_size | |
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) | |
self.ln_final = LayerNorm(transformer_width) | |
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) | |
self.logit_scale = nn.Parameter(torch.ones([])) | |
self.initialize_parameters() | |
def initialize_parameters(self): | |
nn.init.normal_(self.token_embedding.weight, std=0.02) | |
nn.init.normal_(self.positional_embedding, std=0.01) | |
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) | |
attn_std = self.transformer.width ** -0.5 | |
fc_std = (2 * self.transformer.width) ** -0.5 | |
for block in self.transformer.resblocks: | |
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |
if self.text_projection is not None: | |
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) | |
def get_config(pretrained_clip_name="ViT-B/32"): | |
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt") | |
if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME: | |
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name]) | |
if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path): | |
pass | |
else: | |
if pretrained_clip_name in _MODELS: | |
model_path = _download(_MODELS[pretrained_clip_name]) | |
elif os.path.isfile(pretrained_clip_name): | |
model_path = pretrained_clip_name | |
else: | |
raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}") | |
try: | |
# loading JIT archive | |
model = torch.jit.load(model_path, map_location="cpu").eval() | |
state_dict = model.state_dict() | |
except RuntimeError: | |
state_dict = torch.load(model_path, map_location="cpu") | |
return state_dict | |
def build_attention_mask(self, context_length): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.zeros(context_length, context_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def dtype(self): | |
return self.visual.conv1.weight.dtype | |
def encode_image(self, image, left_gt_map, right_gt_map, return_hidden=False, video_frame=-1): | |
hidden, left_map, right_map = self.visual(image.type(self.dtype), left_gt_map, right_gt_map, video_frame=video_frame) | |
hidden = self.visual.ln_post(hidden) @ self.visual.proj | |
cls_index = int(hidden.size(1) / 2) | |
hidden2 = torch.cat([hidden[:, 0, :].unsqueeze(1), hidden[:, cls_index, :].unsqueeze(1)], 1) | |
x = torch.mean(hidden2, 1) | |
if return_hidden: | |
return x, hidden2, left_map, right_map | |
return x, left_map, right_map | |
def encode_text(self, text, return_hidden=False): | |
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] | |
pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) | |
x = x + pos_emd | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
hidden = self.ln_final(x).type(self.dtype) @ self.text_projection | |
# x.shape = [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] | |
if return_hidden: | |
return x, hidden | |
return x | |
def forward(self, image, text): | |
image_features = self.encode_image(image) | |
text_features = self.encode_text(text) | |
# normalized features | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logit_scale * text_features @ image_features.t() | |
# shape = [global_batch_size, global_batch_size] | |
return logits_per_image, logits_per_text | |
def convert_weights(model: nn.Module): | |
"""Convert applicable model parameters to fp16""" | |
def _convert_weights_to_fp16(l): | |
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): | |
l.weight.data = l.weight.data.half() | |
if l.bias is not None: | |
l.bias.data = l.bias.data.half() | |
if isinstance(l, nn.MultiheadAttention): | |
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: | |
tensor = getattr(l, attr) | |
if tensor is not None: | |
tensor.data = tensor.data.half() | |
for name in ["text_projection", "proj"]: | |
if hasattr(l, name): | |
attr = getattr(l, name) | |
if attr is not None: | |
attr.data = attr.data.half() | |
model.apply(_convert_weights_to_fp16) | |
def build_model(state_dict: dict): | |
vit = "visual.proj" in state_dict | |
if vit: | |
vision_width = state_dict["visual.conv1.weight"].shape[0] | |
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) | |
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] | |
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) | |
image_resolution = vision_patch_size * grid_size | |
else: | |
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] | |
vision_layers = tuple(counts) | |
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] | |
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) | |
vision_patch_size = None | |
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] | |
image_resolution = output_width * 32 | |
embed_dim = state_dict["text_projection"].shape[1] | |
context_length = state_dict["positional_embedding"].shape[0] | |
vocab_size = state_dict["token_embedding.weight"].shape[0] | |
transformer_width = state_dict["ln_final.weight"].shape[0] | |
transformer_heads = transformer_width // 64 | |
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) | |
model = CLIP( | |
embed_dim, | |
image_resolution, vision_layers, vision_width, vision_patch_size, | |
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers | |
) | |
for key in ["input_resolution", "context_length", "vocab_size"]: | |
if key in state_dict: | |
del state_dict[key] | |
convert_weights(model) | |
model.load_state_dict(state_dict) | |
return model.eval() | |