Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# Copyright 2020 Ross Wightman | |
# Modified Model definition | |
"""Video models.""" | |
import math | |
import torch | |
import torch.nn as nn | |
from einops import rearrange, repeat | |
from timm.layers import to_2tuple | |
from torch import einsum | |
from torch.nn import functional as F | |
default_cfgs = { | |
"vit_1k": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth", | |
"vit_1k_large": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth", | |
} | |
def qkv_attn(q, k, v, tok_mask: torch.Tensor = None): | |
sim = einsum("b i d, b j d -> b i j", q, k) | |
# apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N) | |
if tok_mask is not None: | |
BSH, N = tok_mask.shape | |
sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0, float("-inf")) # 1 - broadcasts across N | |
attn = sim.softmax(dim=-1) | |
out = einsum("b i j, b j d -> b i d", attn, v) | |
return out | |
class DividedAttention(nn.Module): | |
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = head_dim**-0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.proj = nn.Linear(dim, dim) | |
# init to zeros | |
self.qkv.weight.data.fill_(0) | |
self.qkv.bias.data.fill_(0) | |
self.proj.weight.data.fill_(1) | |
self.proj.bias.data.fill_(0) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): | |
# num of heads variable | |
h = self.num_heads | |
# project x to q, k, v vaalues | |
q, k, v = self.qkv(x).chunk(3, dim=-1) | |
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) | |
if tok_mask is not None: | |
# replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d | |
assert len(tok_mask.shape) == 2 | |
tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1]) | |
# Scale q | |
q *= self.scale | |
# Take out cls_q, cls_k, cls_v | |
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) | |
# the same for masking | |
if tok_mask is not None: | |
cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:] | |
else: | |
cls_mask, mask_ = None, None | |
# let CLS token attend to key / values of all patches across time and space | |
cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask) | |
# rearrange across time or space | |
q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_)) | |
# expand CLS token keys and values across time or space and concat | |
r = q_.shape[0] // cls_k.shape[0] | |
cls_k, cls_v = map(lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v)) | |
k_ = torch.cat((cls_k, k_), dim=1) | |
v_ = torch.cat((cls_v, v_), dim=1) | |
# the same for masking (if provided) | |
if tok_mask is not None: | |
# since mask does not have the latent dim (d), we need to remove it from einops dims | |
mask_ = rearrange(mask_, f"{einops_from} -> {einops_to}".replace(" d", ""), **einops_dims) | |
cls_mask = repeat(cls_mask, "b () -> (b r) ()", r=r) # expand cls_mask across time or space | |
mask_ = torch.cat((cls_mask, mask_), dim=1) | |
# attention | |
out = qkv_attn(q_, k_, v_, tok_mask=mask_) | |
# merge back time or space | |
out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims) | |
# concat back the cls token | |
out = torch.cat((cls_out, out), dim=1) | |
# merge back the heads | |
out = rearrange(out, "(b h) n d -> b n (h d)", h=h) | |
## to out | |
x = self.proj(out) | |
x = self.proj_drop(x) | |
return x | |
class DividedSpaceTimeBlock(nn.Module): | |
def __init__( | |
self, | |
dim=768, | |
num_heads=12, | |
attn_type="divided", | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
self.einops_from_space = "b (f n) d" | |
self.einops_to_space = "(b f) n d" | |
self.einops_from_time = "b (f n) d" | |
self.einops_to_time = "(b n) f d" | |
self.norm1 = norm_layer(dim) | |
self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
self.timeattn = DividedAttention( | |
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop | |
) | |
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.drop_path = nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
self.norm3 = norm_layer(dim) | |
def forward(self, x, seq_len=196, num_frames=8, approx="none", num_landmarks=128, tok_mask: torch.Tensor = None): | |
time_output = self.timeattn( | |
self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask | |
) | |
time_residual = x + time_output | |
space_output = self.attn( | |
self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask | |
) | |
space_residual = time_residual + self.drop_path(space_output) | |
x = space_residual | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class Mlp(nn.Module): | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class PatchEmbed(nn.Module): | |
"""Image to Patch Embedding""" | |
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | |
super().__init__() | |
img_size = img_size if type(img_size) is tuple else to_2tuple(img_size) | |
patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size) | |
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.num_patches = num_patches | |
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
def forward(self, x): | |
B, C, H, W = x.shape | |
x = self.proj(x).flatten(2).transpose(1, 2) | |
return x | |
class PatchEmbed3D(nn.Module): | |
"""Image to Patch Embedding""" | |
def __init__( | |
self, | |
img_size=224, | |
temporal_resolution=4, | |
in_chans=3, | |
patch_size=16, | |
z_block_size=2, | |
embed_dim=768, | |
flatten=True, | |
): | |
super().__init__() | |
self.height = img_size // patch_size | |
self.width = img_size // patch_size | |
### v-iashin: these two are incorrect | |
# self.frames = (temporal_resolution // z_block_size) | |
# self.num_patches = self.height * self.width * self.frames | |
self.z_block_size = z_block_size | |
### | |
self.proj = nn.Conv3d( | |
in_chans, | |
embed_dim, | |
kernel_size=(z_block_size, patch_size, patch_size), | |
stride=(z_block_size, patch_size, patch_size), | |
) | |
self.flatten = flatten | |
def forward(self, x): | |
B, C, T, H, W = x.shape | |
x = self.proj(x) | |
if self.flatten: | |
x = x.flatten(2).transpose(1, 2) | |
return x | |
class HeadMLP(nn.Module): | |
def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): | |
super(HeadMLP, self).__init__() | |
self.n_input = n_input | |
self.n_classes = n_classes | |
self.n_hidden = n_hidden | |
if n_hidden is None: | |
# use linear classifier | |
self.block_forward = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True)) | |
else: | |
# use simple MLP classifier | |
self.block_forward = nn.Sequential( | |
nn.Dropout(p=p), | |
nn.Linear(n_input, n_hidden, bias=True), | |
nn.BatchNorm1d(n_hidden), | |
nn.ReLU(inplace=True), | |
nn.Dropout(p=p), | |
nn.Linear(n_hidden, n_classes, bias=True), | |
) | |
print(f"Dropout-NLP: {p}") | |
def forward(self, x): | |
return self.block_forward(x) | |
def _conv_filter(state_dict, patch_size=16): | |
"""convert patch embedding weight from manual patchify + linear proj to conv""" | |
out_dict = {} | |
for k, v in state_dict.items(): | |
if "patch_embed.proj.weight" in k: | |
v = v.reshape((v.shape[0], 3, patch_size, patch_size)) | |
out_dict[k] = v | |
return out_dict | |
def adapt_input_conv(in_chans, conv_weight, agg="sum"): | |
conv_type = conv_weight.dtype | |
conv_weight = conv_weight.float() | |
O, I, J, K = conv_weight.shape | |
if in_chans == 1: | |
if I > 3: | |
assert conv_weight.shape[1] % 3 == 0 | |
# For models with space2depth stems | |
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) | |
conv_weight = conv_weight.sum(dim=2, keepdim=False) | |
else: | |
if agg == "sum": | |
print("Summing conv1 weights") | |
conv_weight = conv_weight.sum(dim=1, keepdim=True) | |
else: | |
print("Averaging conv1 weights") | |
conv_weight = conv_weight.mean(dim=1, keepdim=True) | |
elif in_chans != 3: | |
if I != 3: | |
raise NotImplementedError("Weight format not supported by conversion.") | |
else: | |
if agg == "sum": | |
print("Summing conv1 weights") | |
repeat = int(math.ceil(in_chans / 3)) | |
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] | |
conv_weight *= 3 / float(in_chans) | |
else: | |
print("Averaging conv1 weights") | |
conv_weight = conv_weight.mean(dim=1, keepdim=True) | |
conv_weight = conv_weight.repeat(1, in_chans, 1, 1) | |
conv_weight = conv_weight.to(conv_type) | |
return conv_weight | |
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): | |
# Load state dict | |
assert f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]" | |
state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS]) | |
if filter_fn is not None: | |
state_dict = filter_fn(state_dict) | |
input_convs = "patch_embed.proj" | |
if input_convs is not None and in_chans != 3: | |
if isinstance(input_convs, str): | |
input_convs = (input_convs,) | |
for input_conv_name in input_convs: | |
weight_name = input_conv_name + ".weight" | |
try: | |
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name], agg="avg") | |
print(f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)") | |
except NotImplementedError as e: | |
del state_dict[weight_name] | |
strict = False | |
print(f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer.") | |
classifier_name = "head" | |
label_offset = cfg.get("label_offset", 0) | |
pretrain_classes = 1000 | |
if num_classes != pretrain_classes: | |
# completely discard fully connected if model num_classes doesn't match pretrained weights | |
del state_dict[classifier_name + ".weight"] | |
del state_dict[classifier_name + ".bias"] | |
strict = False | |
elif label_offset > 0: | |
# special case for pretrained weights with an extra background class in pretrained weights | |
classifier_weight = state_dict[classifier_name + ".weight"] | |
state_dict[classifier_name + ".weight"] = classifier_weight[label_offset:] | |
classifier_bias = state_dict[classifier_name + ".bias"] | |
state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:] | |
loaded_state = state_dict | |
self_state = model.state_dict() | |
all_names = set(self_state.keys()) | |
saved_names = set([]) | |
for name, param in loaded_state.items(): | |
param = param | |
if "module." in name: | |
name = name.replace("module.", "") | |
if name in self_state.keys() and param.shape == self_state[name].shape: | |
saved_names.add(name) | |
self_state[name].copy_(param) | |
else: | |
print(f"didnt load: {name} of shape: {param.shape}") | |
print("Missing Keys:") | |
print(all_names - saved_names) | |