Spaces:
Sleeping
Sleeping
# File copied from https://raw.githubusercontent.com/heidelberg-hepml/lorentz-gatr/refs/heads/main/experiments/baselines/transformer.py | |
from functools import partial | |
from typing import Optional, Tuple | |
import torch | |
from einops import rearrange | |
from torch import nn | |
from torch.utils.checkpoint import checkpoint | |
from lgatr.layers import ApplyRotaryPositionalEncoding | |
from lgatr.primitives.attention import scaled_dot_product_attention | |
def to_nd(tensor, d): | |
"""Make tensor n-dimensional, group extra dimensions in first.""" | |
return tensor.view( | |
-1, *(1,) * (max(0, d - 1 - tensor.dim())), *tensor.shape[-(d - 1) :] | |
) | |
class BaselineLayerNorm(nn.Module): | |
"""Baseline layer norm over all dimensions except the first.""" | |
def forward(inputs: torch.Tensor) -> torch.Tensor: | |
"""Forward pass. | |
Parameters | |
---------- | |
inputs : Tensor | |
Input data | |
Returns | |
------- | |
outputs : Tensor | |
Normalized inputs. | |
""" | |
return torch.nn.functional.layer_norm( | |
inputs, normalized_shape=inputs.shape[-1:] | |
) | |
class MultiHeadQKVLinear(nn.Module): | |
"""Compute queries, keys, and values via multi-head attention. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
hidden_channels : int | |
Number of hidden channels = size of query, key, and value. | |
num_heads : int | |
Number of attention heads. | |
""" | |
def __init__(self, in_channels, hidden_channels, num_heads): | |
super().__init__() | |
self.num_heads = num_heads | |
self.linear = nn.Linear(in_channels, 3 * hidden_channels * num_heads) | |
def forward(self, inputs): | |
"""Forward pass. | |
Returns | |
------- | |
q : Tensor | |
Queries | |
k : Tensor | |
Keys | |
v : Tensor | |
Values | |
""" | |
qkv = self.linear(inputs) # (..., num_items, 3 * hidden_channels * num_heads) | |
q, k, v = rearrange( | |
qkv, | |
"... items (qkv hidden_channels num_heads) -> qkv ... num_heads items hidden_channels", | |
num_heads=self.num_heads, | |
qkv=3, | |
) | |
return q, k, v | |
class MultiQueryQKVLinear(nn.Module): | |
"""Compute queries, keys, and values via multi-query attention. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
hidden_channels : int | |
Number of hidden channels = size of query, key, and value. | |
num_heads : int | |
Number of attention heads. | |
""" | |
def __init__(self, in_channels, hidden_channels, num_heads): | |
super().__init__() | |
self.num_heads = num_heads | |
self.q_linear = nn.Linear(in_channels, hidden_channels * num_heads) | |
self.k_linear = nn.Linear(in_channels, hidden_channels) | |
self.v_linear = nn.Linear(in_channels, hidden_channels) | |
def forward(self, inputs): | |
"""Forward pass. | |
Parameters | |
---------- | |
inputs : Tensor | |
Input data | |
Returns | |
------- | |
q : Tensor | |
Queries | |
k : Tensor | |
Keys | |
v : Tensor | |
Values | |
""" | |
q = rearrange( | |
self.q_linear(inputs), | |
"... items (hidden_channels num_heads) -> ... num_heads items hidden_channels", | |
num_heads=self.num_heads, | |
) | |
k = self.k_linear(inputs)[ | |
..., None, :, : | |
] # (..., head=1, item, hidden_channels) | |
v = self.v_linear(inputs)[..., None, :, :] | |
return q, k, v | |
class BaselineSelfAttention(nn.Module): | |
"""Baseline self-attention layer. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
out_channels : int | |
Number of input channels. | |
hidden_channels : int | |
Number of hidden channels = size of query, key, and value. | |
num_heads : int | |
Number of attention heads. | |
pos_encoding : bool | |
Whether to apply rotary positional embeddings along the item dimension to the scalar keys | |
and queries. | |
pos_enc_base : int | |
Maximum frequency used in positional encodings. (The minimum frequency is always 1.) | |
multi_query : bool | |
Use multi-query attention instead of multi-head attention. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
hidden_channels: int, | |
num_heads: int = 8, | |
pos_encoding: bool = False, | |
pos_enc_base: int = 4096, | |
multi_query: bool = True, | |
dropout_prob=None, | |
) -> None: | |
super().__init__() | |
# Store settings | |
self.num_heads = num_heads | |
self.hidden_channels = hidden_channels | |
# Linear maps | |
qkv_class = MultiQueryQKVLinear if multi_query else MultiHeadQKVLinear | |
self.qkv_linear = qkv_class(in_channels, hidden_channels, num_heads) | |
self.out_linear = nn.Linear(hidden_channels * num_heads, out_channels) | |
# Optional positional encoding | |
if pos_encoding: | |
self.pos_encoding = ApplyRotaryPositionalEncoding( | |
hidden_channels, item_dim=-2, base=pos_enc_base | |
) | |
else: | |
self.pos_encoding = None | |
if dropout_prob is not None: | |
self.dropout = nn.Dropout(dropout_prob) | |
else: | |
self.dropout = None | |
def forward( | |
self, | |
inputs: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
is_causal: bool = False, | |
) -> torch.Tensor: | |
"""Forward pass. | |
Parameters | |
---------- | |
inputs : Tensor | |
Input data | |
attention_mask : None or Tensor or xformers.ops.AttentionBias | |
Optional attention mask | |
Returns | |
------- | |
outputs : Tensor | |
Outputs | |
""" | |
q, k, v = self.qkv_linear( | |
inputs | |
) # each: (..., num_heads, num_items, num_channels, 16) | |
# Rotary positional encoding | |
if self.pos_encoding is not None: | |
q = self.pos_encoding(q) | |
k = self.pos_encoding(k) | |
# Attention layer | |
h = self._attend(q, k, v, attention_mask, is_causal=is_causal) | |
# Concatenate heads and transform linearly | |
h = rearrange( | |
h, | |
"... num_heads num_items hidden_channels -> ... num_items (num_heads hidden_channels)", | |
) | |
outputs = self.out_linear(h) # (..., num_items, out_channels) | |
if self.dropout is not None: | |
outputs = self.dropout(outputs) | |
return outputs | |
def _attend(q, k, v, attention_mask=None, is_causal=False): | |
"""Scaled dot-product attention.""" | |
# Add batch dimension if needed | |
bh_shape = q.shape[:-2] | |
q = to_nd(q, 4) | |
k = to_nd(k, 4) | |
v = to_nd(v, 4) | |
# SDPA | |
outputs = scaled_dot_product_attention( | |
q.contiguous(), | |
k.expand_as(q).contiguous(), | |
v.expand_as(q).contiguous(), | |
attn_mask=attention_mask, | |
is_causal=is_causal, | |
) | |
# Return batch dimensions to inputs | |
outputs = outputs.view(*bh_shape, *outputs.shape[-2:]) | |
return outputs | |
class BaselineTransformerBlock(nn.Module): | |
"""Baseline transformer block. | |
Inputs are first processed by a block consisting of LayerNorm, multi-head self-attention, and | |
residual connection. Then the data is processed by a block consisting of another LayerNorm, an | |
item-wise two-layer MLP with GeLU activations, and another residual connection. | |
Parameters | |
---------- | |
channels : int | |
Number of input and output channels. | |
num_heads : int | |
Number of attention heads. | |
pos_encoding : bool | |
Whether to apply rotary positional embeddings along the item dimension to the scalar keys | |
and queries. | |
pos_encoding_base : int | |
Maximum frequency used in positional encodings. (The minimum frequency is always 1.) | |
increase_hidden_channels : int | |
Factor by which the key, query, and value size is increased over the default value of | |
hidden_channels / num_heads. | |
multi_query : bool | |
Use multi-query attention instead of multi-head attention. | |
""" | |
def __init__( | |
self, | |
channels, | |
num_heads: int = 8, | |
pos_encoding: bool = False, | |
pos_encoding_base: int = 4096, | |
increase_hidden_channels=1, | |
multi_query: bool = True, | |
dropout_prob=None, | |
) -> None: | |
super().__init__() | |
self.norm = BaselineLayerNorm() | |
# When using positional encoding, the number of scalar hidden channels needs to be even. | |
# It also should not be too small. | |
hidden_channels = channels // num_heads * increase_hidden_channels | |
if pos_encoding: | |
hidden_channels = (hidden_channels + 1) // 2 * 2 | |
hidden_channels = max(hidden_channels, 16) | |
self.attention = BaselineSelfAttention( | |
channels, | |
channels, | |
hidden_channels, | |
num_heads=num_heads, | |
pos_encoding=pos_encoding, | |
pos_enc_base=pos_encoding_base, | |
multi_query=multi_query, | |
dropout_prob=dropout_prob, | |
) | |
self.mlp = nn.Sequential( | |
nn.Linear(channels, 2 * channels), | |
nn.Dropout(dropout_prob) if dropout_prob is not None else nn.Identity(), | |
nn.GELU(), | |
nn.Linear(2 * channels, channels), | |
nn.Dropout(dropout_prob) if dropout_prob is not None else nn.Identity(), | |
) | |
def forward( | |
self, inputs: torch.Tensor, attention_mask=None, is_causal=False | |
) -> torch.Tensor: | |
"""Forward pass. | |
Parameters | |
---------- | |
inputs : Tensor | |
Input data | |
attention_mask : None or Tensor or xformers.ops.AttentionBias | |
Optional attention mask | |
Returns | |
------- | |
outputs : Tensor | |
Outputs | |
""" | |
# Residual attention | |
h = self.norm(inputs) | |
h = self.attention(h, attention_mask=attention_mask, is_causal=is_causal) | |
outputs = inputs + h | |
# Residual MLP | |
h = self.norm(outputs) | |
h = self.mlp(h) | |
outputs = outputs + h | |
return outputs | |
class Transformer(nn.Module): | |
"""Baseline transformer. | |
Combines num_blocks transformer blocks, each consisting of multi-head self-attention layers, an | |
MLP, residual connections, and normalization layers. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
out_channels : int | |
Number of output channels. | |
hidden_channels : int | |
Number of hidden channels. | |
num_blocks : int | |
Number of transformer blocks. | |
num_heads : int | |
Number of attention heads. | |
pos_encoding : bool | |
Whether to apply rotary positional embeddings along the item dimension to the scalar keys | |
and queries. | |
pos_encoding_base : int | |
Maximum frequency used in positional encodings. (The minimum frequency is always 1.) | |
increase_hidden_channels : int | |
Factor by which the key, query, and value size is increased over the default value of | |
hidden_channels / num_heads. | |
multi_query : bool | |
Use multi-query attention instead of multi-head attention. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
hidden_channels: int, | |
num_blocks: int = 10, | |
num_heads: int = 8, | |
pos_encoding: bool = False, | |
pos_encoding_base: int = 4096, | |
checkpoint_blocks: bool = False, | |
increase_hidden_channels=1, | |
multi_query: bool = False, | |
dropout_prob=None, | |
) -> None: | |
super().__init__() | |
self.checkpoint_blocks = checkpoint_blocks | |
self.linear_in = nn.Linear(in_channels, hidden_channels) | |
self.blocks = nn.ModuleList( | |
[ | |
BaselineTransformerBlock( | |
hidden_channels, | |
num_heads=num_heads, | |
pos_encoding=pos_encoding, | |
pos_encoding_base=pos_encoding_base, | |
increase_hidden_channels=increase_hidden_channels, | |
multi_query=multi_query, | |
dropout_prob=dropout_prob, | |
) | |
for _ in range(num_blocks) | |
] | |
) | |
self.linear_out = nn.Linear(hidden_channels, out_channels) | |
def forward( | |
self, inputs: torch.Tensor, attention_mask=None, is_causal=False | |
) -> torch.Tensor: | |
"""Forward pass. | |
Parameters | |
---------- | |
inputs : Tensor with shape (..., num_items, num_channels) | |
Input data | |
attention_mask : None or Tensor or xformers.ops.AttentionBias | |
Optional attention mask | |
is_causal: bool | |
Returns | |
------- | |
outputs : Tensor with shape (..., num_items, num_channels) | |
Outputs | |
""" | |
h = self.linear_in(inputs) | |
for block in self.blocks: | |
if self.checkpoint_blocks: | |
fn = partial(block, attention_mask=attention_mask, is_causal=is_causal) | |
h = checkpoint(fn, h) | |
else: | |
h = block(h, attention_mask=attention_mask, is_causal=is_causal) | |
outputs = self.linear_out(h) | |
return outputs | |
class AxialTransformer(nn.Module): | |
"""Baseline axial transformer for data with two token dimensions. | |
Combines num_blocks transformer blocks, each consisting of multi-head self-attention layers, an | |
MLP, residual connections, and normalization layers. | |
Assumes input data with shape `(..., num_items_1, num_items_2, num_channels, [16])`. | |
The first, third, fifth, ... block computes attention over the `items_2` axis. The other blocks | |
compute attention over the `items_1` axis. Positional encoding can be specified separately for | |
both axes. | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels. | |
out_channels : int | |
Number of output channels. | |
hidden_channels : int | |
Number of hidden channels. | |
num_blocks : int | |
Number of transformer blocks. | |
num_heads : int | |
Number of attention heads. | |
pos_encodings : tuple of bool | |
Whether to apply rotary positional embeddings along the item dimensions to the scalar keys | |
and queries. | |
pos_encoding_base : int | |
Maximum frequency used in positional encodings. (The minimum frequency is always 1.) | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
hidden_channels: int, | |
num_blocks: int = 20, | |
num_heads: int = 8, | |
pos_encodings: Tuple[bool, bool] = (False, False), | |
pos_encoding_base: int = 4096, | |
) -> None: | |
super().__init__() | |
self.linear_in = nn.Linear(in_channels, hidden_channels) | |
self.blocks = nn.ModuleList( | |
[ | |
BaselineTransformerBlock( | |
hidden_channels, | |
num_heads=num_heads, | |
pos_encoding=pos_encodings[(block + 1) % 2], | |
pos_encoding_base=pos_encoding_base, | |
) | |
for block in range(num_blocks) | |
] | |
) | |
self.linear_out = nn.Linear(hidden_channels, out_channels) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
"""Forward pass. | |
Parameters | |
---------- | |
inputs : Tensor with shape (..., num_items1, num_items2, num_channels) | |
Input data | |
Returns | |
------- | |
outputs : Tensor with shape (..., num_items1, num_items2, num_channels) | |
Outputs | |
""" | |
rearrange_pattern = "... i j c -> ... j i c" | |
h = self.linear_in(inputs) | |
for i, block in enumerate(self.blocks): | |
# For first, third, ... block, we want to perform attention over the first token | |
# dimension. We implement this by transposing the two item dimensions. | |
if i % 2 == 1: | |
h = rearrange(h, rearrange_pattern) | |
h = block(h) | |
# Transposing back to standard axis order | |
if i % 2 == 1: | |
h = rearrange(h, rearrange_pattern) | |
outputs = self.linear_out(h) | |
return outputs |