|
import sys |
|
from functools import partial |
|
import torch |
|
from torch import nn |
|
from torch.autograd import Function as Function |
|
|
|
from .attention import MultiScaleAttention, attention_pool |
|
from .common import Mlp, TwoStreamFusion, drop_path |
|
from .utils import round_width |
|
|
|
|
|
class ReversibleMViT(nn.Module): |
|
""" |
|
Reversible model builder. This builds the reversible transformer encoder |
|
and allows reversible training. |
|
|
|
Karttikeya Mangalam, Haoqi Fan, Yanghao Li, Chao-Yuan Wu, Bo Xiong, |
|
Christoph Feichtenhofer, Jitendra Malik |
|
"Reversible Vision Transformers" |
|
|
|
https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf |
|
""" |
|
|
|
def __init__(self, config, model): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
model (nn.Module): parent MViT module this module forms |
|
a reversible encoder in. |
|
""" |
|
|
|
super().__init__() |
|
self.cfg = config |
|
|
|
embed_dim = self.cfg.MVIT.EMBED_DIM |
|
depth = self.cfg.MVIT.DEPTH |
|
num_heads = self.cfg.MVIT.NUM_HEADS |
|
mlp_ratio = self.cfg.MVIT.MLP_RATIO |
|
qkv_bias = self.cfg.MVIT.QKV_BIAS |
|
|
|
drop_path_rate = self.cfg.MVIT.DROPPATH_RATE |
|
self.dropout = config.MVIT.DROPOUT_RATE |
|
self.pre_q_fusion = self.cfg.MVIT.REV.PRE_Q_FUSION |
|
dpr = [ |
|
x.item() for x in torch.linspace(0, drop_path_rate, depth) |
|
] |
|
|
|
input_size = model.patch_dims |
|
|
|
self.layers = nn.ModuleList([]) |
|
self.no_custom_backward = False |
|
|
|
if self.cfg.MVIT.NORM == "layernorm": |
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
else: |
|
raise NotImplementedError("Only supports layernorm.") |
|
|
|
dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) |
|
for i in range(len(self.cfg.MVIT.DIM_MUL)): |
|
dim_mul[self.cfg.MVIT.DIM_MUL[i][0]] = self.cfg.MVIT.DIM_MUL[i][1] |
|
for i in range(len(self.cfg.MVIT.HEAD_MUL)): |
|
head_mul[self.cfg.MVIT.HEAD_MUL[i][0]] = self.cfg.MVIT.HEAD_MUL[i][ |
|
1 |
|
] |
|
|
|
pool_q = model.pool_q |
|
pool_kv = model.pool_kv |
|
stride_q = model.stride_q |
|
stride_kv = model.stride_kv |
|
|
|
for i in range(depth): |
|
|
|
num_heads = round_width(num_heads, head_mul[i]) |
|
|
|
|
|
|
|
embed_dim = round_width( |
|
embed_dim, dim_mul[i - 1] if i > 0 else 1.0, divisor=num_heads |
|
) |
|
dim_out = round_width( |
|
embed_dim, |
|
dim_mul[i], |
|
divisor=round_width(num_heads, head_mul[i + 1]), |
|
) |
|
|
|
if i in self.cfg.MVIT.REV.BUFFER_LAYERS: |
|
layer_type = StageTransitionBlock |
|
input_mult = 2 if "concat" in self.pre_q_fusion else 1 |
|
else: |
|
layer_type = ReversibleBlock |
|
input_mult = 1 |
|
|
|
dimout_correction = ( |
|
2 if (input_mult == 2 and "concat" in self.pre_q_fusion) else 1 |
|
) |
|
|
|
self.layers.append( |
|
layer_type( |
|
dim=embed_dim |
|
* input_mult, |
|
input_size=input_size, |
|
dim_out=dim_out * input_mult // dimout_correction, |
|
num_heads=num_heads, |
|
cfg=self.cfg, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
drop_path=dpr[i], |
|
norm_layer=norm_layer, |
|
kernel_q=pool_q[i] if len(pool_q) > i else [], |
|
kernel_kv=pool_kv[i] if len(pool_kv) > i else [], |
|
stride_q=stride_q[i] if len(stride_q) > i else [], |
|
stride_kv=stride_kv[i] if len(stride_kv) > i else [], |
|
layer_id=i, |
|
pre_q_fusion=self.pre_q_fusion, |
|
) |
|
) |
|
|
|
self.layers[-1].F.thw = input_size |
|
|
|
if len(stride_q[i]) > 0: |
|
input_size = [ |
|
size // stride |
|
for size, stride in zip(input_size, stride_q[i]) |
|
] |
|
|
|
embed_dim = dim_out |
|
|
|
@staticmethod |
|
def vanilla_backward(h, layers, buffer): |
|
""" |
|
Using rev layers without rev backpropagation. Debugging purposes only. |
|
Activated with self.no_custom_backward. |
|
""" |
|
|
|
|
|
h, a = torch.chunk(h, 2, dim=-1) |
|
for _, layer in enumerate(layers): |
|
a, h = layer(a, h) |
|
|
|
return torch.cat([a, h], dim=-1) |
|
|
|
def forward(self, x): |
|
|
|
|
|
stack = [] |
|
for l_i in range(len(self.layers)): |
|
if isinstance(self.layers[l_i], StageTransitionBlock): |
|
stack.append(("StageTransition", l_i)) |
|
else: |
|
if len(stack) == 0 or stack[-1][0] == "StageTransition": |
|
stack.append(("Reversible", [])) |
|
stack[-1][1].append(l_i) |
|
|
|
for layer_seq in stack: |
|
|
|
if layer_seq[0] == "StageTransition": |
|
x = self.layers[layer_seq[1]](x) |
|
|
|
else: |
|
x = torch.cat([x, x], dim=-1) |
|
|
|
|
|
if not self.training or self.no_custom_backward: |
|
executing_fn = ReversibleMViT.vanilla_backward |
|
else: |
|
executing_fn = RevBackProp.apply |
|
|
|
x = executing_fn( |
|
x, |
|
self.layers[layer_seq[1][0] : layer_seq[1][-1] + 1], |
|
[], |
|
) |
|
|
|
|
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training) |
|
|
|
return x |
|
|
|
|
|
class RevBackProp(Function): |
|
""" |
|
Custom Backpropagation function to allow (A) flusing memory in foward |
|
and (B) activation recomputation reversibly in backward for gradient calculation. |
|
|
|
Inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py |
|
""" |
|
|
|
@staticmethod |
|
def forward( |
|
ctx, |
|
x, |
|
layers, |
|
buffer_layers, |
|
): |
|
""" |
|
Reversible Forward pass. Any intermediate activations from `buffer_layers` are |
|
cached in ctx for forward pass. This is not necessary for standard usecases. |
|
Each reversible layer implements its own forward pass logic. |
|
""" |
|
buffer_layers.sort() |
|
|
|
X_1, X_2 = torch.chunk(x, 2, dim=-1) |
|
|
|
intermediate = [] |
|
|
|
for layer in layers: |
|
|
|
X_1, X_2 = layer(X_1, X_2) |
|
|
|
if layer.layer_id in buffer_layers: |
|
intermediate.extend([X_1.detach(), X_2.detach()]) |
|
|
|
if len(buffer_layers) == 0: |
|
all_tensors = [X_1.detach(), X_2.detach()] |
|
else: |
|
intermediate = [torch.LongTensor(buffer_layers), *intermediate] |
|
all_tensors = [X_1.detach(), X_2.detach(), *intermediate] |
|
|
|
ctx.save_for_backward(*all_tensors) |
|
ctx.layers = layers |
|
|
|
return torch.cat([X_1, X_2], dim=-1) |
|
|
|
@staticmethod |
|
def backward(ctx, dx): |
|
""" |
|
Reversible Backward pass. Any intermediate activations from `buffer_layers` are |
|
recovered from ctx. Each layer implements its own loic for backward pass (both |
|
activation recomputation and grad calculation). |
|
""" |
|
dX_1, dX_2 = torch.chunk(dx, 2, dim=-1) |
|
|
|
|
|
X_1, X_2, *int_tensors = ctx.saved_tensors |
|
|
|
|
|
if len(int_tensors) != 0: |
|
buffer_layers = int_tensors[0].tolist() |
|
|
|
else: |
|
buffer_layers = [] |
|
|
|
layers = ctx.layers |
|
|
|
for _, layer in enumerate(layers[::-1]): |
|
|
|
if layer.layer_id in buffer_layers: |
|
|
|
X_1, X_2, dX_1, dX_2 = layer.backward_pass( |
|
Y_1=int_tensors[ |
|
buffer_layers.index(layer.layer_id) * 2 + 1 |
|
], |
|
Y_2=int_tensors[ |
|
buffer_layers.index(layer.layer_id) * 2 + 2 |
|
], |
|
dY_1=dX_1, |
|
dY_2=dX_2, |
|
) |
|
|
|
else: |
|
|
|
X_1, X_2, dX_1, dX_2 = layer.backward_pass( |
|
Y_1=X_1, |
|
Y_2=X_2, |
|
dY_1=dX_1, |
|
dY_2=dX_2, |
|
) |
|
|
|
dx = torch.cat([dX_1, dX_2], dim=-1) |
|
|
|
del int_tensors |
|
del dX_1, dX_2, X_1, X_2 |
|
|
|
return dx, None, None |
|
|
|
|
|
class StageTransitionBlock(nn.Module): |
|
""" |
|
Blocks for changing the feature dimensions in MViT (using Q-pooling). |
|
See Section 3.3.1 in paper for details. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
input_size, |
|
dim_out, |
|
num_heads, |
|
mlp_ratio, |
|
qkv_bias, |
|
drop_path, |
|
kernel_q, |
|
kernel_kv, |
|
stride_q, |
|
stride_kv, |
|
cfg, |
|
norm_layer=nn.LayerNorm, |
|
pre_q_fusion=None, |
|
layer_id=0, |
|
): |
|
""" |
|
Uses the same structure of F and G functions as Reversible Block except |
|
without using reversible forward (and backward) pass. |
|
""" |
|
super().__init__() |
|
|
|
self.drop_path_rate = drop_path |
|
|
|
embed_dim = dim |
|
|
|
self.F = AttentionSubBlock( |
|
dim=embed_dim, |
|
input_size=input_size, |
|
num_heads=num_heads, |
|
cfg=cfg, |
|
dim_out=dim_out, |
|
kernel_q=kernel_q, |
|
kernel_kv=kernel_kv, |
|
stride_q=stride_q, |
|
stride_kv=stride_kv, |
|
norm_layer=norm_layer, |
|
) |
|
|
|
self.G = MLPSubblock( |
|
dim=dim_out, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
) |
|
|
|
self.layer_id = layer_id |
|
|
|
self.is_proj = False |
|
self.has_cls_embed = cfg.MVIT.CLS_EMBED_ON |
|
|
|
self.is_conv = False |
|
self.pool_first = cfg.MVIT.POOL_FIRST |
|
self.mode = cfg.MVIT.MODE |
|
self.pre_q_fuse = TwoStreamFusion(pre_q_fusion, dim=dim) |
|
|
|
if cfg.MVIT.REV.RES_PATH == "max": |
|
self.res_conv = False |
|
self.pool_skip = nn.MaxPool3d( |
|
|
|
[s + 1 if s > 1 else s for s in self.F.attn.pool_q.stride], |
|
self.F.attn.pool_q.stride, |
|
[int(k // 2) for k in self.F.attn.pool_q.stride], |
|
|
|
ceil_mode=False, |
|
) |
|
|
|
elif cfg.MVIT.REV.RES_PATH == "conv": |
|
self.res_conv = True |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
if embed_dim != dim_out: |
|
self.is_proj = True |
|
self.res_proj = nn.Linear(embed_dim, dim_out, bias=True) |
|
|
|
def forward( |
|
self, |
|
x, |
|
): |
|
""" |
|
Forward logic is similar to MultiScaleBlock with Q-pooling. |
|
""" |
|
x = self.pre_q_fuse(x) |
|
|
|
|
|
x_res = x |
|
|
|
|
|
|
|
if self.is_proj and not self.pool_first: |
|
x_res = self.res_proj(x_res) |
|
|
|
if self.res_conv: |
|
|
|
|
|
N, L, C = x_res.shape |
|
|
|
|
|
if self.mode == "conv_unshared": |
|
fold_dim = 1 |
|
else: |
|
fold_dim = self.F.attn.num_heads |
|
|
|
|
|
x_res = x_res.reshape(N, L, fold_dim, C // fold_dim).permute( |
|
0, 2, 1, 3 |
|
) |
|
|
|
x_res, _ = attention_pool( |
|
x_res, |
|
self.F.attn.pool_q, |
|
|
|
thw_shape=self.F.thw, |
|
has_cls_embed=self.has_cls_embed, |
|
norm=self.F.attn.norm_q |
|
if hasattr(self.F.attn, "norm_q") |
|
else None, |
|
) |
|
x_res = x_res.permute(0, 2, 1, 3).reshape(N, x_res.shape[2], C) |
|
|
|
else: |
|
|
|
x_res, _ = attention_pool( |
|
x_res, |
|
self.pool_skip, |
|
thw_shape=self.F.attn.thw, |
|
has_cls_embed=self.has_cls_embed, |
|
) |
|
|
|
|
|
if self.is_proj and self.pool_first: |
|
x_res = self.res_proj(x_res) |
|
|
|
x = self.F(x) |
|
x = x_res + x |
|
x = x + self.G(x) |
|
|
|
x = drop_path(x, drop_prob=self.drop_path_rate, training=self.training) |
|
|
|
return x |
|
|
|
|
|
class ReversibleBlock(nn.Module): |
|
""" |
|
Reversible Blocks for Reversible Vision Transformer and also |
|
for state-preserving blocks in Reversible MViT. See Section |
|
3.3.2 in paper for details. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
input_size, |
|
dim_out, |
|
num_heads, |
|
mlp_ratio, |
|
qkv_bias, |
|
drop_path, |
|
kernel_q, |
|
kernel_kv, |
|
stride_q, |
|
stride_kv, |
|
cfg, |
|
norm_layer=nn.LayerNorm, |
|
layer_id=0, |
|
**kwargs |
|
): |
|
""" |
|
Block is composed entirely of function F (Attention |
|
sub-block) and G (MLP sub-block) including layernorm. |
|
""" |
|
super().__init__() |
|
|
|
self.drop_path_rate = drop_path |
|
|
|
self.F = AttentionSubBlock( |
|
dim=dim, |
|
input_size=input_size, |
|
num_heads=num_heads, |
|
cfg=cfg, |
|
dim_out=dim_out, |
|
kernel_q=kernel_q, |
|
kernel_kv=kernel_kv, |
|
stride_q=stride_q, |
|
stride_kv=stride_kv, |
|
norm_layer=norm_layer, |
|
) |
|
|
|
self.G = MLPSubblock( |
|
dim=dim, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
) |
|
|
|
self.layer_id = layer_id |
|
|
|
self.seeds = {} |
|
|
|
def seed_cuda(self, key): |
|
""" |
|
Fix seeds to allow for stochastic elements such as |
|
dropout to be reproduced exactly in activation |
|
recomputation in the backward pass. |
|
""" |
|
|
|
|
|
|
|
if ( |
|
hasattr(torch.cuda, "default_generators") |
|
and len(torch.cuda.default_generators) > 0 |
|
): |
|
|
|
device_idx = torch.cuda.current_device() |
|
seed = torch.cuda.default_generators[device_idx].seed() |
|
else: |
|
|
|
seed = int(torch.seed() % sys.maxsize) |
|
|
|
self.seeds[key] = seed |
|
torch.manual_seed(self.seeds[key]) |
|
|
|
def forward(self, X_1, X_2): |
|
""" |
|
forward pass equations: |
|
Y_1 = X_1 + Attention(X_2), F = Attention |
|
Y_2 = X_2 + MLP(Y_1), G = MLP |
|
""" |
|
|
|
self.seed_cuda("attn") |
|
|
|
f_X_2 = self.F(X_2) |
|
|
|
self.seed_cuda("droppath") |
|
f_X_2_dropped = drop_path( |
|
f_X_2, drop_prob=self.drop_path_rate, training=self.training |
|
) |
|
|
|
|
|
Y_1 = X_1 + f_X_2_dropped |
|
|
|
|
|
del X_1 |
|
|
|
self.seed_cuda("FFN") |
|
g_Y_1 = self.G(Y_1) |
|
|
|
torch.manual_seed(self.seeds["droppath"]) |
|
g_Y_1_dropped = drop_path( |
|
g_Y_1, drop_prob=self.drop_path_rate, training=self.training |
|
) |
|
|
|
|
|
Y_2 = X_2 + g_Y_1_dropped |
|
|
|
del X_2 |
|
|
|
return Y_1, Y_2 |
|
|
|
def backward_pass( |
|
self, |
|
Y_1, |
|
Y_2, |
|
dY_1, |
|
dY_2, |
|
): |
|
""" |
|
equation for activation recomputation: |
|
X_2 = Y_2 - G(Y_1), G = MLP |
|
X_1 = Y_1 - F(X_2), F = Attention |
|
""" |
|
|
|
|
|
|
|
with torch.enable_grad(): |
|
|
|
Y_1.requires_grad = True |
|
|
|
torch.manual_seed(self.seeds["FFN"]) |
|
g_Y_1 = self.G(Y_1) |
|
|
|
torch.manual_seed(self.seeds["droppath"]) |
|
g_Y_1 = drop_path( |
|
g_Y_1, drop_prob=self.drop_path_rate, training=self.training |
|
) |
|
|
|
g_Y_1.backward(dY_2, retain_graph=True) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
X_2 = Y_2 - g_Y_1 |
|
del g_Y_1 |
|
|
|
dY_1 = dY_1 + Y_1.grad |
|
Y_1.grad = None |
|
|
|
|
|
with torch.enable_grad(): |
|
X_2.requires_grad = True |
|
|
|
torch.manual_seed(self.seeds["attn"]) |
|
f_X_2 = self.F(X_2) |
|
|
|
torch.manual_seed(self.seeds["droppath"]) |
|
f_X_2 = drop_path( |
|
f_X_2, drop_prob=self.drop_path_rate, training=self.training |
|
) |
|
|
|
f_X_2.backward(dY_1, retain_graph=True) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
X_1 = Y_1 - f_X_2 |
|
|
|
del f_X_2, Y_1 |
|
dY_2 = dY_2 + X_2.grad |
|
|
|
X_2.grad = None |
|
X_2 = X_2.detach() |
|
|
|
return X_1, X_2, dY_1, dY_2 |
|
|
|
|
|
class MLPSubblock(nn.Module): |
|
""" |
|
This creates the function G such that the entire block can be |
|
expressed as F(G(X)). Includes pre-LayerNorm. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
mlp_ratio, |
|
norm_layer=nn.LayerNorm, |
|
): |
|
|
|
super().__init__() |
|
self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True) |
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
|
self.mlp = Mlp( |
|
in_features=dim, |
|
hidden_features=mlp_hidden_dim, |
|
act_layer=nn.GELU, |
|
) |
|
|
|
def forward(self, x): |
|
return self.mlp(self.norm(x)) |
|
|
|
|
|
class AttentionSubBlock(nn.Module): |
|
""" |
|
This creates the function F such that the entire block can be |
|
expressed as F(G(X)). Includes pre-LayerNorm. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
input_size, |
|
num_heads, |
|
cfg, |
|
dim_out=None, |
|
kernel_q=(1, 1, 1), |
|
kernel_kv=(1, 1, 1), |
|
stride_q=(1, 1, 1), |
|
stride_kv=(1, 1, 1), |
|
norm_layer=nn.LayerNorm, |
|
): |
|
|
|
super().__init__() |
|
self.norm = norm_layer(dim, eps=1e-6, elementwise_affine=True) |
|
|
|
|
|
self.thw = None |
|
|
|
|
|
|
|
|
|
self.attn = MultiScaleAttention( |
|
dim, |
|
dim_out, |
|
input_size=input_size, |
|
num_heads=num_heads, |
|
kernel_q=kernel_q, |
|
kernel_kv=kernel_kv, |
|
stride_q=stride_q, |
|
stride_kv=stride_kv, |
|
norm_layer=norm_layer, |
|
drop_rate=cfg.MVIT.DROPOUT_RATE, |
|
qkv_bias=cfg.MVIT.QKV_BIAS, |
|
has_cls_embed=cfg.MVIT.CLS_EMBED_ON, |
|
mode=cfg.MVIT.MODE, |
|
pool_first=cfg.MVIT.POOL_FIRST, |
|
rel_pos_spatial=cfg.MVIT.REL_POS_SPATIAL, |
|
rel_pos_temporal=cfg.MVIT.REL_POS_TEMPORAL, |
|
rel_pos_zero_init=cfg.MVIT.REL_POS_ZERO_INIT, |
|
residual_pooling=cfg.MVIT.RESIDUAL_POOLING, |
|
separate_qkv=cfg.MVIT.SEPARATE_QKV, |
|
) |
|
|
|
def forward(self, x): |
|
out, _ = self.attn(self.norm(x), self.thw) |
|
return out |