Spaces:
Running
Running
File size: 7,719 Bytes
506a2b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# Copyright (c) 2025 Resemble AI
# Author: Manmay Nakhashi
# MIT License
import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
class RelativePositionBias(nn.Module):
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, qk_dots):
i, j, device = *qk_dots.shape[-2:], qk_dots.device
q_pos = torch.arange(i, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
max_distance=self.max_distance)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j')
return qk_dots + (bias * self.scale)
class AttentionQKV(nn.Module):
def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.scale = scale if scale is not None else head_dim ** -0.5
self.flash = flash
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(dropout_rate)
self.flash_config = self.setup_flash_config() if flash else None
def setup_flash_config(self):
# Setup flash attention configuration
flash_config = {
'enable_flash': True,
'enable_math': True,
'enable_mem_efficient': True
}
return flash_config
def forward(self, q, k, v, mask=None):
q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]]
if self.flash:
out = self.flash_attention(q, k, v, mask=mask)
else:
out = self.scaled_dot_product_attention(q, k, v, mask=mask)
return self.combine_heads(out)
def scaled_dot_product_attention(self, q, k, v, mask=None):
sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale
if mask is not None:
sim = sim.masked_fill(mask == 0, float('-inf'))
attn = torch.softmax(sim, dim=-1)
attn = self.dropout(attn)
return torch.einsum("bhts,bhls->bhlt", attn, v)
def flash_attention(self, q, k, v, mask=None):
config = self.flash_config if self.flash_config else {}
with torch.backends.cuda.sdp_kernel(**config):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.dropout_rate if self.training else 0.
)
return out
def split_heads(self, x):
bs, length, _ = x.shape
x = x.view(bs, length, self.n_heads, self.head_dim)
return x.permute(0, 2, 1, 3)
def combine_heads(self, x):
bs, _, length, _ = x.shape
x = x.permute(0, 2, 1, 3).contiguous()
return x.view(bs, length, -1)
class AttentionBlock2(nn.Module):
"""
An attention block that allows spatial positions to attend to each other,
using AttentionQKV and separate linear transformations for Q, K, and V.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
relative_pos_embeddings=False,
flash_attention=True,
dropout_rate=0.2,
scale=None
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.LayerNorm(channels)
# Separate linear layers for Q, K, and V
self.to_q = nn.Linear(channels, channels)
self.to_k = nn.Linear(channels, channels)
self.to_v = nn.Linear(channels, channels)
self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale)
self.proj_out = nn.Linear(channels, channels)
if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
else:
self.relative_pos_embeddings = None
def forward(self, x1, x2, mask=None):
b1, c1, *spatial1 = x1.shape
b2, c2, *spatial2 = x2.shape
x1_norm = self.norm(x1)
x2_norm = self.norm(x2)
q = self.to_q(x1_norm)
k = self.to_k(x2_norm)
v = self.to_v(x2_norm)
h = self.attention(q, k, v, mask=mask)
h = self.proj_out(h)
return (x1 + h).reshape(b1, c1, *spatial1)
class Perceiver(nn.Module):
"""Inspired by https://arxiv.org/abs/2103.03206"""
def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4):
"""
Initialize the perceiver module.
:param pre_attention_query_token: Number of query tokens for pre-attention
:param pre_attention_query_size: Size of each query token
:param embedding_dim: Dimension of the embedding space
:param num_attn_heads: Number of attention heads
"""
super().__init__()
# Initialize the pre-attention query parameter
self.pre_attention_query = torch.nn.Parameter(
torch.empty(1, pre_attention_query_token, pre_attention_query_size)
)
# Calculate the variance for uniform initialization
query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token))
# Initialize the pre-attention query with uniform distribution
self.pre_attention_query.data.uniform_(-query_variance, query_variance)
# Initialize the attention block
self.attn = AttentionBlock2(embedding_dim, num_attn_heads)
def forward(self, h):
"""
Forward pass of the perceiver module.
:param h: Input tensor
:return: Output after applying attention mechanisms
"""
# Expand the pre-attention query to match the batch size of the input
query_ = self.pre_attention_query.expand(h.shape[0], -1, -1)
# Apply the first attention mechanism (cross-attention)
pre_att = self.attn(query_, h)
# Apply the second attention mechanism (self-attention)
attn = self.attn(pre_att, pre_att)
return attn
|