|
|
|
|
|
|
|
|
|
import numpy |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.init import trunc_normal_ |
|
|
|
from .common import DropPath, Mlp |
|
|
|
|
|
def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): |
|
if pool is None: |
|
return tensor, thw_shape |
|
tensor_dim = tensor.ndim |
|
if tensor_dim == 4: |
|
pass |
|
elif tensor_dim == 3: |
|
tensor = tensor.unsqueeze(1) |
|
else: |
|
raise NotImplementedError(f"Unsupported input dimension {tensor.shape}") |
|
|
|
if has_cls_embed: |
|
cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] |
|
|
|
B, N, L, C = tensor.shape |
|
T, H, W = thw_shape |
|
tensor = ( |
|
tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous() |
|
) |
|
|
|
tensor = pool(tensor) |
|
|
|
thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] |
|
L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] |
|
tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) |
|
if has_cls_embed: |
|
tensor = torch.cat((cls_tok, tensor), dim=2) |
|
if norm is not None: |
|
tensor = norm(tensor) |
|
|
|
if tensor_dim == 4: |
|
pass |
|
else: |
|
tensor = tensor.squeeze(1) |
|
return tensor, thw_shape |
|
|
|
|
|
def get_rel_pos(rel_pos, d): |
|
if isinstance(d, int): |
|
ori_d = rel_pos.shape[0] |
|
if ori_d == d: |
|
return rel_pos |
|
else: |
|
|
|
new_pos_embed = F.interpolate( |
|
rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1), |
|
size=d, |
|
mode="linear", |
|
) |
|
|
|
return new_pos_embed.reshape(-1, d).permute(1, 0) |
|
|
|
|
|
def cal_rel_pos_spatial( |
|
attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h, rel_pos_w |
|
): |
|
""" |
|
Decomposed Spatial Relative Positional Embeddings. |
|
""" |
|
sp_idx = 1 if has_cls_embed else 0 |
|
q_t, q_h, q_w = q_shape |
|
k_t, k_h, k_w = k_shape |
|
dh = int(2 * max(q_h, k_h) - 1) |
|
dw = int(2 * max(q_w, k_w) - 1) |
|
|
|
|
|
q_h_ratio = max(k_h / q_h, 1.0) |
|
k_h_ratio = max(q_h / k_h, 1.0) |
|
dist_h = ( |
|
torch.arange(q_h)[:, None] * q_h_ratio |
|
- torch.arange(k_h)[None, :] * k_h_ratio |
|
) |
|
dist_h += (k_h - 1) * k_h_ratio |
|
q_w_ratio = max(k_w / q_w, 1.0) |
|
k_w_ratio = max(q_w / k_w, 1.0) |
|
dist_w = ( |
|
torch.arange(q_w)[:, None] * q_w_ratio |
|
- torch.arange(k_w)[None, :] * k_w_ratio |
|
) |
|
dist_w += (k_w - 1) * k_w_ratio |
|
|
|
|
|
rel_pos_h = get_rel_pos(rel_pos_h, dh) |
|
rel_pos_w = get_rel_pos(rel_pos_w, dw) |
|
Rh = rel_pos_h[dist_h.long()] |
|
Rw = rel_pos_w[dist_w.long()] |
|
|
|
B, n_head, q_N, dim = q.shape |
|
|
|
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) |
|
rel_h_q = torch.einsum( |
|
"bythwc,hkc->bythwk", r_q, Rh |
|
) |
|
rel_w_q = torch.einsum( |
|
"bythwc,wkc->bythwk", r_q, Rw |
|
) |
|
|
|
attn[:, :, sp_idx:, sp_idx:] = ( |
|
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) |
|
+ rel_h_q[:, :, :, :, :, None, :, None] |
|
+ rel_w_q[:, :, :, :, :, None, None, :] |
|
).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) |
|
|
|
return attn |
|
|
|
|
|
def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t): |
|
""" |
|
Temporal Relative Positional Embeddings. |
|
""" |
|
sp_idx = 1 if has_cls_embed else 0 |
|
q_t, q_h, q_w = q_shape |
|
k_t, k_h, k_w = k_shape |
|
dt = int(2 * max(q_t, k_t) - 1) |
|
|
|
rel_pos_t = get_rel_pos(rel_pos_t, dt) |
|
|
|
|
|
q_t_ratio = max(k_t / q_t, 1.0) |
|
k_t_ratio = max(q_t / k_t, 1.0) |
|
dist_t = ( |
|
torch.arange(q_t)[:, None] * q_t_ratio |
|
- torch.arange(k_t)[None, :] * k_t_ratio |
|
) |
|
dist_t += (k_t - 1) * k_t_ratio |
|
Rt = rel_pos_t[dist_t.long()] |
|
|
|
B, n_head, q_N, dim = q.shape |
|
|
|
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) |
|
|
|
r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape( |
|
q_t, B * n_head * q_h * q_w, dim |
|
) |
|
|
|
|
|
rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) |
|
|
|
rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) |
|
|
|
attn[:, :, sp_idx:, sp_idx:] = ( |
|
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) |
|
+ rel[:, :, :, :, :, :, None, None] |
|
).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w) |
|
|
|
return attn |
|
|
|
|
|
class MultiScaleAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out, |
|
input_size, |
|
num_heads=8, |
|
qkv_bias=False, |
|
drop_rate=0.0, |
|
kernel_q=(1, 1, 1), |
|
kernel_kv=(1, 1, 1), |
|
stride_q=(1, 1, 1), |
|
stride_kv=(1, 1, 1), |
|
norm_layer=nn.LayerNorm, |
|
has_cls_embed=True, |
|
|
|
mode="conv", |
|
|
|
pool_first=False, |
|
rel_pos_spatial=False, |
|
rel_pos_temporal=False, |
|
rel_pos_zero_init=False, |
|
residual_pooling=False, |
|
separate_qkv=False, |
|
): |
|
super().__init__() |
|
self.pool_first = pool_first |
|
self.separate_qkv = separate_qkv |
|
self.drop_rate = drop_rate |
|
self.num_heads = num_heads |
|
self.dim_out = dim_out |
|
head_dim = dim_out // num_heads |
|
self.scale = head_dim**-0.5 |
|
self.has_cls_embed = has_cls_embed |
|
self.mode = mode |
|
padding_q = [int(q // 2) for q in kernel_q] |
|
padding_kv = [int(kv // 2) for kv in kernel_kv] |
|
|
|
if pool_first or separate_qkv: |
|
self.q = nn.Linear(dim, dim_out, bias=qkv_bias) |
|
self.k = nn.Linear(dim, dim_out, bias=qkv_bias) |
|
self.v = nn.Linear(dim, dim_out, bias=qkv_bias) |
|
else: |
|
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) |
|
|
|
self.proj = nn.Linear(dim_out, dim_out) |
|
if drop_rate > 0.0: |
|
self.proj_drop = nn.Dropout(drop_rate) |
|
|
|
|
|
if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1: |
|
kernel_q = () |
|
if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1: |
|
kernel_kv = () |
|
|
|
if mode in ("avg", "max"): |
|
pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d |
|
self.pool_q = ( |
|
pool_op(kernel_q, stride_q, padding_q, ceil_mode=False) |
|
if len(kernel_q) > 0 |
|
else None |
|
) |
|
self.pool_k = ( |
|
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) |
|
if len(kernel_kv) > 0 |
|
else None |
|
) |
|
self.pool_v = ( |
|
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) |
|
if len(kernel_kv) > 0 |
|
else None |
|
) |
|
elif mode == "conv" or mode == "conv_unshared": |
|
if pool_first: |
|
dim_conv = dim // num_heads if mode == "conv" else dim |
|
else: |
|
dim_conv = dim_out // num_heads if mode == "conv" else dim_out |
|
self.pool_q = ( |
|
nn.Conv3d( |
|
dim_conv, |
|
dim_conv, |
|
kernel_q, |
|
stride=stride_q, |
|
padding=padding_q, |
|
groups=dim_conv, |
|
bias=False, |
|
) |
|
if len(kernel_q) > 0 |
|
else None |
|
) |
|
self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None |
|
self.pool_k = ( |
|
nn.Conv3d( |
|
dim_conv, |
|
dim_conv, |
|
kernel_kv, |
|
stride=stride_kv, |
|
padding=padding_kv, |
|
groups=dim_conv, |
|
bias=False, |
|
) |
|
if len(kernel_kv) > 0 |
|
else None |
|
) |
|
self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None |
|
self.pool_v = ( |
|
nn.Conv3d( |
|
dim_conv, |
|
dim_conv, |
|
kernel_kv, |
|
stride=stride_kv, |
|
padding=padding_kv, |
|
groups=dim_conv, |
|
bias=False, |
|
) |
|
if len(kernel_kv) > 0 |
|
else None |
|
) |
|
self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None |
|
else: |
|
raise NotImplementedError(f"Unsupported model {mode}") |
|
|
|
self.rel_pos_spatial = rel_pos_spatial |
|
self.rel_pos_temporal = rel_pos_temporal |
|
if self.rel_pos_spatial: |
|
assert input_size[1] == input_size[2] |
|
size = input_size[1] |
|
q_size = size // stride_q[1] if len(stride_q) > 0 else size |
|
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size |
|
rel_sp_dim = 2 * max(q_size, kv_size) - 1 |
|
|
|
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) |
|
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) |
|
if not rel_pos_zero_init: |
|
trunc_normal_(self.rel_pos_h, std=0.02) |
|
trunc_normal_(self.rel_pos_w, std=0.02) |
|
if self.rel_pos_temporal: |
|
self.rel_pos_t = nn.Parameter( |
|
torch.zeros(2 * input_size[0] - 1, head_dim) |
|
) |
|
if not rel_pos_zero_init: |
|
trunc_normal_(self.rel_pos_t, std=0.02) |
|
|
|
self.residual_pooling = residual_pooling |
|
|
|
def forward(self, x, thw_shape): |
|
B, N, _ = x.shape |
|
|
|
if self.pool_first: |
|
if self.mode == "conv_unshared": |
|
fold_dim = 1 |
|
else: |
|
fold_dim = self.num_heads |
|
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) |
|
q = k = v = x |
|
else: |
|
assert self.mode != "conv_unshared" |
|
if not self.separate_qkv: |
|
qkv = ( |
|
self.qkv(x) |
|
.reshape(B, N, 3, self.num_heads, -1) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
else: |
|
q = k = v = x |
|
q = ( |
|
self.q(q) |
|
.reshape(B, N, self.num_heads, -1) |
|
.permute(0, 2, 1, 3) |
|
) |
|
k = ( |
|
self.k(k) |
|
.reshape(B, N, self.num_heads, -1) |
|
.permute(0, 2, 1, 3) |
|
) |
|
v = ( |
|
self.v(v) |
|
.reshape(B, N, self.num_heads, -1) |
|
.permute(0, 2, 1, 3) |
|
) |
|
|
|
q, q_shape = attention_pool( |
|
q, |
|
self.pool_q, |
|
thw_shape, |
|
has_cls_embed=self.has_cls_embed, |
|
norm=self.norm_q if hasattr(self, "norm_q") else None, |
|
) |
|
k, k_shape = attention_pool( |
|
k, |
|
self.pool_k, |
|
thw_shape, |
|
has_cls_embed=self.has_cls_embed, |
|
norm=self.norm_k if hasattr(self, "norm_k") else None, |
|
) |
|
v, v_shape = attention_pool( |
|
v, |
|
self.pool_v, |
|
thw_shape, |
|
has_cls_embed=self.has_cls_embed, |
|
norm=self.norm_v if hasattr(self, "norm_v") else None, |
|
) |
|
|
|
if self.pool_first: |
|
q_N = ( |
|
numpy.prod(q_shape) + 1 |
|
if self.has_cls_embed |
|
else numpy.prod(q_shape) |
|
) |
|
k_N = ( |
|
numpy.prod(k_shape) + 1 |
|
if self.has_cls_embed |
|
else numpy.prod(k_shape) |
|
) |
|
v_N = ( |
|
numpy.prod(v_shape) + 1 |
|
if self.has_cls_embed |
|
else numpy.prod(v_shape) |
|
) |
|
|
|
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) |
|
q = ( |
|
self.q(q) |
|
.reshape(B, q_N, self.num_heads, -1) |
|
.permute(0, 2, 1, 3) |
|
) |
|
|
|
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) |
|
v = ( |
|
self.v(v) |
|
.reshape(B, v_N, self.num_heads, -1) |
|
.permute(0, 2, 1, 3) |
|
) |
|
|
|
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) |
|
k = ( |
|
self.k(k) |
|
.reshape(B, k_N, self.num_heads, -1) |
|
.permute(0, 2, 1, 3) |
|
) |
|
|
|
N = q.shape[2] |
|
attn = (q * self.scale) @ k.transpose(-2, -1) |
|
if self.rel_pos_spatial: |
|
attn = cal_rel_pos_spatial( |
|
attn, |
|
q, |
|
k, |
|
self.has_cls_embed, |
|
q_shape, |
|
k_shape, |
|
self.rel_pos_h, |
|
self.rel_pos_w, |
|
) |
|
|
|
if self.rel_pos_temporal: |
|
attn = cal_rel_pos_temporal( |
|
attn, |
|
q, |
|
self.has_cls_embed, |
|
q_shape, |
|
k_shape, |
|
self.rel_pos_t, |
|
) |
|
attn = attn.softmax(dim=-1) |
|
|
|
x = attn @ v |
|
|
|
if self.residual_pooling: |
|
if self.has_cls_embed: |
|
x[:, :, 1:, :] += q[:, :, 1:, :] |
|
else: |
|
x = x + q |
|
|
|
x = x.transpose(1, 2).reshape(B, -1, self.dim_out) |
|
x = self.proj(x) |
|
|
|
if self.drop_rate > 0.0: |
|
x = self.proj_drop(x) |
|
return x, q_shape |
|
|
|
|
|
class MultiScaleBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out, |
|
num_heads, |
|
input_size, |
|
mlp_ratio=4.0, |
|
qkv_bias=False, |
|
qk_scale=None, |
|
drop_rate=0.0, |
|
drop_path=0.0, |
|
layer_scale_init_value=0.0, |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
up_rate=None, |
|
kernel_q=(1, 1, 1), |
|
kernel_kv=(1, 1, 1), |
|
stride_q=(1, 1, 1), |
|
stride_kv=(1, 1, 1), |
|
mode="conv", |
|
has_cls_embed=True, |
|
pool_first=False, |
|
rel_pos_spatial=False, |
|
rel_pos_temporal=False, |
|
rel_pos_zero_init=False, |
|
residual_pooling=False, |
|
dim_mul_in_att=False, |
|
separate_qkv=False, |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
self.dim_out = dim_out |
|
self.norm1 = norm_layer(dim) |
|
self.dim_mul_in_att = dim_mul_in_att |
|
kernel_skip = [s + 1 if s > 1 else s for s in stride_q] |
|
stride_skip = stride_q |
|
padding_skip = [int(skip // 2) for skip in kernel_skip] |
|
att_dim = dim_out if dim_mul_in_att else dim |
|
self.attn = MultiScaleAttention( |
|
dim, |
|
att_dim, |
|
num_heads=num_heads, |
|
input_size=input_size, |
|
qkv_bias=qkv_bias, |
|
drop_rate=drop_rate, |
|
kernel_q=kernel_q, |
|
kernel_kv=kernel_kv, |
|
stride_q=stride_q, |
|
stride_kv=stride_kv, |
|
norm_layer=norm_layer, |
|
has_cls_embed=has_cls_embed, |
|
mode=mode, |
|
pool_first=pool_first, |
|
rel_pos_spatial=rel_pos_spatial, |
|
rel_pos_temporal=rel_pos_temporal, |
|
rel_pos_zero_init=rel_pos_zero_init, |
|
residual_pooling=residual_pooling, |
|
separate_qkv=separate_qkv, |
|
) |
|
self.drop_path = ( |
|
DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
) |
|
self.norm2 = norm_layer(att_dim) |
|
mlp_hidden_dim = int(att_dim * mlp_ratio) |
|
self.has_cls_embed = has_cls_embed |
|
|
|
if up_rate is not None and up_rate > 1: |
|
mlp_dim_out = dim * up_rate |
|
else: |
|
mlp_dim_out = dim_out |
|
self.mlp = Mlp( |
|
in_features=att_dim, |
|
hidden_features=mlp_hidden_dim, |
|
out_features=mlp_dim_out, |
|
act_layer=act_layer, |
|
drop_rate=drop_rate, |
|
) |
|
if layer_scale_init_value > 0: |
|
self.gamma_1 = nn.Parameter( |
|
layer_scale_init_value * torch.ones((dim)), requires_grad=True |
|
) |
|
self.gamma_2 = nn.Parameter( |
|
layer_scale_init_value * torch.ones((dim_out)), |
|
requires_grad=True, |
|
) |
|
else: |
|
self.gamma_1, self.gamma_2 = None, None |
|
|
|
if dim != dim_out: |
|
self.proj = nn.Linear(dim, dim_out) |
|
|
|
self.pool_skip = ( |
|
nn.MaxPool3d( |
|
kernel_skip, stride_skip, padding_skip, ceil_mode=False |
|
) |
|
if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1 |
|
else None |
|
) |
|
|
|
def forward(self, x, thw_shape=None): |
|
x_norm = self.norm1(x) |
|
x_block, thw_shape_new = self.attn(x_norm, thw_shape) |
|
if self.dim_mul_in_att and self.dim != self.dim_out: |
|
x = self.proj(x_norm) |
|
x_res, _ = attention_pool( |
|
x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed |
|
) |
|
if self.gamma_1 is not None: |
|
x = x_res + self.drop_path(self.gamma_1 * x_block) |
|
else: |
|
x = x_res + self.drop_path(x_block) |
|
x_norm = self.norm2(x) |
|
x_mlp = self.mlp(x_norm) |
|
if not self.dim_mul_in_att and self.dim != self.dim_out: |
|
x = self.proj(x_norm) |
|
if self.gamma_2 is not None: |
|
x = x + self.drop_path(self.gamma_2 * x_mlp) |
|
else: |
|
x = x + self.drop_path(x_mlp) |
|
if thw_shape: |
|
return x, thw_shape_new |
|
else: |
|
return x |