ianpan's picture
Initial commit
231edce
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from .common import DropPath, Mlp
def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
if pool is None:
return tensor, thw_shape
tensor_dim = tensor.ndim
if tensor_dim == 4:
pass
elif tensor_dim == 3:
tensor = tensor.unsqueeze(1)
else:
raise NotImplementedError(f"Unsupported input dimension {tensor.shape}")
if has_cls_embed:
cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
B, N, L, C = tensor.shape
T, H, W = thw_shape
tensor = (
tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()
)
tensor = pool(tensor)
thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
if has_cls_embed:
tensor = torch.cat((cls_tok, tensor), dim=2)
if norm is not None:
tensor = norm(tensor)
# Assert tensor_dim in [3, 4]
if tensor_dim == 4:
pass
else: # tensor_dim == 3:
tensor = tensor.squeeze(1)
return tensor, thw_shape
def get_rel_pos(rel_pos, d):
if isinstance(d, int):
ori_d = rel_pos.shape[0]
if ori_d == d:
return rel_pos
else:
# Interpolate rel pos.
new_pos_embed = F.interpolate(
rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1),
size=d,
mode="linear",
)
return new_pos_embed.reshape(-1, d).permute(1, 0)
def cal_rel_pos_spatial(
attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h, rel_pos_w
):
"""
Decomposed Spatial Relative Positional Embeddings.
"""
sp_idx = 1 if has_cls_embed else 0
q_t, q_h, q_w = q_shape
k_t, k_h, k_w = k_shape
dh = int(2 * max(q_h, k_h) - 1)
dw = int(2 * max(q_w, k_w) - 1)
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0)
k_h_ratio = max(q_h / k_h, 1.0)
dist_h = (
torch.arange(q_h)[:, None] * q_h_ratio
- torch.arange(k_h)[None, :] * k_h_ratio
)
dist_h += (k_h - 1) * k_h_ratio
q_w_ratio = max(k_w / q_w, 1.0)
k_w_ratio = max(q_w / k_w, 1.0)
dist_w = (
torch.arange(q_w)[:, None] * q_w_ratio
- torch.arange(k_w)[None, :] * k_w_ratio
)
dist_w += (k_w - 1) * k_w_ratio
# Intepolate rel pos if needed.
rel_pos_h = get_rel_pos(rel_pos_h, dh)
rel_pos_w = get_rel_pos(rel_pos_w, dw)
Rh = rel_pos_h[dist_h.long()]
Rw = rel_pos_w[dist_w.long()]
B, n_head, q_N, dim = q.shape
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
rel_h_q = torch.einsum(
"bythwc,hkc->bythwk", r_q, Rh
) # [B, H, q_t, qh, qw, k_h]
rel_w_q = torch.einsum(
"bythwc,wkc->bythwk", r_q, Rw
) # [B, H, q_t, qh, qw, k_w]
attn[:, :, sp_idx:, sp_idx:] = (
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
+ rel_h_q[:, :, :, :, :, None, :, None]
+ rel_w_q[:, :, :, :, :, None, None, :]
).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w)
return attn
def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t):
"""
Temporal Relative Positional Embeddings.
"""
sp_idx = 1 if has_cls_embed else 0
q_t, q_h, q_w = q_shape
k_t, k_h, k_w = k_shape
dt = int(2 * max(q_t, k_t) - 1)
# Intepolate rel pos if needed.
rel_pos_t = get_rel_pos(rel_pos_t, dt)
# Scale up rel pos if shapes for q and k are different.
q_t_ratio = max(k_t / q_t, 1.0)
k_t_ratio = max(q_t / k_t, 1.0)
dist_t = (
torch.arange(q_t)[:, None] * q_t_ratio
- torch.arange(k_t)[None, :] * k_t_ratio
)
dist_t += (k_t - 1) * k_t_ratio
Rt = rel_pos_t[dist_t.long()]
B, n_head, q_N, dim = q.shape
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
# [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(
q_t, B * n_head * q_h * q_w, dim
)
# [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
# [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
attn[:, :, sp_idx:, sp_idx:] = (
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
+ rel[:, :, :, :, :, :, None, None]
).view(B, -1, q_t * q_h * q_w, k_t * k_h * k_w)
return attn
class MultiScaleAttention(nn.Module):
def __init__(
self,
dim,
dim_out,
input_size,
num_heads=8,
qkv_bias=False,
drop_rate=0.0,
kernel_q=(1, 1, 1),
kernel_kv=(1, 1, 1),
stride_q=(1, 1, 1),
stride_kv=(1, 1, 1),
norm_layer=nn.LayerNorm,
has_cls_embed=True,
# Options include `conv`, `avg`, and `max`.
mode="conv",
# If True, perform pool before projection.
pool_first=False,
rel_pos_spatial=False,
rel_pos_temporal=False,
rel_pos_zero_init=False,
residual_pooling=False,
separate_qkv=False,
):
super().__init__()
self.pool_first = pool_first
self.separate_qkv = separate_qkv
self.drop_rate = drop_rate
self.num_heads = num_heads
self.dim_out = dim_out
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5
self.has_cls_embed = has_cls_embed
self.mode = mode
padding_q = [int(q // 2) for q in kernel_q]
padding_kv = [int(kv // 2) for kv in kernel_kv]
if pool_first or separate_qkv:
self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
else:
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
self.proj = nn.Linear(dim_out, dim_out)
if drop_rate > 0.0:
self.proj_drop = nn.Dropout(drop_rate)
# Skip pooling with kernel and stride size of (1, 1, 1).
if numpy.prod(kernel_q) == 1 and numpy.prod(stride_q) == 1:
kernel_q = ()
if numpy.prod(kernel_kv) == 1 and numpy.prod(stride_kv) == 1:
kernel_kv = ()
if mode in ("avg", "max"):
pool_op = nn.MaxPool3d if mode == "max" else nn.AvgPool3d
self.pool_q = (
pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
if len(kernel_q) > 0
else None
)
self.pool_k = (
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
if len(kernel_kv) > 0
else None
)
self.pool_v = (
pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
if len(kernel_kv) > 0
else None
)
elif mode == "conv" or mode == "conv_unshared":
if pool_first:
dim_conv = dim // num_heads if mode == "conv" else dim
else:
dim_conv = dim_out // num_heads if mode == "conv" else dim_out
self.pool_q = (
nn.Conv3d(
dim_conv,
dim_conv,
kernel_q,
stride=stride_q,
padding=padding_q,
groups=dim_conv,
bias=False,
)
if len(kernel_q) > 0
else None
)
self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None
self.pool_k = (
nn.Conv3d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
if len(kernel_kv) > 0
else None
)
self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
self.pool_v = (
nn.Conv3d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
if len(kernel_kv) > 0
else None
)
self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
else:
raise NotImplementedError(f"Unsupported model {mode}")
self.rel_pos_spatial = rel_pos_spatial
self.rel_pos_temporal = rel_pos_temporal
if self.rel_pos_spatial:
assert input_size[1] == input_size[2]
size = input_size[1]
q_size = size // stride_q[1] if len(stride_q) > 0 else size
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
rel_sp_dim = 2 * max(q_size, kv_size) - 1
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
if not rel_pos_zero_init:
trunc_normal_(self.rel_pos_h, std=0.02)
trunc_normal_(self.rel_pos_w, std=0.02)
if self.rel_pos_temporal:
self.rel_pos_t = nn.Parameter(
torch.zeros(2 * input_size[0] - 1, head_dim)
)
if not rel_pos_zero_init:
trunc_normal_(self.rel_pos_t, std=0.02)
self.residual_pooling = residual_pooling
def forward(self, x, thw_shape):
B, N, _ = x.shape
if self.pool_first:
if self.mode == "conv_unshared":
fold_dim = 1
else:
fold_dim = self.num_heads
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
q = k = v = x
else:
assert self.mode != "conv_unshared"
if not self.separate_qkv:
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, -1)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
else:
q = k = v = x
q = (
self.q(q)
.reshape(B, N, self.num_heads, -1)
.permute(0, 2, 1, 3)
)
k = (
self.k(k)
.reshape(B, N, self.num_heads, -1)
.permute(0, 2, 1, 3)
)
v = (
self.v(v)
.reshape(B, N, self.num_heads, -1)
.permute(0, 2, 1, 3)
)
q, q_shape = attention_pool(
q,
self.pool_q,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_q if hasattr(self, "norm_q") else None,
)
k, k_shape = attention_pool(
k,
self.pool_k,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_k if hasattr(self, "norm_k") else None,
)
v, v_shape = attention_pool(
v,
self.pool_v,
thw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_v if hasattr(self, "norm_v") else None,
)
if self.pool_first:
q_N = (
numpy.prod(q_shape) + 1
if self.has_cls_embed
else numpy.prod(q_shape)
)
k_N = (
numpy.prod(k_shape) + 1
if self.has_cls_embed
else numpy.prod(k_shape)
)
v_N = (
numpy.prod(v_shape) + 1
if self.has_cls_embed
else numpy.prod(v_shape)
)
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
q = (
self.q(q)
.reshape(B, q_N, self.num_heads, -1)
.permute(0, 2, 1, 3)
)
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
v = (
self.v(v)
.reshape(B, v_N, self.num_heads, -1)
.permute(0, 2, 1, 3)
)
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
k = (
self.k(k)
.reshape(B, k_N, self.num_heads, -1)
.permute(0, 2, 1, 3)
)
N = q.shape[2]
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.rel_pos_spatial:
attn = cal_rel_pos_spatial(
attn,
q,
k,
self.has_cls_embed,
q_shape,
k_shape,
self.rel_pos_h,
self.rel_pos_w,
)
if self.rel_pos_temporal:
attn = cal_rel_pos_temporal(
attn,
q,
self.has_cls_embed,
q_shape,
k_shape,
self.rel_pos_t,
)
attn = attn.softmax(dim=-1)
x = attn @ v
if self.residual_pooling:
if self.has_cls_embed:
x[:, :, 1:, :] += q[:, :, 1:, :]
else:
x = x + q
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
x = self.proj(x)
if self.drop_rate > 0.0:
x = self.proj_drop(x)
return x, q_shape
class MultiScaleBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
num_heads,
input_size,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
drop_path=0.0,
layer_scale_init_value=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
up_rate=None,
kernel_q=(1, 1, 1),
kernel_kv=(1, 1, 1),
stride_q=(1, 1, 1),
stride_kv=(1, 1, 1),
mode="conv",
has_cls_embed=True,
pool_first=False,
rel_pos_spatial=False,
rel_pos_temporal=False,
rel_pos_zero_init=False,
residual_pooling=False,
dim_mul_in_att=False,
separate_qkv=False,
):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.norm1 = norm_layer(dim)
self.dim_mul_in_att = dim_mul_in_att
kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
stride_skip = stride_q
padding_skip = [int(skip // 2) for skip in kernel_skip]
att_dim = dim_out if dim_mul_in_att else dim
self.attn = MultiScaleAttention(
dim,
att_dim,
num_heads=num_heads,
input_size=input_size,
qkv_bias=qkv_bias,
drop_rate=drop_rate,
kernel_q=kernel_q,
kernel_kv=kernel_kv,
stride_q=stride_q,
stride_kv=stride_kv,
norm_layer=norm_layer,
has_cls_embed=has_cls_embed,
mode=mode,
pool_first=pool_first,
rel_pos_spatial=rel_pos_spatial,
rel_pos_temporal=rel_pos_temporal,
rel_pos_zero_init=rel_pos_zero_init,
residual_pooling=residual_pooling,
separate_qkv=separate_qkv,
)
self.drop_path = (
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)
self.norm2 = norm_layer(att_dim)
mlp_hidden_dim = int(att_dim * mlp_ratio)
self.has_cls_embed = has_cls_embed
# TODO: check the use case for up_rate, and merge the following lines
if up_rate is not None and up_rate > 1:
mlp_dim_out = dim * up_rate
else:
mlp_dim_out = dim_out
self.mlp = Mlp(
in_features=att_dim,
hidden_features=mlp_hidden_dim,
out_features=mlp_dim_out,
act_layer=act_layer,
drop_rate=drop_rate,
)
if layer_scale_init_value > 0:
self.gamma_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True
)
self.gamma_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim_out)),
requires_grad=True,
)
else:
self.gamma_1, self.gamma_2 = None, None
if dim != dim_out:
self.proj = nn.Linear(dim, dim_out)
self.pool_skip = (
nn.MaxPool3d(
kernel_skip, stride_skip, padding_skip, ceil_mode=False
)
if len(stride_skip) > 0 and numpy.prod(stride_skip) > 1
else None
)
def forward(self, x, thw_shape=None):
x_norm = self.norm1(x)
x_block, thw_shape_new = self.attn(x_norm, thw_shape)
if self.dim_mul_in_att and self.dim != self.dim_out:
x = self.proj(x_norm)
x_res, _ = attention_pool(
x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed
)
if self.gamma_1 is not None:
x = x_res + self.drop_path(self.gamma_1 * x_block)
else:
x = x_res + self.drop_path(x_block)
x_norm = self.norm2(x)
x_mlp = self.mlp(x_norm)
if not self.dim_mul_in_att and self.dim != self.dim_out:
x = self.proj(x_norm)
if self.gamma_2 is not None:
x = x + self.drop_path(self.gamma_2 * x_mlp)
else:
x = x + self.drop_path(x_mlp)
if thw_shape:
return x, thw_shape_new
else:
return x