DiffusionSfM / diffusionsfm /model /memory_efficient_attention.py
qitaoz's picture
Upload 57 files
4562a06 verified
import ipdb
import torch.nn as nn
from xformers.ops import memory_efficient_attention
class MEAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.0,
proj_drop=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# MEA expects [B, N, H, D], whereas timm uses [B, H, N, D]
x = memory_efficient_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
scale=self.scale,
)
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x