Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import numpy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.init import trunc_normal_ | |
from .common import DropPath, Mlp | |
def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): | |
if pool is None: | |
return tensor, thw_shape | |
tensor_dim = tensor.ndim | |
if tensor_dim == 4: | |
pass | |
elif tensor_dim == 3: | |
tensor = tensor.unsqueeze(1) | |
else: | |
raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") | |
if has_cls_embed: | |
cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] | |
B, N, L, C = tensor.shape | |
T, H, W = thw_shape | |
tensor = ( | |
tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() | |
) | |
tensor = pool(tensor) | |
thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] | |
L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] | |
tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) | |
if has_cls_embed: | |
tensor = torch.cat((cls_tok, tensor), dim=2) | |
if norm is not None: | |
tensor = norm(tensor) | |
# Assert tensor_dim in [3, 4] | |
if tensor_dim == 4: | |
pass | |
else: # tensor_dim == 3: | |
tensor = tensor.squeeze(1) | |
return tensor, thw_shape | |
def get_rel_pos(rel_pos, d): | |
if isinstance(d, int): | |
ori_d = rel_pos.shape[0] | |
if ori_d == d: | |
return rel_pos | |
else: | |
# Interpolate rel pos. | |
new_pos_embed = F.interpolate( | |
rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1), | |
size=d, | |
mode="linear", | |
) | |
return new_pos_embed.reshape(-1, d).permute(1, 0) | |
def cal_rel_pos_spatial( | |
attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h, rel_pos_w | |
): | |
""" | |
Decomposed Spatial Relative Positional Embeddings. | |
""" | |
sp_idx = 1 if has_cls_embed else 0 | |
q_t, q_h, q_w = q_shape | |
k_t, k_h, k_w = k_shape | |
dh = int(2 * max(q_h, k_h) - 1) | |
dw = int(2 * max(q_w, k_w) - 1) | |
# Scale up rel pos if shapes for q and k are different. | |
q_h_ratio = max(k_h / q_h, 1.0) | |
k_h_ratio = max(q_h / k_h, 1.0) | |
dist_h = ( | |
torch.arange(q_h)[:, None] * q_h_ratio | |
- torch.arange(k_h)[None, :] * k_h_ratio | |
) | |
dist_h += (k_h - 1) * k_h_ratio | |
q_w_ratio = max(k_w / q_w, 1.0) | |
k_w_ratio = max(q_w / k_w, 1.0) | |
dist_w = ( | |
torch.arange(q_w)[:, None] * q_w_ratio | |
- torch.arange(k_w)[None, :] * k_w_ratio | |
) | |
dist_w += (k_w - 1) * k_w_ratio | |
# Intepolate rel pos if needed. | |
rel_pos_h = get_rel_pos(rel_pos_h, dh) | |
rel_pos_w = get_rel_pos(rel_pos_w, dw) | |
Rh = rel_pos_h[dist_h.long()] | |
Rw = rel_pos_w[dist_w.long()] | |
B, n_head, q_N, dim = q.shape | |
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) | |
rel_h_q = torch.einsum( | |
"bythwc,hkc->bythwk", r_q, Rh | |
) # [B, H, q_t, qh, qw, k_h] | |
rel_w_q = torch.einsum( | |
"bythwc,wkc->bythwk", r_q, Rw | |
) # [B, H, q_t, qh, qw, k_w] | |
attn[:, :, sp_idx:, sp_idx:] = ( | |
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) | |
+ rel_h_q[:, :, :, :, :, None, :, None] | |
+ rel_w_q[:, :, :, :, :, None, None, :] | |
).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) | |
return attn | |
def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t): | |
""" | |
Temporal Relative Positional Embeddings. | |
""" | |
sp_idx = 1 if has_cls_embed else 0 | |
q_t, q_h, q_w = q_shape | |
k_t, k_h, k_w = k_shape | |
dt = int(2 * max(q_t, k_t) - 1) | |
# Intepolate rel pos if needed. | |
rel_pos_t = get_rel_pos(rel_pos_t, dt) | |
# Scale up rel pos if shapes for q and k are different. | |
q_t_ratio = max(k_t / q_t, 1.0) | |
k_t_ratio = max(q_t / k_t, 1.0) | |
dist_t = ( | |
torch.arange(q_t)[:, None] * q_t_ratio | |
- torch.arange(k_t)[None, :] * k_t_ratio | |
) | |
dist_t += (k_t - 1) * k_t_ratio | |
Rt = rel_pos_t[dist_t.long()] | |
B, n_head, q_N, dim = q.shape | |
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) | |
# [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] | |
r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape( | |
q_t, B * n_head * q_h * q_w, dim | |
) | |
# [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] | |
rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) | |
# [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] | |
rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) | |
attn[:, :, sp_idx:, sp_idx:] = ( | |
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) | |
+ rel[:, :, :, :, :, :, None, None] | |
).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) | |
return attn | |
class MultiScaleAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_out, | |
input_size, | |
num_heads=8, | |
qkv_bias=False, | |
drop_rate=0.0, | |
kernel_q=(1, 1, 1), | |
kernel_kv=(1, 1, 1), | |
stride_q=(1, 1, 1), | |
stride_kv=(1, 1, 1), | |
norm_layer=nn.LayerNorm, | |
has_cls_embed=True, | |
# Options include `conv`, `avg`, and `max`. | |
mode="conv", | |
# If True, perform pool before projection. | |
pool_first=False, | |
rel_pos_spatial=False, | |
rel_pos_temporal=False, | |
rel_pos_zero_init=False, | |
residual_pooling=False, | |
separate_qkv=False, | |
): | |
super().__init__() | |
self.pool_first = pool_first | |
self.separate_qkv = separate_qkv | |
self.drop_rate = drop_rate | |
self.num_heads = num_heads | |
self.dim_out = dim_out | |
head_dim = dim_out // num_heads | |
self.scale = head_dim**-0.5 | |
self.has_cls_embed = has_cls_embed | |
self.mode = mode | |
padding_q = [int(q // 2) for q in kernel_q] | |
padding_kv = [int(kv // 2) for kv in kernel_kv] | |
if pool_first or separate_qkv: | |
self.q = nn.Linear(dim, dim_out, bias=qkv_bias) | |
self.k = nn.Linear(dim, dim_out, bias=qkv_bias) | |
self.v = nn.Linear(dim, dim_out, bias=qkv_bias) | |
else: | |
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) | |
self.proj = nn.Linear(dim_out, dim_out) | |
if drop_rate > 0.0: | |
self.proj_drop = nn.Dropout(drop_rate) | |
# Skip pooling with kernel and stride size of (1, 1, 1). | |
if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1: | |
kernel_q = () | |
if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1: | |
kernel_kv = () | |
if mode in ("avg", "max"): | |
pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d | |
self.pool_q = ( | |
pool_op(kernel_q, stride_q, padding_q, ceil_mode=False) | |
if len(kernel_q) > 0 | |
else None | |
) | |
self.pool_k = ( | |
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) | |
if len(kernel_kv) > 0 | |
else None | |
) | |
self.pool_v = ( | |
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) | |
if len(kernel_kv) > 0 | |
else None | |
) | |
elif mode == "conv" or mode == "conv_unshared": | |
if pool_first: | |
dim_conv = dim // num_heads if mode == "conv" else dim | |
else: | |
dim_conv = dim_out // num_heads if mode == "conv" else dim_out | |
self.pool_q = ( | |
nn.Conv3d( | |
dim_conv, | |
dim_conv, | |
kernel_q, | |
stride=stride_q, | |
padding=padding_q, | |
groups=dim_conv, | |
bias=False, | |
) | |
if len(kernel_q) > 0 | |
else None | |
) | |
self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None | |
self.pool_k = ( | |
nn.Conv3d( | |
dim_conv, | |
dim_conv, | |
kernel_kv, | |
stride=stride_kv, | |
padding=padding_kv, | |
groups=dim_conv, | |
bias=False, | |
) | |
if len(kernel_kv) > 0 | |
else None | |
) | |
self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None | |
self.pool_v = ( | |
nn.Conv3d( | |
dim_conv, | |
dim_conv, | |
kernel_kv, | |
stride=stride_kv, | |
padding=padding_kv, | |
groups=dim_conv, | |
bias=False, | |
) | |
if len(kernel_kv) > 0 | |
else None | |
) | |
self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None | |
else: | |
raise NotImplementedError(f"Unsupported model {mode}") | |
self.rel_pos_spatial = rel_pos_spatial | |
self.rel_pos_temporal = rel_pos_temporal | |
if self.rel_pos_spatial: | |
assert input_size[1] == input_size[2] | |
size = input_size[1] | |
q_size = size // stride_q[1] if len(stride_q) > 0 else size | |
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size | |
rel_sp_dim = 2 * max(q_size, kv_size) - 1 | |
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) | |
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) | |
if not rel_pos_zero_init: | |
trunc_normal_(self.rel_pos_h, std=0.02) | |
trunc_normal_(self.rel_pos_w, std=0.02) | |
if self.rel_pos_temporal: | |
self.rel_pos_t = nn.Parameter( | |
torch.zeros(2 * input_size[0] - 1, head_dim) | |
) | |
if not rel_pos_zero_init: | |
trunc_normal_(self.rel_pos_t, std=0.02) | |
self.residual_pooling = residual_pooling | |
def forward(self, x, thw_shape): | |
B, N, _ = x.shape | |
if self.pool_first: | |
if self.mode == "conv_unshared": | |
fold_dim = 1 | |
else: | |
fold_dim = self.num_heads | |
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) | |
q = k = v = x | |
else: | |
assert self.mode != "conv_unshared" | |
if not self.separate_qkv: | |
qkv = ( | |
self.qkv(x) | |
.reshape(B, N, 3, self.num_heads, -1) | |
.permute(2, 0, 3, 1, 4) | |
) | |
q, k, v = qkv[0], qkv[1], qkv[2] | |
else: | |
q = k = v = x | |
q = ( | |
self.q(q) | |
.reshape(B, N, self.num_heads, -1) | |
.permute(0, 2, 1, 3) | |
) | |
k = ( | |
self.k(k) | |
.reshape(B, N, self.num_heads, -1) | |
.permute(0, 2, 1, 3) | |
) | |
v = ( | |
self.v(v) | |
.reshape(B, N, self.num_heads, -1) | |
.permute(0, 2, 1, 3) | |
) | |
q, q_shape = attention_pool( | |
q, | |
self.pool_q, | |
thw_shape, | |
has_cls_embed=self.has_cls_embed, | |
norm=self.norm_q if hasattr(self, "norm_q") else None, | |
) | |
k, k_shape = attention_pool( | |
k, | |
self.pool_k, | |
thw_shape, | |
has_cls_embed=self.has_cls_embed, | |
norm=self.norm_k if hasattr(self, "norm_k") else None, | |
) | |
v, v_shape = attention_pool( | |
v, | |
self.pool_v, | |
thw_shape, | |
has_cls_embed=self.has_cls_embed, | |
norm=self.norm_v if hasattr(self, "norm_v") else None, | |
) | |
if self.pool_first: | |
q_N = ( | |
numpy.prod(q_shape) + 1 | |
if self.has_cls_embed | |
else numpy.prod(q_shape) | |
) | |
k_N = ( | |
numpy.prod(k_shape) + 1 | |
if self.has_cls_embed | |
else numpy.prod(k_shape) | |
) | |
v_N = ( | |
numpy.prod(v_shape) + 1 | |
if self.has_cls_embed | |
else numpy.prod(v_shape) | |
) | |
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) | |
q = ( | |
self.q(q) | |
.reshape(B, q_N, self.num_heads, -1) | |
.permute(0, 2, 1, 3) | |
) | |
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) | |
v = ( | |
self.v(v) | |
.reshape(B, v_N, self.num_heads, -1) | |
.permute(0, 2, 1, 3) | |
) | |
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) | |
k = ( | |
self.k(k) | |
.reshape(B, k_N, self.num_heads, -1) | |
.permute(0, 2, 1, 3) | |
) | |
N = q.shape[2] | |
attn = (q * self.scale) @ k.transpose(-2, -1) | |
if self.rel_pos_spatial: | |
attn = cal_rel_pos_spatial( | |
attn, | |
q, | |
k, | |
self.has_cls_embed, | |
q_shape, | |
k_shape, | |
self.rel_pos_h, | |
self.rel_pos_w, | |
) | |
if self.rel_pos_temporal: | |
attn = cal_rel_pos_temporal( | |
attn, | |
q, | |
self.has_cls_embed, | |
q_shape, | |
k_shape, | |
self.rel_pos_t, | |
) | |
attn = attn.softmax(dim=-1) | |
x = attn @ v | |
if self.residual_pooling: | |
if self.has_cls_embed: | |
x[:, :, 1:, :] += q[:, :, 1:, :] | |
else: | |
x = x + q | |
x = x.transpose(1, 2).reshape(B, -1, self.dim_out) | |
x = self.proj(x) | |
if self.drop_rate > 0.0: | |
x = self.proj_drop(x) | |
return x, q_shape | |
class MultiScaleBlock(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_out, | |
num_heads, | |
input_size, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
qk_scale=None, | |
drop_rate=0.0, | |
drop_path=0.0, | |
layer_scale_init_value=0.0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
up_rate=None, | |
kernel_q=(1, 1, 1), | |
kernel_kv=(1, 1, 1), | |
stride_q=(1, 1, 1), | |
stride_kv=(1, 1, 1), | |
mode="conv", | |
has_cls_embed=True, | |
pool_first=False, | |
rel_pos_spatial=False, | |
rel_pos_temporal=False, | |
rel_pos_zero_init=False, | |
residual_pooling=False, | |
dim_mul_in_att=False, | |
separate_qkv=False, | |
): | |
super().__init__() | |
self.dim = dim | |
self.dim_out = dim_out | |
self.norm1 = norm_layer(dim) | |
self.dim_mul_in_att = dim_mul_in_att | |
kernel_skip = [s + 1 if s > 1 else s for s in stride_q] | |
stride_skip = stride_q | |
padding_skip = [int(skip // 2) for skip in kernel_skip] | |
att_dim = dim_out if dim_mul_in_att else dim | |
self.attn = MultiScaleAttention( | |
dim, | |
att_dim, | |
num_heads=num_heads, | |
input_size=input_size, | |
qkv_bias=qkv_bias, | |
drop_rate=drop_rate, | |
kernel_q=kernel_q, | |
kernel_kv=kernel_kv, | |
stride_q=stride_q, | |
stride_kv=stride_kv, | |
norm_layer=norm_layer, | |
has_cls_embed=has_cls_embed, | |
mode=mode, | |
pool_first=pool_first, | |
rel_pos_spatial=rel_pos_spatial, | |
rel_pos_temporal=rel_pos_temporal, | |
rel_pos_zero_init=rel_pos_zero_init, | |
residual_pooling=residual_pooling, | |
separate_qkv=separate_qkv, | |
) | |
self.drop_path = ( | |
DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
) | |
self.norm2 = norm_layer(att_dim) | |
mlp_hidden_dim = int(att_dim * mlp_ratio) | |
self.has_cls_embed = has_cls_embed | |
# TODO: check the use case for up_rate, and merge the following lines | |
if up_rate is not None and up_rate > 1: | |
mlp_dim_out = dim * up_rate | |
else: | |
mlp_dim_out = dim_out | |
self.mlp = Mlp( | |
in_features=att_dim, | |
hidden_features=mlp_hidden_dim, | |
out_features=mlp_dim_out, | |
act_layer=act_layer, | |
drop_rate=drop_rate, | |
) | |
if layer_scale_init_value > 0: | |
self.gamma_1 = nn.Parameter( | |
layer_scale_init_value * torch.ones((dim)), requires_grad=True | |
) | |
self.gamma_2 = nn.Parameter( | |
layer_scale_init_value * torch.ones((dim_out)), | |
requires_grad=True, | |
) | |
else: | |
self.gamma_1, self.gamma_2 = None, None | |
if dim != dim_out: | |
self.proj = nn.Linear(dim, dim_out) | |
self.pool_skip = ( | |
nn.MaxPool3d( | |
kernel_skip, stride_skip, padding_skip, ceil_mode=False | |
) | |
if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1 | |
else None | |
) | |
def forward(self, x, thw_shape=None): | |
x_norm = self.norm1(x) | |
x_block, thw_shape_new = self.attn(x_norm, thw_shape) | |
if self.dim_mul_in_att and self.dim != self.dim_out: | |
x = self.proj(x_norm) | |
x_res, _ = attention_pool( | |
x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed | |
) | |
if self.gamma_1 is not None: | |
x = x_res + self.drop_path(self.gamma_1 * x_block) | |
else: | |
x = x_res + self.drop_path(x_block) | |
x_norm = self.norm2(x) | |
x_mlp = self.mlp(x_norm) | |
if not self.dim_mul_in_att and self.dim != self.dim_out: | |
x = self.proj(x_norm) | |
if self.gamma_2 is not None: | |
x = x + self.drop_path(self.gamma_2 * x_mlp) | |
else: | |
x = x + self.drop_path(x_mlp) | |
if thw_shape: | |
return x, thw_shape_new | |
else: | |
return x |