# File copied from https://raw.githubusercontent.com/heidelberg-hepml/lorentz-gatr/refs/heads/main/experiments/baselines/transformer.py from functools import partial from typing import Optional, Tuple import torch from einops import rearrange from torch import nn from torch.utils.checkpoint import checkpoint from lgatr.layers import ApplyRotaryPositionalEncoding from lgatr.primitives.attention import scaled_dot_product_attention def to_nd(tensor, d): """Make tensor n-dimensional, group extra dimensions in first.""" return tensor.view( -1, *(1,) * (max(0, d - 1 - tensor.dim())), *tensor.shape[-(d - 1) :] ) class BaselineLayerNorm(nn.Module): """Baseline layer norm over all dimensions except the first.""" @staticmethod def forward(inputs: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- inputs : Tensor Input data Returns ------- outputs : Tensor Normalized inputs. """ return torch.nn.functional.layer_norm( inputs, normalized_shape=inputs.shape[-1:] ) class MultiHeadQKVLinear(nn.Module): """Compute queries, keys, and values via multi-head attention. Parameters ---------- in_channels : int Number of input channels. hidden_channels : int Number of hidden channels = size of query, key, and value. num_heads : int Number of attention heads. """ def __init__(self, in_channels, hidden_channels, num_heads): super().__init__() self.num_heads = num_heads self.linear = nn.Linear(in_channels, 3 * hidden_channels * num_heads) def forward(self, inputs): """Forward pass. Returns ------- q : Tensor Queries k : Tensor Keys v : Tensor Values """ qkv = self.linear(inputs) # (..., num_items, 3 * hidden_channels * num_heads) q, k, v = rearrange( qkv, "... items (qkv hidden_channels num_heads) -> qkv ... num_heads items hidden_channels", num_heads=self.num_heads, qkv=3, ) return q, k, v class MultiQueryQKVLinear(nn.Module): """Compute queries, keys, and values via multi-query attention. Parameters ---------- in_channels : int Number of input channels. hidden_channels : int Number of hidden channels = size of query, key, and value. num_heads : int Number of attention heads. """ def __init__(self, in_channels, hidden_channels, num_heads): super().__init__() self.num_heads = num_heads self.q_linear = nn.Linear(in_channels, hidden_channels * num_heads) self.k_linear = nn.Linear(in_channels, hidden_channels) self.v_linear = nn.Linear(in_channels, hidden_channels) def forward(self, inputs): """Forward pass. Parameters ---------- inputs : Tensor Input data Returns ------- q : Tensor Queries k : Tensor Keys v : Tensor Values """ q = rearrange( self.q_linear(inputs), "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", num_heads=self.num_heads, ) k = self.k_linear(inputs)[ ..., None, :, : ] # (..., head=1, item, hidden_channels) v = self.v_linear(inputs)[..., None, :, :] return q, k, v class BaselineSelfAttention(nn.Module): """Baseline self-attention layer. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of input channels. hidden_channels : int Number of hidden channels = size of query, key, and value. num_heads : int Number of attention heads. pos_encoding : bool Whether to apply rotary positional embeddings along the item dimension to the scalar keys and queries. pos_enc_base : int Maximum frequency used in positional encodings. (The minimum frequency is always 1.) multi_query : bool Use multi-query attention instead of multi-head attention. """ def __init__( self, in_channels: int, out_channels: int, hidden_channels: int, num_heads: int = 8, pos_encoding: bool = False, pos_enc_base: int = 4096, multi_query: bool = True, dropout_prob=None, ) -> None: super().__init__() # Store settings self.num_heads = num_heads self.hidden_channels = hidden_channels # Linear maps qkv_class = MultiQueryQKVLinear if multi_query else MultiHeadQKVLinear self.qkv_linear = qkv_class(in_channels, hidden_channels, num_heads) self.out_linear = nn.Linear(hidden_channels * num_heads, out_channels) # Optional positional encoding if pos_encoding: self.pos_encoding = ApplyRotaryPositionalEncoding( hidden_channels, item_dim=-2, base=pos_enc_base ) else: self.pos_encoding = None if dropout_prob is not None: self.dropout = nn.Dropout(dropout_prob) else: self.dropout = None def forward( self, inputs: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, is_causal: bool = False, ) -> torch.Tensor: """Forward pass. Parameters ---------- inputs : Tensor Input data attention_mask : None or Tensor or xformers.ops.AttentionBias Optional attention mask Returns ------- outputs : Tensor Outputs """ q, k, v = self.qkv_linear( inputs ) # each: (..., num_heads, num_items, num_channels, 16) # Rotary positional encoding if self.pos_encoding is not None: q = self.pos_encoding(q) k = self.pos_encoding(k) # Attention layer h = self._attend(q, k, v, attention_mask, is_causal=is_causal) # Concatenate heads and transform linearly h = rearrange( h, "... num_heads num_items hidden_channels -> ... num_items (num_heads hidden_channels)", ) outputs = self.out_linear(h) # (..., num_items, out_channels) if self.dropout is not None: outputs = self.dropout(outputs) return outputs @staticmethod def _attend(q, k, v, attention_mask=None, is_causal=False): """Scaled dot-product attention.""" # Add batch dimension if needed bh_shape = q.shape[:-2] q = to_nd(q, 4) k = to_nd(k, 4) v = to_nd(v, 4) # SDPA outputs = scaled_dot_product_attention( q.contiguous(), k.expand_as(q).contiguous(), v.expand_as(q).contiguous(), attn_mask=attention_mask, is_causal=is_causal, ) # Return batch dimensions to inputs outputs = outputs.view(*bh_shape, *outputs.shape[-2:]) return outputs class BaselineTransformerBlock(nn.Module): """Baseline transformer block. Inputs are first processed by a block consisting of LayerNorm, multi-head self-attention, and residual connection. Then the data is processed by a block consisting of another LayerNorm, an item-wise two-layer MLP with GeLU activations, and another residual connection. Parameters ---------- channels : int Number of input and output channels. num_heads : int Number of attention heads. pos_encoding : bool Whether to apply rotary positional embeddings along the item dimension to the scalar keys and queries. pos_encoding_base : int Maximum frequency used in positional encodings. (The minimum frequency is always 1.) increase_hidden_channels : int Factor by which the key, query, and value size is increased over the default value of hidden_channels / num_heads. multi_query : bool Use multi-query attention instead of multi-head attention. """ def __init__( self, channels, num_heads: int = 8, pos_encoding: bool = False, pos_encoding_base: int = 4096, increase_hidden_channels=1, multi_query: bool = True, dropout_prob=None, ) -> None: super().__init__() self.norm = BaselineLayerNorm() # When using positional encoding, the number of scalar hidden channels needs to be even. # It also should not be too small. hidden_channels = channels // num_heads * increase_hidden_channels if pos_encoding: hidden_channels = (hidden_channels + 1) // 2 * 2 hidden_channels = max(hidden_channels, 16) self.attention = BaselineSelfAttention( channels, channels, hidden_channels, num_heads=num_heads, pos_encoding=pos_encoding, pos_enc_base=pos_encoding_base, multi_query=multi_query, dropout_prob=dropout_prob, ) self.mlp = nn.Sequential( nn.Linear(channels, 2 * channels), nn.Dropout(dropout_prob) if dropout_prob is not None else nn.Identity(), nn.GELU(), nn.Linear(2 * channels, channels), nn.Dropout(dropout_prob) if dropout_prob is not None else nn.Identity(), ) def forward( self, inputs: torch.Tensor, attention_mask=None, is_causal=False ) -> torch.Tensor: """Forward pass. Parameters ---------- inputs : Tensor Input data attention_mask : None or Tensor or xformers.ops.AttentionBias Optional attention mask Returns ------- outputs : Tensor Outputs """ # Residual attention h = self.norm(inputs) h = self.attention(h, attention_mask=attention_mask, is_causal=is_causal) outputs = inputs + h # Residual MLP h = self.norm(outputs) h = self.mlp(h) outputs = outputs + h return outputs class Transformer(nn.Module): """Baseline transformer. Combines num_blocks transformer blocks, each consisting of multi-head self-attention layers, an MLP, residual connections, and normalization layers. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. hidden_channels : int Number of hidden channels. num_blocks : int Number of transformer blocks. num_heads : int Number of attention heads. pos_encoding : bool Whether to apply rotary positional embeddings along the item dimension to the scalar keys and queries. pos_encoding_base : int Maximum frequency used in positional encodings. (The minimum frequency is always 1.) increase_hidden_channels : int Factor by which the key, query, and value size is increased over the default value of hidden_channels / num_heads. multi_query : bool Use multi-query attention instead of multi-head attention. """ def __init__( self, in_channels: int, out_channels: int, hidden_channels: int, num_blocks: int = 10, num_heads: int = 8, pos_encoding: bool = False, pos_encoding_base: int = 4096, checkpoint_blocks: bool = False, increase_hidden_channels=1, multi_query: bool = False, dropout_prob=None, ) -> None: super().__init__() self.checkpoint_blocks = checkpoint_blocks self.linear_in = nn.Linear(in_channels, hidden_channels) self.blocks = nn.ModuleList( [ BaselineTransformerBlock( hidden_channels, num_heads=num_heads, pos_encoding=pos_encoding, pos_encoding_base=pos_encoding_base, increase_hidden_channels=increase_hidden_channels, multi_query=multi_query, dropout_prob=dropout_prob, ) for _ in range(num_blocks) ] ) self.linear_out = nn.Linear(hidden_channels, out_channels) def forward( self, inputs: torch.Tensor, attention_mask=None, is_causal=False ) -> torch.Tensor: """Forward pass. Parameters ---------- inputs : Tensor with shape (..., num_items, num_channels) Input data attention_mask : None or Tensor or xformers.ops.AttentionBias Optional attention mask is_causal: bool Returns ------- outputs : Tensor with shape (..., num_items, num_channels) Outputs """ h = self.linear_in(inputs) for block in self.blocks: if self.checkpoint_blocks: fn = partial(block, attention_mask=attention_mask, is_causal=is_causal) h = checkpoint(fn, h) else: h = block(h, attention_mask=attention_mask, is_causal=is_causal) outputs = self.linear_out(h) return outputs class AxialTransformer(nn.Module): """Baseline axial transformer for data with two token dimensions. Combines num_blocks transformer blocks, each consisting of multi-head self-attention layers, an MLP, residual connections, and normalization layers. Assumes input data with shape `(..., num_items_1, num_items_2, num_channels, [16])`. The first, third, fifth, ... block computes attention over the `items_2` axis. The other blocks compute attention over the `items_1` axis. Positional encoding can be specified separately for both axes. Parameters ---------- in_channels : int Number of input channels. out_channels : int Number of output channels. hidden_channels : int Number of hidden channels. num_blocks : int Number of transformer blocks. num_heads : int Number of attention heads. pos_encodings : tuple of bool Whether to apply rotary positional embeddings along the item dimensions to the scalar keys and queries. pos_encoding_base : int Maximum frequency used in positional encodings. (The minimum frequency is always 1.) """ def __init__( self, in_channels: int, out_channels: int, hidden_channels: int, num_blocks: int = 20, num_heads: int = 8, pos_encodings: Tuple[bool, bool] = (False, False), pos_encoding_base: int = 4096, ) -> None: super().__init__() self.linear_in = nn.Linear(in_channels, hidden_channels) self.blocks = nn.ModuleList( [ BaselineTransformerBlock( hidden_channels, num_heads=num_heads, pos_encoding=pos_encodings[(block + 1) % 2], pos_encoding_base=pos_encoding_base, ) for block in range(num_blocks) ] ) self.linear_out = nn.Linear(hidden_channels, out_channels) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """Forward pass. Parameters ---------- inputs : Tensor with shape (..., num_items1, num_items2, num_channels) Input data Returns ------- outputs : Tensor with shape (..., num_items1, num_items2, num_channels) Outputs """ rearrange_pattern = "... i j c -> ... j i c" h = self.linear_in(inputs) for i, block in enumerate(self.blocks): # For first, third, ... block, we want to perform attention over the first token # dimension. We implement this by transposing the two item dimensions. if i % 2 == 1: h = rearrange(h, rearrange_pattern) h = block(h) # Transposing back to standard axis order if i % 2 == 1: h = rearrange(h, rearrange_pattern) outputs = self.linear_out(h) return outputs