MOJO / mojo.py
mgelard's picture
Upload MOJO
73fb208 verified
import logging
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from transformers import PretrainedConfig, PreTrainedModel
@dataclass
class RotaryEmbeddingConfig:
"""
Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
to adapt the rotary embeddings to larger lengths than what was used for training.
One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
Args:b
"""
rescaling_factor: Optional[float]
class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on those in
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
Query and keys are transformed by rotation
matrices which depend on their relative positions.
"""
def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig):
super().__init__()
# Extract argument from the config
self.rescaling_factor = rotary_embedding_config.rescaling_factor
self.upper_freq = 10000
self.dim = dim
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _apply_rotary_pos_emb(
self,
heads: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
""" """
x_first, x_second = (
heads[..., : heads.shape[-1] // 2],
heads[..., heads.shape[-1] // 2 :],
)
first_part = x_first * cos - x_second * sin
second_part = x_second * cos + x_first * sin
return torch.cat((first_part, second_part), dim=-1)
def _compute_cos_sin_tables(
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
freqs = torch.einsum("i, j -> ij", t, inv_freq)
self._cos_cached = torch.cos(freqs)[None, :, None, :]
self._sin_cached = torch.sin(freqs)[None, :, None, :]
return self._cos_cached, self._sin_cached
def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.rescaling_factor is None:
inv_freq = 1.0 / (
self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
)
else:
updated_base = self.upper_freq * (
self.rescaling_factor ** (self.dim / (self.dim - 2))
)
inv_freq = 1.0 / (
updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
)
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
q,
inv_freq,
seq_dimension=-3,
)
return (
self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
class ResidualConvBlock(nn.Module):
"""
Conv Block with Residual connection.
"""
def __init__(
self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1
):
super().__init__()
self.conv_block = ConvBlock(
dim_in=dim_in,
dim_out=dim_out,
layer_norm_shape=layer_norm_shape,
kernel_size=kernel_size,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.conv_block(x)
return x.reshape(y.shape) + y
class ConvBlock(nn.Module):
"""
Conv Block.
"""
def __init__(
self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1
):
super().__init__()
self.conv = nn.Conv1d(
in_channels=dim_in,
out_channels=dim_out,
kernel_size=kernel_size,
padding="same",
)
self.layer_norm = nn.LayerNorm(layer_norm_shape, eps=1e-5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1)
x = self.layer_norm(x)
x = x.permute(0, 2, 1)
x = self.conv(x)
x = F.gelu(x, approximate="tanh")
return x
class ConvTowerBlock(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
conv_layer_norm_shape: int,
resconv_layer_norm_shape,
kernel_size: int,
) -> None:
super().__init__()
self.conv_layer = ConvBlock(
dim_in=dim_in,
dim_out=dim_out,
layer_norm_shape=conv_layer_norm_shape,
kernel_size=kernel_size,
)
self.res_conv = ResidualConvBlock(
dim_in=dim_out,
dim_out=dim_out,
layer_norm_shape=resconv_layer_norm_shape,
kernel_size=1,
)
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
residual = x
x = self.conv_layer(x)
x = self.res_conv(x)
x = self.avg_pool(x)
return x, residual
class ResidualDeConvBlock(nn.Module):
"""
Conv Block with Residual connection.
"""
def __init__(
self,
dim_in: int,
dim_out: int,
layer_norm_shape: int,
kernel_size: int = 1,
stride: int = 1,
):
super().__init__()
self.deconv_block = DeConvBlock(
dim_in=dim_in,
dim_out=dim_out,
layer_norm_shape=layer_norm_shape,
kernel_size=kernel_size,
stride=stride,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.deconv_block(x)
return x.reshape(y.shape) + y
class DeConvBlock(nn.Module):
"""
DeConv Block.
"""
def __init__(
self,
dim_in: int,
dim_out: int,
layer_norm_shape: int,
kernel_size: int = 1,
stride: int = 1,
):
super().__init__()
self.deconv = nn.ConvTranspose1d(
in_channels=dim_in,
out_channels=dim_out,
kernel_size=kernel_size,
stride=stride,
padding=0,
)
self.layer_norm = nn.LayerNorm(layer_norm_shape)
self.kernel_size = kernel_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1)
x = self.layer_norm(x)
x = x.permute(0, 2, 1)
x = self.deconv(x)
if self.kernel_size == 5:
# handle the special case where haiku
# deconv removes padding automatically
x = x[:, :, 1:-2]
x = F.gelu(x, approximate="tanh")
return x
class DeConvTowerBlock(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
kernel_size: int,
conv_layer_norm_shape: int,
resconv_layer_norm_shape: int,
stride: int = 2,
):
super().__init__()
self.deconv_block = DeConvBlock(
dim_in=dim_in,
dim_out=dim_out,
layer_norm_shape=conv_layer_norm_shape,
kernel_size=kernel_size,
stride=stride,
)
self.res_deconv_block = ResidualDeConvBlock(
dim_in=dim_out,
dim_out=dim_out,
layer_norm_shape=resconv_layer_norm_shape,
kernel_size=1,
)
def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
x = self.deconv_block(x)
x = self.res_deconv_block(x)
x = x + res
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
num_heads: int,
key_size: int,
rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None,
add_bias_kv: bool = False,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
name: Optional[str] = None,
):
super().__init__()
if not model_size:
model_size = key_size
if not value_size:
value_size = key_size
self.model_size = model_size
self.key_size = key_size
self.value_size = value_size
self.add_bias_kv = add_bias_kv
self.name = name
self.num_heads = num_heads
self._rotary_embedding_config = rotary_embedding_config
self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
if self._rotary_embedding_config:
self._rotary_embedding = RotaryEmbedding(
self.key_size, self._rotary_embedding_config
)
def apply_rotary_embeddings(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
""" """
query, key = self._rotary_embedding(query, key)
return query, key
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_weight_bias: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
"""
Returns:
dictionary containing attention weights
and outputs.
"""
key_heads = self.w_k(key).reshape(
(*key.shape[:-1], self.num_heads, self.key_size)
)
query_heads = self.w_q(query).reshape(
(*query.shape[:-1], self.num_heads, self.key_size)
)
value_heads = self.w_v(value).reshape(
(*value.shape[:-1], self.num_heads, self.value_size)
)
if self._rotary_embedding_config:
query_heads, key_heads = self.apply_rotary_embeddings(
query_heads, key_heads
)
attention_weights = torch.einsum(
"...thd, ...Thd -> ...htT", query_heads, key_heads
)
sqrt_key_size = np.sqrt(self.key_size)
attention_weights = attention_weights / sqrt_key_size
if attention_mask:
attention_weights = torch.where(attention_mask, attention_weights, -1e30)
if attention_weight_bias:
attention_weights = F.softmax(
attention_weights + attention_weight_bias, dim=-1
)
else:
attention_weights = F.softmax(attention_weights, dim=-1)
value_out = torch.einsum(
"...htT, ...Thd->...thd", attention_weights, value_heads
)
value_out = value_out.reshape((*value_out.shape[:-2], -1))
embeddings = self.output(value_out)
return {"attention_weights": attention_weights, "embeddings": embeddings}
class SelfAttentionBlock(nn.Module):
def __init__(
self,
num_heads: int,
embed_dim: int,
ffn_embed_dim: int,
key_size: Optional[int] = None,
add_bias_kv: bool = False,
add_bias_fnn: bool = True,
ffn_activation_name: str = "gelu-no-approx",
use_glu_in_ffn: bool = False,
layer_norm_eps: float = 1e-5, # this is the default haiku value
pre_layer_norm: bool = True,
name: Optional[str] = None,
rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None,
):
super().__init__()
if key_size is None:
if embed_dim % num_heads != 0:
raise ValueError(
f"The embedding dimension should be divisible by the number of "
f"heads, however provided embedding dimension is {embed_dim} and "
f"the number of heads is {num_heads}."
)
else:
key_size = embed_dim // num_heads
# Get ffn activation function
self._pre_layer_norm = pre_layer_norm
self._use_glu_in_fnn = use_glu_in_ffn
# Define layers
if use_glu_in_ffn:
# user should multiply ffn_embed_dim by 2/3 when using GLU
# to keep total number of parameters equal
# see https://arxiv.org/pdf/2002.05202.pdf. for more details
# we multiply by 2 here as the output will be split in 2 for GLU
self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
else:
self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
self.layer_norm_self_attention = nn.LayerNorm(
embed_dim,
)
self.layer_norm_mlp = nn.LayerNorm(embed_dim)
if ffn_activation_name == "swish":
self._ffn_activation_fn = nn.SiLU()
elif ffn_activation_name == "gelu-no-approx":
self._ffn_activation_fn = lambda x: F.gelu(x, approximate="none")
else:
self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
self.mha = MultiHeadAttention(
num_heads=num_heads,
key_size=key_size,
add_bias_kv=add_bias_kv,
model_size=embed_dim,
name="self_attention",
rotary_embedding_config=rotary_embedding_config,
)
def mlp(self, embed: torch.Tensor) -> torch.Tensor:
if self._pre_layer_norm:
x = self.layer_norm_mlp(embed)
else:
x = embed
if self._use_glu_in_fnn:
x = self.fc1(x)
x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
x = self._ffn_activation_fn(x1) * x2
else:
x = self._ffn_activation_fn(self.fc1(x))
x = self.fc2(x)
if not self._pre_layer_norm:
x = self.layer_norm_mlp(x + embed)
return x
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_weight_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res = x
if self._pre_layer_norm:
x = self.layer_norm_self_attention(x)
output = self.mha(
x,
x,
x,
attention_mask=attention_mask,
attention_weight_bias=attention_weight_bias,
)
if not self._pre_layer_norm:
output["embeddings"] = self.layer_norm_self_attention(
output["embeddings"] + res
)
x = output["embeddings"]
else:
x = output["embeddings"]
x = res + x
# MLP
if not self._pre_layer_norm:
x = self.mlp(x)
else:
x = x + self.mlp(x)
output["embeddings"] = x
return output
class LMHead(nn.Module):
def __init__(
self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int
) -> None:
""" """
super().__init__()
self.num_hidden_layers = num_hidden_layers
self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)])
self.linear_layers.extend(
nn.ModuleList(
[nn.Linear(embed_dim, embed_dim)] # noqa
for _ in range(num_hidden_layers - 1)
)
)
self.linear_out = nn.Linear(embed_dim, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(x, approximate="tanh")
for layer in self.linear_layers:
x = layer(x)
x = F.gelu(x, approximate="tanh")
out = self.linear_out(x)
return out
class MOJOConfig(PretrainedConfig): # noqa: N801
model_type = "MOJO"
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.alphabet_size = kwargs.get(
"alphabet_size", {"rnaseq": 66, "methylation": 66}
)
self.token_embed_dim = kwargs.get("token_embed_dim", 256)
self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200)
self.use_gene_embedding = kwargs.get("use_gene_embedding", True)
self.project_gene_embedding = kwargs.get("project_gene_embedding", True)
self.sequence_length = kwargs.get("sequence_length", 17_116) # n_genes
self.fixed_sequence_length = kwargs.get("fixed_sequence_length", None)
self.num_downsamples = kwargs.get("num_downsamples", 8)
self.conv_init_embed_dim = kwargs.get("conv_init_embed_dim", 512)
self.stem_kernel_shape = kwargs.get("stem_kernel_shape", 15)
self.embed_dim = kwargs.get("embed_dim", 512)
self.filter_list = kwargs.get("filter_list", [])
self.num_attention_heads = kwargs.get("num_attention_heads", 16)
self.key_size = kwargs.get("key_size", None)
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 1_024)
self.num_layers = kwargs.get("num_layers", 8)
self.num_hidden_layers_head = kwargs.get("num_hidden_layers_head", 1)
# return
self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get(
"embeddings_layers_to_save", ()
)
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
"attention_maps_to_save", []
)
self.__post_init__()
def __post_init__(self):
# Validate attention key size
key_size = self.key_size
if key_size is None:
embed_dim = self.embed_dim
num_attention_heads = self.num_attention_heads
if not embed_dim % num_attention_heads == 0:
raise ValueError(
f"When no key size is provided, the embedding dimension should be "
f"divisible by the number of heads, however provided embedding "
f"dimension is {embed_dim} and the number of heads is "
f"{num_attention_heads}."
)
self.key_size = embed_dim // num_attention_heads
# Validate gene embedding projection
use_gene_embedding = self.use_gene_embedding
if use_gene_embedding:
init_gene_embed_dim = self.init_gene_embed_dim
token_embed_dim = self.token_embed_dim
if init_gene_embed_dim != token_embed_dim:
project_gene_embedding = self.project_gene_embedding
if not project_gene_embedding:
logging.warning(
f"Init gene embedding dimension ({init_gene_embed_dim})"
f"different than token embedding dimension ({token_embed_dim})."
f"Setting `project_gene_embedding` to True"
)
self.project_gene_embedding = True
# Compute fixed_sequence_length
num_downsamples = self.num_downsamples
sequence_length = self.sequence_length
downsample_factor = 2**num_downsamples
fixed_sequence_length = (
math.ceil(sequence_length / downsample_factor) * downsample_factor
)
self.fixed_sequence_length = fixed_sequence_length
# Create filters list
num_downsamples = self.num_downsamples
filter_list = (
np.linspace(
self.conv_init_embed_dim,
self.embed_dim,
num_downsamples + 1,
)
.astype(int)
.tolist()
)
self.filter_list = filter_list # noqa
class MOJO(PreTrainedModel): # noqa: N801
config_class = MOJOConfig
def __init__(self, config: MOJOConfig):
super().__init__(config=config)
# Embeddings
self.embedding_layers = nn.ModuleDict(
{
omic: nn.Embedding(config.alphabet_size[omic], config.token_embed_dim)
for omic in config.alphabet_size
}
)
self.gene_embedding_layer = nn.Embedding(
config.fixed_sequence_length,
config.init_gene_embed_dim,
)
self.fc_gene_embedding = nn.Linear(
config.init_gene_embed_dim, config.token_embed_dim
)
# Convolutions
self.stem_conv = nn.Sequential(
nn.Conv1d(
in_channels=config.token_embed_dim,
out_channels=config.conv_init_embed_dim,
kernel_size=config.stem_kernel_shape,
padding="same",
),
nn.GELU(approximate="tanh"),
)
self.conv_tower = nn.ModuleList(
[
ConvTowerBlock(
dim_in=config.filter_list[i],
dim_out=config.filter_list[i + 1],
kernel_size=5,
conv_layer_norm_shape=config.filter_list[i],
resconv_layer_norm_shape=config.filter_list[i + 1],
)
for i in range(len(config.filter_list) - 1)
]
)
# Transformer
attention_maps_to_save = config.attention_maps_to_save
self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save})
self._attention_maps_per_layer_to_save = {
layer: [t[1] for t in attention_maps_to_save if t[0] == layer]
for layer in self._attention_layers_to_save
}
max_layer = max(self._attention_layers_to_save + [0])
if max_layer > config.num_layers:
raise ValueError(
f"You are requiring attention maps for layer {max_layer}, "
f"while the model has {config.num_layers} layers only."
)
self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None)
self.transformer_layers = nn.ModuleList(
[
SelfAttentionBlock(
num_heads=config.num_attention_heads,
embed_dim=config.embed_dim,
ffn_embed_dim=config.ffn_embed_dim,
key_size=config.key_size,
add_bias_kv=False,
add_bias_fnn=False,
ffn_activation_name="swish",
use_glu_in_ffn=True,
layer_norm_eps=1e-5, # this is the default haiku value
pre_layer_norm=True,
name=f"attention_layer_{layer_idx}",
rotary_embedding_config=self._rotary_embedding_config,
)
for layer_idx in range(config.num_layers)
]
)
# Deconvolutions
self.deconv_tower = nn.ModuleList(
[
DeConvTowerBlock(
dim_in=config.filter_list[-1 - i],
dim_out=config.filter_list[-1 - i - 1],
kernel_size=5,
stride=2,
conv_layer_norm_shape=config.filter_list[-1 - i],
resconv_layer_norm_shape=config.filter_list[-1 - i - 1],
)
for i in range(len(config.filter_list) - 1)
]
)
# Language Modeling heads
self.omic_lm_heads = nn.ModuleDict(
{
omic: LMHead(
dim_in=config.conv_init_embed_dim,
embed_dim=config.embed_dim,
dim_out=config.alphabet_size[omic],
num_hidden_layers=config.num_hidden_layers_head,
)
for omic in self.config.alphabet_size
}
)
def get_embeddings(
self,
input_ids: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
omic_embeddings = {}
for omic, omic_tokens in input_ids.items():
omic_embeddings[omic] = self.embedding_layers[omic](omic_tokens)
return omic_embeddings
def forward(self, input_ids: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
outs = {}
embeddings = self.get_embeddings(input_ids)
outs["omic_embeddings"] = embeddings
x = torch.stack(list(embeddings.values()), dim=0).sum(dim=0) # [B, T, C]
outs["embeddings"] = x
if self.config.use_gene_embedding:
gene_indices = torch.arange(
self.config.fixed_sequence_length, device=x.device
)
gene_embedding = self.gene_embedding_layer(gene_indices)
if self.config.project_gene_embedding:
gene_embedding = self.fc_gene_embedding(gene_embedding)
x = x + gene_embedding
outs["embeddings_with_gene_embedding"] = x
x = x.permute(0, 2, 1)
x = self.stem_conv(x)
outs["stem"] = x
residuals = []
for conv_block in self.conv_tower:
x, res = conv_block(x)
residuals.append(res)
x = x.permute(0, 2, 1)
outs["conv_tower"] = x
outs["conv_tower_residuals"] = residuals # type: ignore
residuals = residuals[::-1]
for layer_idx, transformer in enumerate(self.transformer_layers):
output = transformer(x)
x = output["embeddings"]
if (layer_idx + 1) in self.config.embeddings_layers_to_save:
outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"]
if (layer_idx + 1) in self._attention_layers_to_save:
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
outs[dkey] = output["attention_weights"][:, map_number + 1]
outs["after_transformer_embedding"] = x
x = x.permute(0, 2, 1)
for deconv_block, res in zip(self.deconv_tower, residuals):
x = deconv_block(x, res)
x = x.permute(0, 2, 1)
outs["deconv_tower"] = x
outs["logits"] = {
omic: self.omic_lm_heads[omic](x) for omic in self.config.alphabet_size
}
return outs