Spaces:
Runtime error
Runtime error
# Copyright (c) 2021 NVIDIA CORPORATION. Licensed under the MIT license. | |
# Written by Chen Zhu during an internship at NVIDIA, [email protected] | |
import math | |
from torch import nn | |
import torch | |
from timm.models.layers import trunc_normal_ | |
import torch.nn.functional as F | |
class AttentionLS(nn.Module): | |
"""Implementation for long-short term attention. | |
Flexible options for using window attention, global token and dynamic projection. | |
Args: | |
dim: input and output feature dimension. | |
num_heads: number of attention heads. | |
qkv_bias: whether to use bias for the projection of query, key and values. | |
qk_scale: scale factor on query and key for numerical stability. | |
By default, set to square root of head dimensions. | |
attn_drop: dropout probability for attention matrix. | |
proj_drop: dropout probability for the final output. | |
rpe: whether to use relative position encoding. | |
nglo: number of global tokens (e.g., CLS). | |
""" | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
rpe=False, | |
nglo=1, | |
dp_rank=2, | |
w=2, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |
self.scale = qk_scale or head_dim**-0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.nglo = nglo | |
# Equals to segment size (w) in the paper. | |
self.window_size = w | |
# Equals to r in the paper. | |
self.dp_rank = dp_rank | |
if self.dp_rank > 0: | |
self.to_dynamic_projection = nn.Linear(dim, dp_rank * num_heads) | |
# The LN of DualLN corresponding to dynamic projection | |
self.dual_ln_dp = nn.LayerNorm(dim) | |
# The LN of DualLN corresponding to all the tokens | |
self.dual_ln_full = nn.LayerNorm(dim) | |
# Adapted from ViL: https://github.com/microsoft/vision-longformer/blob/main/src/models/layers/longformer2d.py#L55-L100 | |
# We only add RPE to window attention. | |
# Unnecessary to add bias for global tokens, since DualLN already adds biases. | |
self.rpe = rpe | |
if rpe: | |
# handle the boarder conditions... | |
w_pad = int(w * 0.5) | |
self.local_relative_position_bias_table = nn.Parameter( | |
torch.zeros(2 * (w + w_pad - 1) * (2 * w_pad + w + 1) + 1, num_heads) | |
) | |
trunc_normal_(self.local_relative_position_bias_table, std=0.02) | |
# get pair-wise relative position index | |
coords_h = torch.arange(-w_pad, w_pad + w) | |
coords_w = torch.arange(-w_pad, w_pad + w) | |
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 2w, 2w | |
coords = ( | |
coords.view(2, (w + w_pad * 2) ** 2).transpose(0, 1).unsqueeze(0) | |
) # 1, 4w**2, 2 | |
q_coords_hw = torch.arange(0, w) | |
q_coords = torch.stack( | |
torch.meshgrid([q_coords_hw, q_coords_hw]) | |
) # 2, w, w | |
q_coords = q_coords.view(2, w**2).transpose(0, 1).unsqueeze(1) # w**2, 1, 2 | |
relative_coords = q_coords - coords | |
relative_coords += w_pad + w - 1 # shift to start from 0 | |
relative_coords[:, :, 0] *= 2 * w_pad + w | |
relative_position_index = relative_coords.sum(-1) # w^2, 4w^2 | |
self.register_buffer("relative_position_index", relative_position_index) | |
def forward(self, x, nx=None, ny=None): | |
B, N, C = x.shape | |
N_feat = N - self.nglo | |
self.img_size = int(math.sqrt(N)) if nx is None else nx | |
qkv = self.qkv(x) | |
# query, key, value | |
q, k, v = qkv.chunk(3, dim=2) | |
q = q.mul(self.scale) | |
# Layer norm on the projected keys and values | |
k = self.dual_ln_full(k) | |
v = self.dual_ln_full(v) | |
# output size: bsz x n_heads x seqlen x d | |
if self.nglo > 0: | |
q_cls, q = q[:, : self.nglo], q[:, self.nglo :] | |
k_cls, k = k[:, : self.nglo], k[:, self.nglo :] | |
v_cls, v = v[:, : self.nglo], v[:, self.nglo :] | |
q_cls = q_cls.reshape( | |
B, self.nglo, self.num_heads, C // self.num_heads | |
).transpose(1, 2) | |
k_cls = k_cls.reshape( | |
B, self.nglo, self.num_heads, C // self.num_heads | |
).transpose(1, 2) | |
v_cls = v_cls.reshape( | |
B, self.nglo, self.num_heads, C // self.num_heads | |
).transpose(1, 2) | |
q = q.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2) | |
k = k.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2) | |
v = v.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2) | |
# Long-range Attention (Dynamic Projection) | |
if self.dp_rank > 0: | |
# b x h x r x (l w) | |
# Compute the projection matrix (P_i in the paper) | |
c_scores = ( | |
self.to_dynamic_projection(x[:, self.nglo :]) | |
.transpose(1, 2) | |
.contiguous() | |
.view(B, self.num_heads, self.dp_rank, -1) | |
) | |
# c_scores = c_scores.softmax(dim=-1, dtype=torch.float32).to(x) | |
c_scores = c_scores.softmax(dim=-1).to( | |
x | |
) # Changed when experimenting with mixed precision (Johannes S.) | |
# b x h x r x d | |
k_lms = c_scores.matmul(k) | |
k_lms = k_lms.transpose(1, 2).contiguous().view(B, self.dp_rank, -1) | |
k_lms = ( | |
self.dual_ln_dp(k_lms) | |
.view(B, self.dp_rank, self.num_heads, -1) | |
.contiguous() | |
.permute(0, 2, 3, 1) | |
) | |
# b x h x (lw) x r | |
dots_all = q.matmul(k_lms) | |
if self.window_size > 0: | |
# Switch the order of dimensions if using window attention. | |
dots_all = self.group_dots(dots_all) | |
else: | |
dots_all = None | |
# Short-term Attention (Window Attention) | |
# In our window attention, each token attends to at most (4w^2) tokens. | |
if self.window_size > 0: | |
dots_win = self.compute_window_scores(q, k) | |
w2 = int(self.window_size * self.window_size) | |
if self.rpe: | |
w_pad = int(0.5 * self.window_size) | |
local_relative_position_bias = self.local_relative_position_bias_table[ | |
self.relative_position_index.view(-1) | |
].view( | |
1, w2, (w_pad * 2 + self.window_size) ** 2, -1 | |
) # w^2, kv_nums,H | |
local_relative_position_bias = ( | |
local_relative_position_bias.permute(0, 3, 1, 2) | |
.expand(B, -1, -1, -1) | |
.unsqueeze(2) | |
.unsqueeze(2) | |
) | |
dots_win += local_relative_position_bias | |
if dots_all is None: | |
dots_all = dots_win | |
else: | |
dots_all = torch.cat([dots_all, dots_win], dim=-1) | |
# Global token. | |
if self.nglo > 0: | |
# and compute the scores of queries on CLS | |
dots_q_cls = q.matmul(k_cls.transpose(-1, -2)) | |
if self.window_size > 0: | |
dots_q_cls = self.group_dots(dots_q_cls) | |
dots_all = torch.cat([dots_all, dots_q_cls], dim=-1) | |
# attn = dots_all.softmax(dim=-1, dtype=torch.float32).to(x) | |
attn = dots_all.softmax(dim=-1).to( | |
x | |
) # Changed when experimenting with mixed precision (Johannes S.) | |
attn = self.attn_drop(attn) | |
out = 0 | |
if self.window_size > 0: | |
offset = max(0, self.dp_rank) | |
kv_group_size = self.window_size | |
total_win_size = max(1, self.window_size // 2) * 2 + kv_group_size | |
attn_win = attn[:, :, :, :, :, offset : offset + total_win_size**2] | |
out += self.compute_window_pv(attn_win, v) | |
attn = self.ungroup_dots(attn) | |
# attn will be b x h x lw x n_k from now on | |
if self.dp_rank > 0: | |
attn_lm = attn[:, :, :, : self.dp_rank] | |
v_lms = ( | |
# c_scores.matmul(v.float()) | |
c_scores.matmul( | |
v | |
) # Changed when experimenting with mixed precision (Johannes S.) | |
.to(v) | |
.transpose(1, 2) | |
.contiguous() | |
.view(B, self.dp_rank, -1) | |
) | |
v_lms = ( | |
self.dual_ln_dp(v_lms) | |
.view(B, self.dp_rank, self.num_heads, -1) | |
.contiguous() | |
.transpose(1, 2) | |
) | |
out += attn_lm.matmul(v_lms) | |
if self.nglo > 0: | |
attn_cls = attn[:, :, :, -self.nglo :] | |
out += attn_cls.matmul( | |
v_cls | |
) # Changed. Was `.mul` instead of `.matmul`. (JWS) | |
# b x h x 1 x lw | |
cls_inner = q_cls.matmul(k_cls.transpose(-1, -2)) | |
cls_dots = q_cls.matmul( | |
k.transpose(-1, -2) | |
) # Changed. Was `out` instead of `k`. (JWS) | |
cls_dots = torch.cat([cls_inner, cls_dots], dim=-1) | |
# cls_dots = cls_dots.softmax(dim=-1, dtype=torch.float32).to(x) | |
cls_dots = cls_dots.softmax(dim=-1).to( | |
x | |
) # Changed when experimenting with mixed precision (Johannes S.) | |
cls_next = cls_dots[:, :, :, self.nglo :].matmul( | |
v | |
) # the post_cls variant # Changed. Was `out` instead of `v`. (JWS) | |
cls_next += cls_dots[:, :, :, : self.nglo].matmul(v_cls) | |
out = torch.cat([cls_next, out], dim=2) | |
out = out.transpose(1, 2).contiguous().view(B, N, -1) | |
# x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
out = self.proj(out) | |
out = self.proj_drop(out) | |
return out | |
def compute_window_scores(self, q, k): | |
"""Compute the inner products for the window attention. | |
Frist, divide the query into non-overlapping windows. | |
Then, use torch.as_trided (implemented in self.get_overlapping_tiles) to create a view of the keys | |
that corresponds to the windows with at most 2x memory overhead. | |
Finally, compute the inner product. | |
""" | |
# q: b h (l w) d | |
b, h, _, d = q.shape | |
side_size = max(self.window_size // 2, 1) | |
# q_group_size: segment size | |
kv_width = 2 * side_size + self.window_size # assuming q_stride=1 | |
q_n_group = self.img_size // self.window_size | |
q_tiles = q.reshape( | |
b, h, q_n_group, self.window_size, q_n_group, self.window_size, d | |
).permute(0, 1, 2, 4, 3, 5, 6) | |
# q_tiles: b x h x n_group x n_group x w^2 x d | |
q_tiles = q_tiles.contiguous().view(b, h, q_n_group, q_n_group, -1, d) | |
# k_tiles: b x h x n_group x n_group x 9w^2 x d | |
k_tiles = ( | |
self.get_overlapping_tiles(k) | |
.contiguous() | |
.view(b, h, q_n_group, q_n_group, -1, d) | |
) | |
# dot_tiles: b x h x n_group x n_group x w^2 x 9w^2 | |
dot_tiles = q_tiles.matmul(k_tiles.transpose(-1, -2)) | |
# fill "-inf" into the zero-padding parts | |
dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width, kv_width) | |
dot_tiles[:, :, 0, :, :, :side_size].fill_(float("-inf")) | |
dot_tiles[:, :, -1, :, :, -side_size:].fill_(float("-inf")) | |
dot_tiles[:, :, :, 0, :, :, :side_size].fill_(float("-inf")) | |
dot_tiles[:, :, :, -1, :, :, -side_size:].fill_(float("-inf")) | |
dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width**2) | |
return dot_tiles | |
def get_overlapping_tiles(self, x): | |
"""Get overlapping tiles in the 2D spatial domain, ensuring each query computes correlation with all neighbors""" | |
# x: b h (l w) d | |
b, h, _, d = x.shape | |
side_size = max(self.window_size // 2, 1) | |
total_size = 2 * side_size + self.window_size | |
kv_group_size = self.window_size | |
kv_width = self.img_size | |
x = x.view(b, h, kv_width, kv_width, d) | |
x = F.pad(x, [0, 0, side_size, side_size, side_size, side_size], value=0) | |
out_shape = [ | |
b, | |
h, | |
kv_width // kv_group_size, | |
kv_width // kv_group_size, | |
total_size, | |
total_size, | |
d, | |
] | |
in_stride = x.stride() | |
out_stride = [ | |
in_stride[0], | |
in_stride[1], | |
in_stride[2] * kv_group_size, | |
in_stride[3] * kv_group_size, | |
in_stride[2], | |
in_stride[3], | |
in_stride[4], | |
] | |
# note we ignored the boundary here | |
return x.as_strided(size=out_shape, stride=out_stride) | |
def compute_window_pv(self, attn, v): | |
"""Compute the inner product of attention matrix and the values for the window attention.""" | |
b, h, n_group, _, w2, n_k = attn.shape | |
d = v.shape[-1] | |
v_tiles = ( | |
self.get_overlapping_tiles(v) | |
.contiguous() | |
.view(b, h, n_group, n_group, -1, d) | |
) | |
# b x h x n_group x n_group x w^2 x d | |
pv = attn.matmul(v_tiles) | |
# return: b x h x (lw) x d | |
ret = self.ungroup_dots(pv) | |
return ret | |
def group_dots(self, dots): | |
b, h = dots.shape[:2] | |
n_group = self.img_size // self.window_size | |
dots = dots.reshape( | |
b, h, n_group, self.window_size, n_group, self.window_size, -1 | |
).permute(0, 1, 2, 4, 3, 5, 6) | |
dots = dots.contiguous().view( | |
b, h, n_group, n_group, self.window_size * self.window_size, -1 | |
) | |
return dots | |
def ungroup_dots(self, dots): | |
b, h, n_group, _, _, n_keys = dots.shape | |
dots = dots.reshape( | |
b, h, n_group, n_group, self.window_size, self.window_size, -1 | |
).permute(0, 1, 2, 4, 3, 5, 6) | |
dots = dots.contiguous().view(b, h, -1, n_keys) | |
return dots | |