johannesschmude's picture
Initial commit
b73936d
# Copyright (c) 2021 NVIDIA CORPORATION. Licensed under the MIT license.
# Written by Chen Zhu during an internship at NVIDIA, [email protected]
import math
from torch import nn
import torch
from timm.models.layers import trunc_normal_
import torch.nn.functional as F
class AttentionLS(nn.Module):
"""Implementation for long-short term attention.
Flexible options for using window attention, global token and dynamic projection.
Args:
dim: input and output feature dimension.
num_heads: number of attention heads.
qkv_bias: whether to use bias for the projection of query, key and values.
qk_scale: scale factor on query and key for numerical stability.
By default, set to square root of head dimensions.
attn_drop: dropout probability for attention matrix.
proj_drop: dropout probability for the final output.
rpe: whether to use relative position encoding.
nglo: number of global tokens (e.g., CLS).
"""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
rpe=False,
nglo=1,
dp_rank=2,
w=2,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.nglo = nglo
# Equals to segment size (w) in the paper.
self.window_size = w
# Equals to r in the paper.
self.dp_rank = dp_rank
if self.dp_rank > 0:
self.to_dynamic_projection = nn.Linear(dim, dp_rank * num_heads)
# The LN of DualLN corresponding to dynamic projection
self.dual_ln_dp = nn.LayerNorm(dim)
# The LN of DualLN corresponding to all the tokens
self.dual_ln_full = nn.LayerNorm(dim)
# Adapted from ViL: https://github.com/microsoft/vision-longformer/blob/main/src/models/layers/longformer2d.py#L55-L100
# We only add RPE to window attention.
# Unnecessary to add bias for global tokens, since DualLN already adds biases.
self.rpe = rpe
if rpe:
# handle the boarder conditions...
w_pad = int(w * 0.5)
self.local_relative_position_bias_table = nn.Parameter(
torch.zeros(2 * (w + w_pad - 1) * (2 * w_pad + w + 1) + 1, num_heads)
)
trunc_normal_(self.local_relative_position_bias_table, std=0.02)
# get pair-wise relative position index
coords_h = torch.arange(-w_pad, w_pad + w)
coords_w = torch.arange(-w_pad, w_pad + w)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 2w, 2w
coords = (
coords.view(2, (w + w_pad * 2) ** 2).transpose(0, 1).unsqueeze(0)
) # 1, 4w**2, 2
q_coords_hw = torch.arange(0, w)
q_coords = torch.stack(
torch.meshgrid([q_coords_hw, q_coords_hw])
) # 2, w, w
q_coords = q_coords.view(2, w**2).transpose(0, 1).unsqueeze(1) # w**2, 1, 2
relative_coords = q_coords - coords
relative_coords += w_pad + w - 1 # shift to start from 0
relative_coords[:, :, 0] *= 2 * w_pad + w
relative_position_index = relative_coords.sum(-1) # w^2, 4w^2
self.register_buffer("relative_position_index", relative_position_index)
def forward(self, x, nx=None, ny=None):
B, N, C = x.shape
N_feat = N - self.nglo
self.img_size = int(math.sqrt(N)) if nx is None else nx
qkv = self.qkv(x)
# query, key, value
q, k, v = qkv.chunk(3, dim=2)
q = q.mul(self.scale)
# Layer norm on the projected keys and values
k = self.dual_ln_full(k)
v = self.dual_ln_full(v)
# output size: bsz x n_heads x seqlen x d
if self.nglo > 0:
q_cls, q = q[:, : self.nglo], q[:, self.nglo :]
k_cls, k = k[:, : self.nglo], k[:, self.nglo :]
v_cls, v = v[:, : self.nglo], v[:, self.nglo :]
q_cls = q_cls.reshape(
B, self.nglo, self.num_heads, C // self.num_heads
).transpose(1, 2)
k_cls = k_cls.reshape(
B, self.nglo, self.num_heads, C // self.num_heads
).transpose(1, 2)
v_cls = v_cls.reshape(
B, self.nglo, self.num_heads, C // self.num_heads
).transpose(1, 2)
q = q.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
k = k.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
v = v.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
# Long-range Attention (Dynamic Projection)
if self.dp_rank > 0:
# b x h x r x (l w)
# Compute the projection matrix (P_i in the paper)
c_scores = (
self.to_dynamic_projection(x[:, self.nglo :])
.transpose(1, 2)
.contiguous()
.view(B, self.num_heads, self.dp_rank, -1)
)
# c_scores = c_scores.softmax(dim=-1, dtype=torch.float32).to(x)
c_scores = c_scores.softmax(dim=-1).to(
x
) # Changed when experimenting with mixed precision (Johannes S.)
# b x h x r x d
k_lms = c_scores.matmul(k)
k_lms = k_lms.transpose(1, 2).contiguous().view(B, self.dp_rank, -1)
k_lms = (
self.dual_ln_dp(k_lms)
.view(B, self.dp_rank, self.num_heads, -1)
.contiguous()
.permute(0, 2, 3, 1)
)
# b x h x (lw) x r
dots_all = q.matmul(k_lms)
if self.window_size > 0:
# Switch the order of dimensions if using window attention.
dots_all = self.group_dots(dots_all)
else:
dots_all = None
# Short-term Attention (Window Attention)
# In our window attention, each token attends to at most (4w^2) tokens.
if self.window_size > 0:
dots_win = self.compute_window_scores(q, k)
w2 = int(self.window_size * self.window_size)
if self.rpe:
w_pad = int(0.5 * self.window_size)
local_relative_position_bias = self.local_relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
1, w2, (w_pad * 2 + self.window_size) ** 2, -1
) # w^2, kv_nums,H
local_relative_position_bias = (
local_relative_position_bias.permute(0, 3, 1, 2)
.expand(B, -1, -1, -1)
.unsqueeze(2)
.unsqueeze(2)
)
dots_win += local_relative_position_bias
if dots_all is None:
dots_all = dots_win
else:
dots_all = torch.cat([dots_all, dots_win], dim=-1)
# Global token.
if self.nglo > 0:
# and compute the scores of queries on CLS
dots_q_cls = q.matmul(k_cls.transpose(-1, -2))
if self.window_size > 0:
dots_q_cls = self.group_dots(dots_q_cls)
dots_all = torch.cat([dots_all, dots_q_cls], dim=-1)
# attn = dots_all.softmax(dim=-1, dtype=torch.float32).to(x)
attn = dots_all.softmax(dim=-1).to(
x
) # Changed when experimenting with mixed precision (Johannes S.)
attn = self.attn_drop(attn)
out = 0
if self.window_size > 0:
offset = max(0, self.dp_rank)
kv_group_size = self.window_size
total_win_size = max(1, self.window_size // 2) * 2 + kv_group_size
attn_win = attn[:, :, :, :, :, offset : offset + total_win_size**2]
out += self.compute_window_pv(attn_win, v)
attn = self.ungroup_dots(attn)
# attn will be b x h x lw x n_k from now on
if self.dp_rank > 0:
attn_lm = attn[:, :, :, : self.dp_rank]
v_lms = (
# c_scores.matmul(v.float())
c_scores.matmul(
v
) # Changed when experimenting with mixed precision (Johannes S.)
.to(v)
.transpose(1, 2)
.contiguous()
.view(B, self.dp_rank, -1)
)
v_lms = (
self.dual_ln_dp(v_lms)
.view(B, self.dp_rank, self.num_heads, -1)
.contiguous()
.transpose(1, 2)
)
out += attn_lm.matmul(v_lms)
if self.nglo > 0:
attn_cls = attn[:, :, :, -self.nglo :]
out += attn_cls.matmul(
v_cls
) # Changed. Was `.mul` instead of `.matmul`. (JWS)
# b x h x 1 x lw
cls_inner = q_cls.matmul(k_cls.transpose(-1, -2))
cls_dots = q_cls.matmul(
k.transpose(-1, -2)
) # Changed. Was `out` instead of `k`. (JWS)
cls_dots = torch.cat([cls_inner, cls_dots], dim=-1)
# cls_dots = cls_dots.softmax(dim=-1, dtype=torch.float32).to(x)
cls_dots = cls_dots.softmax(dim=-1).to(
x
) # Changed when experimenting with mixed precision (Johannes S.)
cls_next = cls_dots[:, :, :, self.nglo :].matmul(
v
) # the post_cls variant # Changed. Was `out` instead of `v`. (JWS)
cls_next += cls_dots[:, :, :, : self.nglo].matmul(v_cls)
out = torch.cat([cls_next, out], dim=2)
out = out.transpose(1, 2).contiguous().view(B, N, -1)
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.proj(out)
out = self.proj_drop(out)
return out
def compute_window_scores(self, q, k):
"""Compute the inner products for the window attention.
Frist, divide the query into non-overlapping windows.
Then, use torch.as_trided (implemented in self.get_overlapping_tiles) to create a view of the keys
that corresponds to the windows with at most 2x memory overhead.
Finally, compute the inner product.
"""
# q: b h (l w) d
b, h, _, d = q.shape
side_size = max(self.window_size // 2, 1)
# q_group_size: segment size
kv_width = 2 * side_size + self.window_size # assuming q_stride=1
q_n_group = self.img_size // self.window_size
q_tiles = q.reshape(
b, h, q_n_group, self.window_size, q_n_group, self.window_size, d
).permute(0, 1, 2, 4, 3, 5, 6)
# q_tiles: b x h x n_group x n_group x w^2 x d
q_tiles = q_tiles.contiguous().view(b, h, q_n_group, q_n_group, -1, d)
# k_tiles: b x h x n_group x n_group x 9w^2 x d
k_tiles = (
self.get_overlapping_tiles(k)
.contiguous()
.view(b, h, q_n_group, q_n_group, -1, d)
)
# dot_tiles: b x h x n_group x n_group x w^2 x 9w^2
dot_tiles = q_tiles.matmul(k_tiles.transpose(-1, -2))
# fill "-inf" into the zero-padding parts
dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width, kv_width)
dot_tiles[:, :, 0, :, :, :side_size].fill_(float("-inf"))
dot_tiles[:, :, -1, :, :, -side_size:].fill_(float("-inf"))
dot_tiles[:, :, :, 0, :, :, :side_size].fill_(float("-inf"))
dot_tiles[:, :, :, -1, :, :, -side_size:].fill_(float("-inf"))
dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width**2)
return dot_tiles
def get_overlapping_tiles(self, x):
"""Get overlapping tiles in the 2D spatial domain, ensuring each query computes correlation with all neighbors"""
# x: b h (l w) d
b, h, _, d = x.shape
side_size = max(self.window_size // 2, 1)
total_size = 2 * side_size + self.window_size
kv_group_size = self.window_size
kv_width = self.img_size
x = x.view(b, h, kv_width, kv_width, d)
x = F.pad(x, [0, 0, side_size, side_size, side_size, side_size], value=0)
out_shape = [
b,
h,
kv_width // kv_group_size,
kv_width // kv_group_size,
total_size,
total_size,
d,
]
in_stride = x.stride()
out_stride = [
in_stride[0],
in_stride[1],
in_stride[2] * kv_group_size,
in_stride[3] * kv_group_size,
in_stride[2],
in_stride[3],
in_stride[4],
]
# note we ignored the boundary here
return x.as_strided(size=out_shape, stride=out_stride)
def compute_window_pv(self, attn, v):
"""Compute the inner product of attention matrix and the values for the window attention."""
b, h, n_group, _, w2, n_k = attn.shape
d = v.shape[-1]
v_tiles = (
self.get_overlapping_tiles(v)
.contiguous()
.view(b, h, n_group, n_group, -1, d)
)
# b x h x n_group x n_group x w^2 x d
pv = attn.matmul(v_tiles)
# return: b x h x (lw) x d
ret = self.ungroup_dots(pv)
return ret
def group_dots(self, dots):
b, h = dots.shape[:2]
n_group = self.img_size // self.window_size
dots = dots.reshape(
b, h, n_group, self.window_size, n_group, self.window_size, -1
).permute(0, 1, 2, 4, 3, 5, 6)
dots = dots.contiguous().view(
b, h, n_group, n_group, self.window_size * self.window_size, -1
)
return dots
def ungroup_dots(self, dots):
b, h, n_group, _, _, n_keys = dots.shape
dots = dots.reshape(
b, h, n_group, n_group, self.window_size, self.window_size, -1
).permute(0, 1, 2, 4, 3, 5, 6)
dots = dots.contiguous().view(b, h, -1, n_keys)
return dots