pooyanrg's picture
initial commit
ad4721b
"""
Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py
"""
import warnings
from collections import OrderedDict
from typing import Tuple, Union, Optional
import hashlib
import os
import urllib
import warnings
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from .module_gated_attention import gated_coattention
from torch import nn
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
_PT_NAME = {
"RN50": "RN50.pt",
"RN101": "RN101.pt",
"RN50x4": "RN50x4.pt",
"RN50x16": "RN50x16.pt",
"ViT-B/32": "ViT-B-32.pt",
"ViT-B/16": "ViT-B-16.pt",
}
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def available_models():
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
# =============================
class TABAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
to :attr:`embed_dim` such that query, key, and value have the same
number of features.
Examples::
>>> multihead_attn = TABAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
This is a version of multihead attention written to comply with the defintion of TAB!!!
"""
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
super(TABAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.register_parameter('in_proj_weight', None)
else:
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
if bias:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias)
if add_bias_kv:
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old TABAttention checkpoints generated by v1.1.0
if '_qkv_same_embed_dim' not in state:
state['_qkv_same_embed_dim'] = True
super(TABAttention, self).__setstate__(state)
def forward(self, query: Tensor, key: Tensor, value: Tensor, gt_attention_map: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shapes for inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
source sequence length.
If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Shapes for outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
if not self._qkv_same_embed_dim:
return gated_coattention(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight.half(), self.in_proj_bias.half(),
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(),
training=self.training, gt_attention_map=gt_attention_map,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
else:
return gated_coattention(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight.half(), self.in_proj_bias.half(),
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(),
training=self.training, gt_attention_map=gt_attention_map,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask=None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
attn_mask_ = self.attn_mask
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
attn_mask_ = self.attn_mask(x.size(0)) # LND
attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
def forward(self, x_tuple:tuple):
x, video_frame = x_tuple
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return (x, video_frame)
def visualize_attention(self, x: torch.Tensor):
attn_outputs, attn_weights = self.attn(x, x, x, need_weights=True, attn_mask=None)
return attn_outputs, attn_weights
def visualize_forward(self, x_tuple:tuple):
x, video_frame = x_tuple
attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x))
x = x + attn_outputs
x = x + self.mlp(self.ln_2(x))
return (x, video_frame, attn_weights)
class TABLayer(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask=None):
super().__init__()
self.attn = TABAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor, y: torch.Tensor):
attn_mask_ = self.attn_mask
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
attn_mask_ = self.attn_mask(x.size(0)) # LND
attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
return self.attn(x, y, y, need_weights=False, attn_mask=attn_mask_)[0]
def forward(self, x: torch.Tensor, y: torch.Tensor):
x = self.attention(self.ln_1(x), self.ln_1(y))
x = x + self.mlp(self.ln_2(x))
return x
def visualize_attention(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map):
attn_outputs, attn_weights = self.attn(x, y, y, gt_attention_map=gt_attention_map, need_weights=True, attn_mask=None)
return attn_outputs, attn_weights
def visualize_forward(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map):
attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x), self.ln_1(y), gt_attention_map)
x = attn_outputs
x = x + self.mlp(self.ln_2(x))
return (x, attn_weights)
class visionTransformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) if i < (layers - 1) else TABLayer(width, 1, attn_mask) for i in range(layers)])
def forward(self, x: torch.Tensor, video_frame=-1):
return self.resblocks((x, video_frame))[0]
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor, video_frame=-1):
return self.resblocks((x, video_frame))[0]
class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
linear_patch: str = '2d', intra_layers: int = 9):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.intra_layers = intra_layers
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.joint_positional_embedding = nn.Parameter(scale * torch.randn(2 * ((input_resolution // patch_size) ** 2 + 1), width))
self.bef_embedding = nn.Parameter(scale * torch.randn(width))
self.aft_embedding = nn.Parameter(scale * torch.randn(width))
self.ln_mid = LayerNorm(width)
self.transformer = visionTransformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
# For 3D
assert linear_patch in ['2d', '3d']
self.linear_patch = linear_patch
if self.linear_patch == '3d':
self.conv2 = nn.Conv3d(in_channels=3, out_channels=width, kernel_size=(3, patch_size, patch_size),
stride=(1, patch_size, patch_size), padding=(1, 0, 0), bias=False)
def forward(self, x: torch.Tensor, left_gt_map, right_gt_map, video_frame=-1, visualize=False):
if self.linear_patch == '3d':
assert video_frame != -1
x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], x.shape[-1])
x_3d = x_3d.permute(0, 2, 1, 3, 4)
x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid]
x_3d = x_3d.permute(0, 2, 1, 3, 4) # shape = [*, frame, width, grid, grid]
x = x_3d.reshape(-1, x_3d.shape[-3], x_3d.shape[-2], x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid]
else:
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
if visualize is True:
all_attn_weights = []
for i in range(self.intra_layers):
x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame))
attn_weights = attn_weights.view(x.size(1) // video_frame, -1, attn_weights.size(-2),
attn_weights.size(-1))
all_attn_weights.append(attn_weights)
else:
for i in range(self.intra_layers):
x = self.transformer.resblocks[i]((x, video_frame))[0]
x = x.permute(1, 0, 2) # LND -> NLD
bs = x.size(0) // video_frame
x = x.view(bs, video_frame, x.size(-2), x.size(-1))
x = torch.cat([x[:, 0] + self.bef_embedding.to(x.dtype),
x[:, 1] + self.aft_embedding.to(x.dtype)], dim=1)
x = x + self.joint_positional_embedding.to(x.dtype)
x = self.ln_mid(x)
x = x.permute(1, 0, 2) # NLD -> LND
if visualize is True:
for i in range(self.intra_layers, self.transformer.layers - 1):
x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame))
all_attn_weights.append(attn_weights)
cls_index = int(x.size(0) / 2)
left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map)
right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map)
all_attn_weights.append(left_attn_weights)
all_attn_weights.append(right_attn_weights)
else:
for i in range(self.intra_layers, self.transformer.layers - 1):
x = self.transformer.resblocks[i]((x, video_frame))[0]
cls_index = int(x.size(0) / 2)
left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map)
right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map)
left_features = left_features.permute(1, 0, 2) # LND -> NLD
right_features = right_features.permute(1, 0, 2) # LND -> NLD
x = torch.cat([left_features, right_features], 1)
# Move the three lines below to `encode_image` for entire hidden sequence
# x = self.ln_post(x[:, 0, :])
# if self.proj is not None:
# x = x @ self.proj
if visualize is True:
return x, all_attn_weights
return x, left_attn_weights, right_attn_weights
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
# vision linear of patch
linear_patch: str = '2d',
intra_layers: int = 9,
):
super().__init__()
self.context_length = context_length
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
linear_patch=linear_patch,
intra_layers=intra_layers,
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
@staticmethod
def get_config(pretrained_clip_name="ViT-B/32"):
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt")
if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME:
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name])
if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path):
pass
else:
if pretrained_clip_name in _MODELS:
model_path = _download(_MODELS[pretrained_clip_name])
elif os.path.isfile(pretrained_clip_name):
model_path = pretrained_clip_name
else:
raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = model.state_dict()
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
return state_dict
def build_attention_mask(self, context_length):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.zeros(context_length, context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image, left_gt_map, right_gt_map, return_hidden=False, video_frame=-1):
hidden, left_map, right_map = self.visual(image.type(self.dtype), left_gt_map, right_gt_map, video_frame=video_frame)
hidden = self.visual.ln_post(hidden) @ self.visual.proj
cls_index = int(hidden.size(1) / 2)
hidden2 = torch.cat([hidden[:, 0, :].unsqueeze(1), hidden[:, cls_index, :].unsqueeze(1)], 1)
x = torch.mean(hidden2, 1)
if return_hidden:
return x, hidden2, left_map, right_map
return x, left_map, right_map
def encode_text(self, text, return_hidden=False):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype)
x = x + pos_emd
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
hidden = self.ln_final(x).type(self.dtype) @ self.text_projection
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)]
if return_hidden:
return x, hidden
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()