|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
import math |
|
|
|
from models.positional_embeddings import PositionalEmbedding, FourierEmbedding |
|
from einops import rearrange |
|
|
|
torch.fx.wrap("rearrange") |
|
from typing import Tuple, Optional |
|
from einops._torch_specific import allow_ops_in_compiled_graph |
|
|
|
allow_ops_in_compiled_graph() |
|
|
|
|
|
class FusedMLP(nn.Sequential): |
|
def __init__( |
|
self, |
|
dim_model: int, |
|
dropout: float, |
|
activation: nn.Module, |
|
hidden_layer_multiplier: int = 4, |
|
bias: bool = True, |
|
): |
|
super().__init__( |
|
nn.Linear(dim_model, dim_model * hidden_layer_multiplier, bias=bias), |
|
activation(), |
|
nn.Dropout(dropout), |
|
nn.Linear(dim_model * hidden_layer_multiplier, dim_model, bias=bias), |
|
) |
|
|
|
|
|
def _cast_if_autocast_enabled(tensor): |
|
if torch.is_autocast_enabled(): |
|
if tensor.device.type == "cuda": |
|
dtype = torch.get_autocast_gpu_dtype() |
|
elif tensor.device.type == "cpu": |
|
dtype = torch.get_autocast_cpu_dtype() |
|
else: |
|
raise NotImplementedError() |
|
return tensor.to(dtype=dtype) |
|
return tensor |
|
|
|
|
|
class LayerNorm16Bits(torch.nn.LayerNorm): |
|
""" |
|
16-bit friendly version of torch.nn.LayerNorm |
|
""" |
|
|
|
def __init__( |
|
self, |
|
normalized_shape, |
|
eps=1e-06, |
|
elementwise_affine=True, |
|
device=None, |
|
dtype=None, |
|
): |
|
super().__init__( |
|
normalized_shape=normalized_shape, |
|
eps=eps, |
|
elementwise_affine=elementwise_affine, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
def forward(self, x): |
|
module_device = x.device |
|
downcast_x = _cast_if_autocast_enabled(x) |
|
downcast_weight = ( |
|
_cast_if_autocast_enabled(self.weight) |
|
if self.weight is not None |
|
else self.weight |
|
) |
|
downcast_bias = ( |
|
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias |
|
) |
|
with torch.autocast(enabled=False, device_type=module_device.type): |
|
return nn.functional.layer_norm( |
|
downcast_x, |
|
self.normalized_shape, |
|
downcast_weight, |
|
downcast_bias, |
|
self.eps, |
|
) |
|
|
|
|
|
class StochatichDepth(nn.Module): |
|
def __init__(self, p: float): |
|
super().__init__() |
|
self.survival_prob = 1.0 - p |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
if self.training and self.survival_prob < 1: |
|
mask = ( |
|
torch.empty(x.shape[0], 1, 1, device=x.device).uniform_() |
|
+ self.survival_prob |
|
) |
|
mask = mask.floor() |
|
if self.survival_prob > 0: |
|
mask = mask / self.survival_prob |
|
return x * mask |
|
else: |
|
return x |
|
|
|
|
|
class CrossAttentionOp(nn.Module): |
|
def __init__( |
|
self, attention_dim, num_heads, dim_q, dim_kv, use_biases=True, is_sa=False |
|
): |
|
super().__init__() |
|
self.dim_q = dim_q |
|
self.dim_kv = dim_kv |
|
self.attention_dim = attention_dim |
|
self.num_heads = num_heads |
|
self.use_biases = use_biases |
|
self.is_sa = is_sa |
|
if self.is_sa: |
|
self.qkv = nn.Linear(dim_q, attention_dim * 3, bias=use_biases) |
|
else: |
|
self.q = nn.Linear(dim_q, attention_dim, bias=use_biases) |
|
self.kv = nn.Linear(dim_kv, attention_dim * 2, bias=use_biases) |
|
self.out = nn.Linear(attention_dim, dim_q, bias=use_biases) |
|
|
|
def forward(self, x_to, x_from=None, attention_mask=None, materialize_sdpa=False): |
|
if x_from is None: |
|
x_from = x_to |
|
if self.is_sa: |
|
q, k, v = self.qkv(x_to).chunk(3, dim=-1) |
|
else: |
|
q = self.q(x_to) |
|
k, v = self.kv(x_from).chunk(2, dim=-1) |
|
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) |
|
k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) |
|
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.unsqueeze(1) |
|
if materialize_sdpa: |
|
x = self.materialize_sdpa(q, k, v, attention_mask) |
|
else: |
|
x = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=attention_mask |
|
) |
|
x = rearrange(x, "b h n d -> b n (h d)") |
|
x = self.out(x) |
|
return x |
|
|
|
def materialize_sdpa(self, q, k, v, attn_mask=None): |
|
scale = 1.0 / math.sqrt(q.shape[-1]) |
|
|
|
attn_matrix = torch.einsum("b h i d, b h j d -> b h i j", q, k) * scale |
|
if attn_mask is not None: |
|
attn_matrix = attn_matrix * attn_mask |
|
attn_matrix = torch.nn.functional.softmax(attn_matrix, dim=-1) |
|
return torch.einsum("b h i j, b h j d -> b h i d", attn_matrix, v) |
|
|
|
|
|
class CrossAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim_q: int, |
|
dim_kv: int, |
|
num_heads: int, |
|
attention_dim: int = 0, |
|
mlp_multiplier: int = 4, |
|
dropout: float = 0.0, |
|
stochastic_depth: float = 0.0, |
|
use_biases: bool = True, |
|
retrieve_attention_scores: bool = False, |
|
use_16_bits_layer_norm: bool = False, |
|
): |
|
super().__init__() |
|
if use_16_bits_layer_norm and not retrieve_attention_scores: |
|
LayerNorm = LayerNorm16Bits |
|
else: |
|
LayerNorm = nn.LayerNorm |
|
self.retrieve_attention_scores = retrieve_attention_scores |
|
self.initial_to_ln = LayerNorm(dim_q, eps=1e-6) |
|
attention_dim = min(dim_q, dim_kv) if attention_dim == 0 else attention_dim |
|
self.ca = CrossAttentionOp( |
|
attention_dim, num_heads, dim_q, dim_kv, is_sa=False, use_biases=use_biases |
|
) |
|
self.ca_stochastic_depth = StochatichDepth(stochastic_depth) |
|
self.middle_ln = LayerNorm(dim_q, eps=1e-6) |
|
self.ffn = FusedMLP( |
|
dim_model=dim_q, |
|
dropout=dropout, |
|
activation=nn.GELU, |
|
hidden_layer_multiplier=mlp_multiplier, |
|
bias=use_biases, |
|
) |
|
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) |
|
|
|
self.register_parameter( |
|
"attention_mask_dummy", |
|
nn.Parameter(torch.ones(1, 1, dtype=torch.bool), requires_grad=False), |
|
) |
|
|
|
def forward( |
|
self, |
|
to_tokens: Tensor, |
|
from_tokens: Tensor, |
|
to_token_mask: Optional[Tensor] = None, |
|
from_token_mask: Optional[Tensor] = None, |
|
) -> Tensor: |
|
if to_token_mask is None and from_token_mask is None: |
|
attention_mask = None |
|
else: |
|
if to_token_mask is None: |
|
to_token_mask = self.attention_mask_dummy.expand( |
|
to_tokens.shape[0], |
|
to_tokens.shape[1], |
|
) |
|
if from_token_mask is None: |
|
from_token_mask = self.attention_mask_dummy.expand( |
|
from_tokens.shape[0], |
|
from_tokens.shape[1], |
|
) |
|
attention_mask = from_token_mask.unsqueeze(1) * to_token_mask.unsqueeze(2) |
|
if self.retrieve_attention_scores: |
|
attention_output = self.ca( |
|
self.initial_to_ln(to_tokens), |
|
from_tokens, |
|
attention_mask=attention_mask, |
|
materialize_sdpa=True, |
|
) |
|
else: |
|
attention_output = self.ca( |
|
self.initial_to_ln(to_tokens), |
|
from_tokens, |
|
attention_mask=attention_mask, |
|
) |
|
to_tokens = to_tokens + self.ca_stochastic_depth(attention_output) |
|
to_tokens = to_tokens + self.ffn_stochastic_depth( |
|
self.ffn(self.middle_ln(to_tokens)) |
|
) |
|
return to_tokens |
|
|
|
|
|
class SelfAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim_qkv: int, |
|
num_heads: int, |
|
attention_dim: int = 0, |
|
mlp_multiplier: int = 4, |
|
dropout: float = 0.0, |
|
stochastic_depth: float = 0.0, |
|
use_biases: bool = True, |
|
use_layer_scale: bool = False, |
|
layer_scale_value: float = 0.1, |
|
retrieve_attention_scores: bool = False, |
|
use_16_bits_layer_norm: bool = False, |
|
): |
|
super().__init__() |
|
if use_16_bits_layer_norm and not retrieve_attention_scores: |
|
LayerNorm = LayerNorm16Bits |
|
else: |
|
LayerNorm = nn.LayerNorm |
|
self.retrieve_attention_scores = retrieve_attention_scores |
|
self.initial_ln = LayerNorm(dim_qkv, eps=1e-6) |
|
attention_dim = dim_qkv if attention_dim == 0 else attention_dim |
|
self.sa = CrossAttentionOp( |
|
attention_dim, |
|
num_heads, |
|
dim_qkv, |
|
dim_qkv, |
|
is_sa=True, |
|
use_biases=use_biases, |
|
) |
|
self.sa_stochastic_depth = StochatichDepth(stochastic_depth) |
|
self.middle_ln = LayerNorm(dim_qkv, eps=1e-6) |
|
self.ffn = FusedMLP( |
|
dim_model=dim_qkv, |
|
dropout=dropout, |
|
activation=nn.GELU, |
|
hidden_layer_multiplier=mlp_multiplier, |
|
bias=use_biases, |
|
) |
|
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) |
|
self.use_layer_scale = use_layer_scale |
|
if use_layer_scale: |
|
self.layer_scale_1 = nn.Parameter( |
|
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True |
|
) |
|
self.layer_scale_2 = nn.Parameter( |
|
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True |
|
) |
|
|
|
self.register_parameter( |
|
"attention_mask_dummy", |
|
nn.Parameter(torch.ones(1, 1, dtype=torch.bool), requires_grad=False), |
|
) |
|
|
|
def forward( |
|
self, |
|
tokens: torch.Tensor, |
|
token_mask: Optional[torch.Tensor] = None, |
|
): |
|
if token_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = token_mask.unsqueeze(1) * self.attention_mask_dummy.expand( |
|
tokens.shape[0], |
|
tokens.shape[1], |
|
).unsqueeze(2) |
|
if self.retrieve_attention_scores: |
|
attention_output = self.sa( |
|
self.initial_ln(tokens), |
|
attention_mask=attention_mask, |
|
materialize_sdpa=True, |
|
) |
|
else: |
|
attention_output = self.sa( |
|
self.initial_ln(tokens), |
|
attention_mask=attention_mask, |
|
) |
|
if self.use_layer_scale: |
|
tokens = tokens + self.sa_stochastic_depth( |
|
self.layer_scale_1 * attention_output |
|
) |
|
tokens = tokens + self.ffn_stochastic_depth( |
|
self.layer_scale_2 * self.ffn(self.middle_ln(tokens)) |
|
) |
|
else: |
|
tokens = tokens + self.sa_stochastic_depth(attention_output) |
|
tokens = tokens + self.ffn_stochastic_depth( |
|
self.ffn(self.middle_ln(tokens)) |
|
) |
|
return tokens |
|
|