import ipdb import torch.nn as nn from xformers.ops import memory_efficient_attention class MEAttention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_norm=False, attn_drop=0.0, proj_drop=0.0, norm_layer=nn.LayerNorm, ): super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) # MEA expects [B, N, H, D], whereas timm uses [B, H, N, D] x = memory_efficient_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), scale=self.scale, ) x = x.reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x