|
|
|
|
|
|
|
"""Video models.""" |
|
|
|
import math |
|
from functools import partial |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.init import trunc_normal_ |
|
|
|
|
|
|
|
from .attention import MultiScaleBlock |
|
|
|
from .common import TwoStreamFusion |
|
from .reversible_mvit import ReversibleMViT |
|
from .utils import ( |
|
calc_mvit_feature_geometry, |
|
get_3d_sincos_pos_embed, |
|
round_width, |
|
validate_checkpoint_wrapper_import, |
|
) |
|
|
|
from . import head_helper, stem_helper |
|
|
|
|
|
class MViT(nn.Module): |
|
""" |
|
Model builder for MViTv1 and MViTv2. |
|
|
|
"MViTv2: Improved Multiscale Vision Transformers for Classification and Detection" |
|
Yanghao Li, Chao-Yuan Wu, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, Christoph Feichtenhofer |
|
https://arxiv.org/abs/2112.01526 |
|
"Multiscale Vision Transformers" |
|
Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik, Christoph Feichtenhofer |
|
https://arxiv.org/abs/2104.11227 |
|
""" |
|
|
|
def __init__(self, cfg): |
|
super().__init__() |
|
|
|
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE |
|
self.cfg = cfg |
|
pool_first = cfg.MVIT.POOL_FIRST |
|
|
|
spatial_size = cfg.DATA.TRAIN_CROP_SIZE |
|
temporal_size = cfg.DATA.NUM_FRAMES |
|
in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0] |
|
self.use_2d_patch = cfg.MVIT.PATCH_2D |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.enable_rev = cfg.MVIT.REV.ENABLE |
|
self.patch_stride = cfg.MVIT.PATCH_STRIDE |
|
if self.use_2d_patch: |
|
self.patch_stride = [1] + self.patch_stride |
|
self.T = cfg.DATA.NUM_FRAMES // self.patch_stride[0] |
|
self.H = cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[1] |
|
self.W = cfg.DATA.TRAIN_CROP_SIZE // self.patch_stride[2] |
|
|
|
num_classes = cfg.MODEL.NUM_CLASSES |
|
embed_dim = cfg.MVIT.EMBED_DIM |
|
|
|
num_heads = cfg.MVIT.NUM_HEADS |
|
mlp_ratio = cfg.MVIT.MLP_RATIO |
|
qkv_bias = cfg.MVIT.QKV_BIAS |
|
self.drop_rate = cfg.MVIT.DROPOUT_RATE |
|
depth = cfg.MVIT.DEPTH |
|
drop_path_rate = cfg.MVIT.DROPPATH_RATE |
|
layer_scale_init_value = cfg.MVIT.LAYER_SCALE_INIT_VALUE |
|
head_init_scale = cfg.MVIT.HEAD_INIT_SCALE |
|
mode = cfg.MVIT.MODE |
|
self.cls_embed_on = cfg.MVIT.CLS_EMBED_ON |
|
self.use_mean_pooling = cfg.MVIT.USE_MEAN_POOLING |
|
|
|
self.use_abs_pos = cfg.MVIT.USE_ABS_POS |
|
self.use_fixed_sincos_pos = cfg.MVIT.USE_FIXED_SINCOS_POS |
|
self.sep_pos_embed = cfg.MVIT.SEP_POS_EMBED |
|
self.rel_pos_spatial = cfg.MVIT.REL_POS_SPATIAL |
|
self.rel_pos_temporal = cfg.MVIT.REL_POS_TEMPORAL |
|
if cfg.MVIT.NORM == "layernorm": |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
else: |
|
raise NotImplementedError("Only supports layernorm.") |
|
self.num_classes = num_classes |
|
self.patch_embed = stem_helper.PatchEmbed( |
|
dim_in=in_chans, |
|
dim_out=embed_dim, |
|
kernel=cfg.MVIT.PATCH_KERNEL, |
|
stride=cfg.MVIT.PATCH_STRIDE, |
|
padding=cfg.MVIT.PATCH_PADDING, |
|
conv_2d=self.use_2d_patch, |
|
) |
|
|
|
self.input_dims = [temporal_size, spatial_size, spatial_size] |
|
assert self.input_dims[1] == self.input_dims[2] |
|
self.patch_dims = [ |
|
self.input_dims[i] // self.patch_stride[i] |
|
for i in range(len(self.input_dims)) |
|
] |
|
num_patches = math.prod(self.patch_dims) |
|
|
|
dpr = [ |
|
x.item() for x in torch.linspace(0, drop_path_rate, depth) |
|
] |
|
|
|
if self.cls_embed_on: |
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
pos_embed_dim = num_patches + 1 |
|
else: |
|
pos_embed_dim = num_patches |
|
|
|
if self.use_abs_pos: |
|
if self.sep_pos_embed: |
|
self.pos_embed_spatial = nn.Parameter( |
|
torch.zeros( |
|
1, self.patch_dims[1] * self.patch_dims[2], embed_dim |
|
) |
|
) |
|
self.pos_embed_temporal = nn.Parameter( |
|
torch.zeros(1, self.patch_dims[0], embed_dim) |
|
) |
|
if self.cls_embed_on: |
|
self.pos_embed_class = nn.Parameter( |
|
torch.zeros(1, 1, embed_dim) |
|
) |
|
else: |
|
self.pos_embed = nn.Parameter( |
|
torch.zeros( |
|
1, |
|
pos_embed_dim, |
|
embed_dim, |
|
), |
|
requires_grad=not self.use_fixed_sincos_pos, |
|
) |
|
|
|
if self.drop_rate > 0.0: |
|
self.pos_drop = nn.Dropout(p=self.drop_rate) |
|
|
|
dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) |
|
for i in range(len(cfg.MVIT.DIM_MUL)): |
|
dim_mul[cfg.MVIT.DIM_MUL[i][0]] = cfg.MVIT.DIM_MUL[i][1] |
|
for i in range(len(cfg.MVIT.HEAD_MUL)): |
|
head_mul[cfg.MVIT.HEAD_MUL[i][0]] = cfg.MVIT.HEAD_MUL[i][1] |
|
|
|
pool_q = [[] for i in range(cfg.MVIT.DEPTH)] |
|
pool_kv = [[] for i in range(cfg.MVIT.DEPTH)] |
|
stride_q = [[] for i in range(cfg.MVIT.DEPTH)] |
|
stride_kv = [[] for i in range(cfg.MVIT.DEPTH)] |
|
|
|
for i in range(len(cfg.MVIT.POOL_Q_STRIDE)): |
|
stride_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_Q_STRIDE[i][ |
|
1: |
|
] |
|
if cfg.MVIT.POOL_KVQ_KERNEL is not None: |
|
pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = cfg.MVIT.POOL_KVQ_KERNEL |
|
else: |
|
pool_q[cfg.MVIT.POOL_Q_STRIDE[i][0]] = [ |
|
s + 1 if s > 1 else s for s in cfg.MVIT.POOL_Q_STRIDE[i][1:] |
|
] |
|
|
|
|
|
if cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE is not None: |
|
_stride_kv = cfg.MVIT.POOL_KV_STRIDE_ADAPTIVE |
|
cfg.MVIT.POOL_KV_STRIDE = [] |
|
for i in range(cfg.MVIT.DEPTH): |
|
if len(stride_q[i]) > 0: |
|
_stride_kv = [ |
|
max(_stride_kv[d] // stride_q[i][d], 1) |
|
for d in range(len(_stride_kv)) |
|
] |
|
cfg.MVIT.POOL_KV_STRIDE.append([i] + _stride_kv) |
|
|
|
for i in range(len(cfg.MVIT.POOL_KV_STRIDE)): |
|
stride_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = cfg.MVIT.POOL_KV_STRIDE[ |
|
i |
|
][1:] |
|
if cfg.MVIT.POOL_KVQ_KERNEL is not None: |
|
pool_kv[ |
|
cfg.MVIT.POOL_KV_STRIDE[i][0] |
|
] = cfg.MVIT.POOL_KVQ_KERNEL |
|
else: |
|
pool_kv[cfg.MVIT.POOL_KV_STRIDE[i][0]] = [ |
|
s + 1 if s > 1 else s |
|
for s in cfg.MVIT.POOL_KV_STRIDE[i][1:] |
|
] |
|
|
|
self.pool_q = pool_q |
|
self.pool_kv = pool_kv |
|
self.stride_q = stride_q |
|
self.stride_kv = stride_kv |
|
|
|
self.norm_stem = norm_layer(embed_dim) if cfg.MVIT.NORM_STEM else None |
|
|
|
input_size = self.patch_dims |
|
|
|
if self.enable_rev: |
|
|
|
|
|
assert not self.cls_embed_on |
|
|
|
self.rev_backbone = ReversibleMViT(cfg, self) |
|
|
|
embed_dim = round_width( |
|
embed_dim, dim_mul.prod(), divisor=num_heads |
|
) |
|
|
|
self.fuse = TwoStreamFusion( |
|
cfg.MVIT.REV.RESPATH_FUSE, dim=2 * embed_dim |
|
) |
|
|
|
if "concat" in self.cfg.MVIT.REV.RESPATH_FUSE: |
|
self.norm = norm_layer(2 * embed_dim) |
|
else: |
|
self.norm = norm_layer(embed_dim) |
|
|
|
else: |
|
|
|
self.blocks = nn.ModuleList() |
|
|
|
for i in range(depth): |
|
num_heads = round_width(num_heads, head_mul[i]) |
|
if cfg.MVIT.DIM_MUL_IN_ATT: |
|
dim_out = round_width( |
|
embed_dim, |
|
dim_mul[i], |
|
divisor=round_width(num_heads, head_mul[i]), |
|
) |
|
else: |
|
dim_out = round_width( |
|
embed_dim, |
|
dim_mul[i + 1], |
|
divisor=round_width(num_heads, head_mul[i + 1]), |
|
) |
|
attention_block = MultiScaleBlock( |
|
dim=embed_dim, |
|
dim_out=dim_out, |
|
num_heads=num_heads, |
|
input_size=input_size, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
drop_rate=self.drop_rate, |
|
drop_path=dpr[i], |
|
norm_layer=norm_layer, |
|
kernel_q=pool_q[i] if len(pool_q) > i else [], |
|
kernel_kv=pool_kv[i] if len(pool_kv) > i else [], |
|
stride_q=stride_q[i] if len(stride_q) > i else [], |
|
stride_kv=stride_kv[i] if len(stride_kv) > i else [], |
|
mode=mode, |
|
has_cls_embed=self.cls_embed_on, |
|
pool_first=pool_first, |
|
rel_pos_spatial=self.rel_pos_spatial, |
|
rel_pos_temporal=self.rel_pos_temporal, |
|
rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT, |
|
residual_pooling=cfg.MVIT.RESIDUAL_POOLING, |
|
dim_mul_in_att=cfg.MVIT.DIM_MUL_IN_ATT, |
|
separate_qkv=cfg.MVIT.SEPARATE_QKV, |
|
) |
|
|
|
self.blocks.append(attention_block) |
|
if len(stride_q[i]) > 0: |
|
input_size = [ |
|
size // stride |
|
for size, stride in zip(input_size, stride_q[i]) |
|
] |
|
|
|
embed_dim = dim_out |
|
|
|
self.norm = norm_layer(embed_dim) |
|
|
|
if self.enable_detection: |
|
raise Exception("Detection is not supported") |
|
else: |
|
self.head = head_helper.TransformerBasicHead( |
|
2 * embed_dim |
|
if ("concat" in cfg.MVIT.REV.RESPATH_FUSE and self.enable_rev) |
|
else embed_dim, |
|
num_classes, |
|
dropout_rate=cfg.MODEL.DROPOUT_RATE, |
|
act_func=cfg.MODEL.HEAD_ACT, |
|
cfg=cfg, |
|
) |
|
if self.use_abs_pos: |
|
if self.sep_pos_embed: |
|
trunc_normal_(self.pos_embed_spatial, std=0.02) |
|
trunc_normal_(self.pos_embed_temporal, std=0.02) |
|
if self.cls_embed_on: |
|
trunc_normal_(self.pos_embed_class, std=0.02) |
|
else: |
|
trunc_normal_(self.pos_embed, std=0.02) |
|
if self.use_fixed_sincos_pos: |
|
pos_embed = get_3d_sincos_pos_embed( |
|
self.pos_embed.shape[-1], |
|
self.H, |
|
self.T, |
|
cls_token=self.cls_embed_on, |
|
) |
|
self.pos_embed.data.copy_( |
|
torch.from_numpy(pos_embed).float().unsqueeze(0) |
|
) |
|
|
|
if self.cls_embed_on: |
|
trunc_normal_(self.cls_token, std=0.02) |
|
self.apply(self._init_weights) |
|
|
|
self.head.projection.weight.data.mul_(head_init_scale) |
|
self.head.projection.bias.data.mul_(head_init_scale) |
|
|
|
self.feat_size, self.feat_stride = calc_mvit_feature_geometry(cfg) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0.02) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0.02) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
names = [] |
|
if self.cfg.MVIT.ZERO_DECAY_POS_CLS: |
|
if self.use_abs_pos: |
|
if self.sep_pos_embed: |
|
names.extend( |
|
[ |
|
"pos_embed_spatial", |
|
"pos_embed_temporal", |
|
"pos_embed_class", |
|
] |
|
) |
|
else: |
|
names.append("pos_embed") |
|
if self.rel_pos_spatial: |
|
names.extend(["rel_pos_h", "rel_pos_w", "rel_pos_hw"]) |
|
if self.rel_pos_temporal: |
|
names.extend(["rel_pos_t"]) |
|
if self.cls_embed_on: |
|
names.append("cls_token") |
|
|
|
return names |
|
|
|
def _get_pos_embed(self, pos_embed, bcthw): |
|
|
|
if len(bcthw) == 4: |
|
t, h, w = 1, bcthw[-2], bcthw[-1] |
|
else: |
|
t, h, w = bcthw[-3], bcthw[-2], bcthw[-1] |
|
if self.cls_embed_on: |
|
cls_pos_embed = pos_embed[:, 0:1, :] |
|
pos_embed = pos_embed[:, 1:] |
|
txy_num = pos_embed.shape[1] |
|
p_t, p_h, p_w = self.patch_dims |
|
assert p_t * p_h * p_w == txy_num |
|
|
|
if (p_t, p_h, p_w) != (t, h, w): |
|
new_pos_embed = F.interpolate( |
|
pos_embed[:, :, :] |
|
.reshape(1, p_t, p_h, p_w, -1) |
|
.permute(0, 4, 1, 2, 3), |
|
size=(t, h, w), |
|
mode="trilinear", |
|
) |
|
pos_embed = new_pos_embed.reshape(1, -1, t * h * w).permute(0, 2, 1) |
|
|
|
if self.cls_embed_on: |
|
pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1) |
|
|
|
return pos_embed |
|
|
|
def _forward_reversible(self, x): |
|
""" |
|
Reversible specific code for forward computation. |
|
""" |
|
|
|
assert not self.cls_embed_on |
|
assert not self.enable_detection |
|
|
|
x = self.rev_backbone(x) |
|
|
|
if self.use_mean_pooling: |
|
x = self.fuse(x) |
|
x = x.mean(1) |
|
x = self.norm(x) |
|
else: |
|
x = self.norm(x) |
|
x = self.fuse(x) |
|
x = x.mean(1) |
|
|
|
x = self.head(x) |
|
|
|
return x |
|
|
|
def forward(self, x, bboxes=None, return_attn=False): |
|
x = x[0] |
|
x, bcthw = self.patch_embed(x) |
|
bcthw = list(bcthw) |
|
if len(bcthw) == 4: |
|
bcthw.insert(2, torch.tensor(self.T)) |
|
T, H, W = bcthw[-3], bcthw[-2], bcthw[-1] |
|
assert len(bcthw) == 5 and (T, H, W) == (self.T, self.H, self.W), bcthw |
|
B, N, C = x.shape |
|
s = 1 if self.cls_embed_on else 0 |
|
if self.use_fixed_sincos_pos: |
|
x += self.pos_embed[:, s:, :] |
|
|
|
if self.cls_embed_on: |
|
cls_tokens = self.cls_token.expand( |
|
B, -1, -1 |
|
) |
|
if self.use_fixed_sincos_pos: |
|
cls_tokens = cls_tokens + self.pos_embed[:, :s, :] |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
if self.use_abs_pos: |
|
if self.sep_pos_embed: |
|
pos_embed = self.pos_embed_spatial.repeat( |
|
1, self.patch_dims[0], 1 |
|
) + torch.repeat_interleave( |
|
self.pos_embed_temporal, |
|
self.patch_dims[1] * self.patch_dims[2], |
|
dim=1, |
|
) |
|
if self.cls_embed_on: |
|
pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1) |
|
x += self._get_pos_embed(pos_embed, bcthw) |
|
else: |
|
x += self._get_pos_embed(self.pos_embed, bcthw) |
|
|
|
if self.drop_rate: |
|
x = self.pos_drop(x) |
|
|
|
if self.norm_stem: |
|
x = self.norm_stem(x) |
|
|
|
thw = [T, H, W] |
|
|
|
if self.enable_rev: |
|
x = self._forward_reversible(x) |
|
|
|
else: |
|
for blk in self.blocks: |
|
x, thw = blk(x, thw) |
|
|
|
if self.enable_detection: |
|
assert not self.enable_rev |
|
|
|
x = self.norm(x) |
|
if self.cls_embed_on: |
|
x = x[:, 1:] |
|
|
|
B, _, C = x.shape |
|
x = x.transpose(1, 2).reshape(B, C, thw[0], thw[1], thw[2]) |
|
|
|
x = self.head([x], bboxes) |
|
|
|
else: |
|
if self.use_mean_pooling: |
|
if self.cls_embed_on: |
|
x = x[:, 1:] |
|
x = x.mean(1) |
|
x = self.norm(x) |
|
elif self.cls_embed_on: |
|
x = self.norm(x) |
|
x = x[:, 0] |
|
else: |
|
x = self.norm(x) |
|
x = x.mean(1) |
|
x = self.head(x) |
|
|
|
return x |