|
import logging |
|
import math |
|
from dataclasses import dataclass |
|
from typing import Any, Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
@dataclass |
|
class RotaryEmbeddingConfig: |
|
""" |
|
Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows |
|
to adapt the rotary embeddings to larger lengths than what was used for training. |
|
One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa |
|
Args:b |
|
""" |
|
|
|
rescaling_factor: Optional[float] |
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module): |
|
""" |
|
Rotary position embeddings based on those in |
|
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). |
|
Query and keys are transformed by rotation |
|
matrices which depend on their relative positions. |
|
""" |
|
|
|
def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig): |
|
super().__init__() |
|
|
|
|
|
self.rescaling_factor = rotary_embedding_config.rescaling_factor |
|
self.upper_freq = 10000 |
|
self.dim = dim |
|
|
|
self._seq_len_cached = None |
|
self._cos_cached = None |
|
self._sin_cached = None |
|
|
|
def _apply_rotary_pos_emb( |
|
self, |
|
heads: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" """ |
|
x_first, x_second = ( |
|
heads[..., : heads.shape[-1] // 2], |
|
heads[..., heads.shape[-1] // 2 :], |
|
) |
|
|
|
first_part = x_first * cos - x_second * sin |
|
second_part = x_second * cos + x_first * sin |
|
|
|
return torch.cat((first_part, second_part), dim=-1) |
|
|
|
def _compute_cos_sin_tables( |
|
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
seq_len = x.shape[seq_dimension] |
|
|
|
|
|
self._seq_len_cached = seq_len |
|
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) |
|
freqs = torch.einsum("i, j -> ij", t, inv_freq) |
|
|
|
self._cos_cached = torch.cos(freqs)[None, :, None, :] |
|
self._sin_cached = torch.sin(freqs)[None, :, None, :] |
|
return self._cos_cached, self._sin_cached |
|
|
|
def forward( |
|
self, q: torch.Tensor, k: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if self.rescaling_factor is None: |
|
inv_freq = 1.0 / ( |
|
self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim) |
|
) |
|
else: |
|
updated_base = self.upper_freq * ( |
|
self.rescaling_factor ** (self.dim / (self.dim - 2)) |
|
) |
|
inv_freq = 1.0 / ( |
|
updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
|
) |
|
|
|
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( |
|
q, |
|
inv_freq, |
|
seq_dimension=-3, |
|
) |
|
|
|
return ( |
|
self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
|
self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
|
) |
|
|
|
|
|
class ResidualConvBlock(nn.Module): |
|
""" |
|
Conv Block with Residual connection. |
|
""" |
|
|
|
def __init__( |
|
self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1 |
|
): |
|
super().__init__() |
|
self.conv_block = ConvBlock( |
|
dim_in=dim_in, |
|
dim_out=dim_out, |
|
layer_norm_shape=layer_norm_shape, |
|
kernel_size=kernel_size, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
y = self.conv_block(x) |
|
return x.reshape(y.shape) + y |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
""" |
|
Conv Block. |
|
""" |
|
|
|
def __init__( |
|
self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1 |
|
): |
|
super().__init__() |
|
self.conv = nn.Conv1d( |
|
in_channels=dim_in, |
|
out_channels=dim_out, |
|
kernel_size=kernel_size, |
|
padding="same", |
|
) |
|
self.layer_norm = nn.LayerNorm(layer_norm_shape, eps=1e-5) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x.permute(0, 2, 1) |
|
x = self.layer_norm(x) |
|
x = x.permute(0, 2, 1) |
|
x = self.conv(x) |
|
x = F.gelu(x, approximate="tanh") |
|
return x |
|
|
|
|
|
class ConvTowerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim_in: int, |
|
dim_out: int, |
|
conv_layer_norm_shape: int, |
|
resconv_layer_norm_shape, |
|
kernel_size: int, |
|
) -> None: |
|
super().__init__() |
|
self.conv_layer = ConvBlock( |
|
dim_in=dim_in, |
|
dim_out=dim_out, |
|
layer_norm_shape=conv_layer_norm_shape, |
|
kernel_size=kernel_size, |
|
) |
|
self.res_conv = ResidualConvBlock( |
|
dim_in=dim_out, |
|
dim_out=dim_out, |
|
layer_norm_shape=resconv_layer_norm_shape, |
|
kernel_size=1, |
|
) |
|
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) |
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
residual = x |
|
x = self.conv_layer(x) |
|
x = self.res_conv(x) |
|
x = self.avg_pool(x) |
|
return x, residual |
|
|
|
|
|
class ResidualDeConvBlock(nn.Module): |
|
""" |
|
Conv Block with Residual connection. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in: int, |
|
dim_out: int, |
|
layer_norm_shape: int, |
|
kernel_size: int = 1, |
|
stride: int = 1, |
|
): |
|
super().__init__() |
|
self.deconv_block = DeConvBlock( |
|
dim_in=dim_in, |
|
dim_out=dim_out, |
|
layer_norm_shape=layer_norm_shape, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
y = self.deconv_block(x) |
|
return x.reshape(y.shape) + y |
|
|
|
|
|
class DeConvBlock(nn.Module): |
|
""" |
|
DeConv Block. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in: int, |
|
dim_out: int, |
|
layer_norm_shape: int, |
|
kernel_size: int = 1, |
|
stride: int = 1, |
|
): |
|
super().__init__() |
|
self.deconv = nn.ConvTranspose1d( |
|
in_channels=dim_in, |
|
out_channels=dim_out, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
) |
|
self.layer_norm = nn.LayerNorm(layer_norm_shape) |
|
self.kernel_size = kernel_size |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x.permute(0, 2, 1) |
|
x = self.layer_norm(x) |
|
x = x.permute(0, 2, 1) |
|
x = self.deconv(x) |
|
if self.kernel_size == 5: |
|
|
|
|
|
x = x[:, :, 1:-2] |
|
x = F.gelu(x, approximate="tanh") |
|
return x |
|
|
|
|
|
class DeConvTowerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim_in: int, |
|
dim_out: int, |
|
kernel_size: int, |
|
conv_layer_norm_shape: int, |
|
resconv_layer_norm_shape: int, |
|
stride: int = 2, |
|
): |
|
super().__init__() |
|
self.deconv_block = DeConvBlock( |
|
dim_in=dim_in, |
|
dim_out=dim_out, |
|
layer_norm_shape=conv_layer_norm_shape, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
) |
|
self.res_deconv_block = ResidualDeConvBlock( |
|
dim_in=dim_out, |
|
dim_out=dim_out, |
|
layer_norm_shape=resconv_layer_norm_shape, |
|
kernel_size=1, |
|
) |
|
|
|
def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: |
|
x = self.deconv_block(x) |
|
x = self.res_deconv_block(x) |
|
x = x + res |
|
return x |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
key_size: int, |
|
rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, |
|
add_bias_kv: bool = False, |
|
value_size: Optional[int] = None, |
|
model_size: Optional[int] = None, |
|
name: Optional[str] = None, |
|
): |
|
super().__init__() |
|
if not model_size: |
|
model_size = key_size |
|
if not value_size: |
|
value_size = key_size |
|
self.model_size = model_size |
|
self.key_size = key_size |
|
self.value_size = value_size |
|
self.add_bias_kv = add_bias_kv |
|
self.name = name |
|
self.num_heads = num_heads |
|
self._rotary_embedding_config = rotary_embedding_config |
|
|
|
self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) |
|
self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) |
|
self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) |
|
self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) |
|
if self._rotary_embedding_config: |
|
self._rotary_embedding = RotaryEmbedding( |
|
self.key_size, self._rotary_embedding_config |
|
) |
|
|
|
def apply_rotary_embeddings( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" """ |
|
query, key = self._rotary_embedding(query, key) |
|
return query, key |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_weight_bias: Optional[torch.Tensor] = None, |
|
) -> dict[str, torch.Tensor]: |
|
""" |
|
Returns: |
|
dictionary containing attention weights |
|
and outputs. |
|
""" |
|
key_heads = self.w_k(key).reshape( |
|
(*key.shape[:-1], self.num_heads, self.key_size) |
|
) |
|
query_heads = self.w_q(query).reshape( |
|
(*query.shape[:-1], self.num_heads, self.key_size) |
|
) |
|
value_heads = self.w_v(value).reshape( |
|
(*value.shape[:-1], self.num_heads, self.value_size) |
|
) |
|
if self._rotary_embedding_config: |
|
query_heads, key_heads = self.apply_rotary_embeddings( |
|
query_heads, key_heads |
|
) |
|
attention_weights = torch.einsum( |
|
"...thd, ...Thd -> ...htT", query_heads, key_heads |
|
) |
|
sqrt_key_size = np.sqrt(self.key_size) |
|
attention_weights = attention_weights / sqrt_key_size |
|
if attention_mask: |
|
attention_weights = torch.where(attention_mask, attention_weights, -1e30) |
|
if attention_weight_bias: |
|
attention_weights = F.softmax( |
|
attention_weights + attention_weight_bias, dim=-1 |
|
) |
|
else: |
|
attention_weights = F.softmax(attention_weights, dim=-1) |
|
value_out = torch.einsum( |
|
"...htT, ...Thd->...thd", attention_weights, value_heads |
|
) |
|
value_out = value_out.reshape((*value_out.shape[:-2], -1)) |
|
embeddings = self.output(value_out) |
|
|
|
return {"attention_weights": attention_weights, "embeddings": embeddings} |
|
|
|
|
|
class SelfAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
embed_dim: int, |
|
ffn_embed_dim: int, |
|
key_size: Optional[int] = None, |
|
add_bias_kv: bool = False, |
|
add_bias_fnn: bool = True, |
|
ffn_activation_name: str = "gelu-no-approx", |
|
use_glu_in_ffn: bool = False, |
|
layer_norm_eps: float = 1e-5, |
|
pre_layer_norm: bool = True, |
|
name: Optional[str] = None, |
|
rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, |
|
): |
|
super().__init__() |
|
if key_size is None: |
|
if embed_dim % num_heads != 0: |
|
raise ValueError( |
|
f"The embedding dimension should be divisible by the number of " |
|
f"heads, however provided embedding dimension is {embed_dim} and " |
|
f"the number of heads is {num_heads}." |
|
) |
|
else: |
|
key_size = embed_dim // num_heads |
|
|
|
|
|
self._pre_layer_norm = pre_layer_norm |
|
self._use_glu_in_fnn = use_glu_in_ffn |
|
|
|
if use_glu_in_ffn: |
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) |
|
else: |
|
self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) |
|
|
|
self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) |
|
|
|
self.layer_norm_self_attention = nn.LayerNorm( |
|
embed_dim, |
|
) |
|
self.layer_norm_mlp = nn.LayerNorm(embed_dim) |
|
if ffn_activation_name == "swish": |
|
self._ffn_activation_fn = nn.SiLU() |
|
elif ffn_activation_name == "gelu-no-approx": |
|
self._ffn_activation_fn = lambda x: F.gelu(x, approximate="none") |
|
else: |
|
self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) |
|
|
|
self.mha = MultiHeadAttention( |
|
num_heads=num_heads, |
|
key_size=key_size, |
|
add_bias_kv=add_bias_kv, |
|
model_size=embed_dim, |
|
name="self_attention", |
|
rotary_embedding_config=rotary_embedding_config, |
|
) |
|
|
|
def mlp(self, embed: torch.Tensor) -> torch.Tensor: |
|
|
|
if self._pre_layer_norm: |
|
x = self.layer_norm_mlp(embed) |
|
else: |
|
x = embed |
|
|
|
if self._use_glu_in_fnn: |
|
x = self.fc1(x) |
|
x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) |
|
x = self._ffn_activation_fn(x1) * x2 |
|
else: |
|
x = self._ffn_activation_fn(self.fc1(x)) |
|
x = self.fc2(x) |
|
|
|
if not self._pre_layer_norm: |
|
x = self.layer_norm_mlp(x + embed) |
|
return x |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_weight_bias: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
|
|
res = x |
|
if self._pre_layer_norm: |
|
x = self.layer_norm_self_attention(x) |
|
|
|
output = self.mha( |
|
x, |
|
x, |
|
x, |
|
attention_mask=attention_mask, |
|
attention_weight_bias=attention_weight_bias, |
|
) |
|
|
|
if not self._pre_layer_norm: |
|
output["embeddings"] = self.layer_norm_self_attention( |
|
output["embeddings"] + res |
|
) |
|
|
|
x = output["embeddings"] |
|
else: |
|
x = output["embeddings"] |
|
x = res + x |
|
|
|
|
|
if not self._pre_layer_norm: |
|
x = self.mlp(x) |
|
else: |
|
x = x + self.mlp(x) |
|
|
|
output["embeddings"] = x |
|
return output |
|
|
|
|
|
class LMHead(nn.Module): |
|
def __init__( |
|
self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int |
|
) -> None: |
|
""" """ |
|
super().__init__() |
|
self.num_hidden_layers = num_hidden_layers |
|
self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)]) |
|
self.linear_layers.extend( |
|
nn.ModuleList( |
|
[nn.Linear(embed_dim, embed_dim)] |
|
for _ in range(num_hidden_layers - 1) |
|
) |
|
) |
|
self.linear_out = nn.Linear(embed_dim, dim_out) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = F.gelu(x, approximate="tanh") |
|
for layer in self.linear_layers: |
|
x = layer(x) |
|
x = F.gelu(x, approximate="tanh") |
|
out = self.linear_out(x) |
|
return out |
|
|
|
|
|
class MOJOConfig(PretrainedConfig): |
|
model_type = "MOJO" |
|
|
|
def __init__(self, **kwargs: Any) -> None: |
|
super().__init__(**kwargs) |
|
self.alphabet_size = kwargs.get( |
|
"alphabet_size", {"rnaseq": 66, "methylation": 66} |
|
) |
|
self.token_embed_dim = kwargs.get("token_embed_dim", 256) |
|
self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200) |
|
self.use_gene_embedding = kwargs.get("use_gene_embedding", True) |
|
self.project_gene_embedding = kwargs.get("project_gene_embedding", True) |
|
self.sequence_length = kwargs.get("sequence_length", 17_116) |
|
self.fixed_sequence_length = kwargs.get("fixed_sequence_length", None) |
|
self.num_downsamples = kwargs.get("num_downsamples", 8) |
|
self.conv_init_embed_dim = kwargs.get("conv_init_embed_dim", 512) |
|
self.stem_kernel_shape = kwargs.get("stem_kernel_shape", 15) |
|
self.embed_dim = kwargs.get("embed_dim", 512) |
|
self.filter_list = kwargs.get("filter_list", []) |
|
self.num_attention_heads = kwargs.get("num_attention_heads", 16) |
|
self.key_size = kwargs.get("key_size", None) |
|
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 1_024) |
|
self.num_layers = kwargs.get("num_layers", 8) |
|
self.num_hidden_layers_head = kwargs.get("num_hidden_layers_head", 1) |
|
|
|
|
|
self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get( |
|
"embeddings_layers_to_save", () |
|
) |
|
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get( |
|
"attention_maps_to_save", [] |
|
) |
|
|
|
self.__post_init__() |
|
|
|
def __post_init__(self): |
|
|
|
key_size = self.key_size |
|
if key_size is None: |
|
embed_dim = self.embed_dim |
|
num_attention_heads = self.num_attention_heads |
|
if not embed_dim % num_attention_heads == 0: |
|
raise ValueError( |
|
f"When no key size is provided, the embedding dimension should be " |
|
f"divisible by the number of heads, however provided embedding " |
|
f"dimension is {embed_dim} and the number of heads is " |
|
f"{num_attention_heads}." |
|
) |
|
self.key_size = embed_dim // num_attention_heads |
|
|
|
|
|
use_gene_embedding = self.use_gene_embedding |
|
if use_gene_embedding: |
|
init_gene_embed_dim = self.init_gene_embed_dim |
|
token_embed_dim = self.token_embed_dim |
|
if init_gene_embed_dim != token_embed_dim: |
|
project_gene_embedding = self.project_gene_embedding |
|
if not project_gene_embedding: |
|
logging.warning( |
|
f"Init gene embedding dimension ({init_gene_embed_dim})" |
|
f"different than token embedding dimension ({token_embed_dim})." |
|
f"Setting `project_gene_embedding` to True" |
|
) |
|
self.project_gene_embedding = True |
|
|
|
|
|
num_downsamples = self.num_downsamples |
|
sequence_length = self.sequence_length |
|
downsample_factor = 2**num_downsamples |
|
fixed_sequence_length = ( |
|
math.ceil(sequence_length / downsample_factor) * downsample_factor |
|
) |
|
self.fixed_sequence_length = fixed_sequence_length |
|
|
|
|
|
num_downsamples = self.num_downsamples |
|
filter_list = ( |
|
np.linspace( |
|
self.conv_init_embed_dim, |
|
self.embed_dim, |
|
num_downsamples + 1, |
|
) |
|
.astype(int) |
|
.tolist() |
|
) |
|
self.filter_list = filter_list |
|
|
|
|
|
class MOJO(PreTrainedModel): |
|
config_class = MOJOConfig |
|
|
|
def __init__(self, config: MOJOConfig): |
|
super().__init__(config=config) |
|
|
|
|
|
self.embedding_layers = nn.ModuleDict( |
|
{ |
|
omic: nn.Embedding(config.alphabet_size[omic], config.token_embed_dim) |
|
for omic in config.alphabet_size |
|
} |
|
) |
|
|
|
self.gene_embedding_layer = nn.Embedding( |
|
config.fixed_sequence_length, |
|
config.init_gene_embed_dim, |
|
) |
|
self.fc_gene_embedding = nn.Linear( |
|
config.init_gene_embed_dim, config.token_embed_dim |
|
) |
|
|
|
|
|
self.stem_conv = nn.Sequential( |
|
nn.Conv1d( |
|
in_channels=config.token_embed_dim, |
|
out_channels=config.conv_init_embed_dim, |
|
kernel_size=config.stem_kernel_shape, |
|
padding="same", |
|
), |
|
nn.GELU(approximate="tanh"), |
|
) |
|
|
|
self.conv_tower = nn.ModuleList( |
|
[ |
|
ConvTowerBlock( |
|
dim_in=config.filter_list[i], |
|
dim_out=config.filter_list[i + 1], |
|
kernel_size=5, |
|
conv_layer_norm_shape=config.filter_list[i], |
|
resconv_layer_norm_shape=config.filter_list[i + 1], |
|
) |
|
for i in range(len(config.filter_list) - 1) |
|
] |
|
) |
|
|
|
|
|
attention_maps_to_save = config.attention_maps_to_save |
|
self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save}) |
|
|
|
self._attention_maps_per_layer_to_save = { |
|
layer: [t[1] for t in attention_maps_to_save if t[0] == layer] |
|
for layer in self._attention_layers_to_save |
|
} |
|
|
|
max_layer = max(self._attention_layers_to_save + [0]) |
|
if max_layer > config.num_layers: |
|
raise ValueError( |
|
f"You are requiring attention maps for layer {max_layer}, " |
|
f"while the model has {config.num_layers} layers only." |
|
) |
|
self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None) |
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
SelfAttentionBlock( |
|
num_heads=config.num_attention_heads, |
|
embed_dim=config.embed_dim, |
|
ffn_embed_dim=config.ffn_embed_dim, |
|
key_size=config.key_size, |
|
add_bias_kv=False, |
|
add_bias_fnn=False, |
|
ffn_activation_name="swish", |
|
use_glu_in_ffn=True, |
|
layer_norm_eps=1e-5, |
|
pre_layer_norm=True, |
|
name=f"attention_layer_{layer_idx}", |
|
rotary_embedding_config=self._rotary_embedding_config, |
|
) |
|
for layer_idx in range(config.num_layers) |
|
] |
|
) |
|
|
|
|
|
self.deconv_tower = nn.ModuleList( |
|
[ |
|
DeConvTowerBlock( |
|
dim_in=config.filter_list[-1 - i], |
|
dim_out=config.filter_list[-1 - i - 1], |
|
kernel_size=5, |
|
stride=2, |
|
conv_layer_norm_shape=config.filter_list[-1 - i], |
|
resconv_layer_norm_shape=config.filter_list[-1 - i - 1], |
|
) |
|
for i in range(len(config.filter_list) - 1) |
|
] |
|
) |
|
|
|
|
|
self.omic_lm_heads = nn.ModuleDict( |
|
{ |
|
omic: LMHead( |
|
dim_in=config.conv_init_embed_dim, |
|
embed_dim=config.embed_dim, |
|
dim_out=config.alphabet_size[omic], |
|
num_hidden_layers=config.num_hidden_layers_head, |
|
) |
|
for omic in self.config.alphabet_size |
|
} |
|
) |
|
|
|
def get_embeddings( |
|
self, |
|
input_ids: dict[str, torch.Tensor], |
|
) -> dict[str, torch.Tensor]: |
|
omic_embeddings = {} |
|
for omic, omic_tokens in input_ids.items(): |
|
omic_embeddings[omic] = self.embedding_layers[omic](omic_tokens) |
|
return omic_embeddings |
|
|
|
def forward(self, input_ids: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
|
outs = {} |
|
embeddings = self.get_embeddings(input_ids) |
|
outs["omic_embeddings"] = embeddings |
|
x = torch.stack(list(embeddings.values()), dim=0).sum(dim=0) |
|
outs["embeddings"] = x |
|
|
|
if self.config.use_gene_embedding: |
|
gene_indices = torch.arange( |
|
self.config.fixed_sequence_length, device=x.device |
|
) |
|
gene_embedding = self.gene_embedding_layer(gene_indices) |
|
if self.config.project_gene_embedding: |
|
gene_embedding = self.fc_gene_embedding(gene_embedding) |
|
x = x + gene_embedding |
|
outs["embeddings_with_gene_embedding"] = x |
|
|
|
x = x.permute(0, 2, 1) |
|
x = self.stem_conv(x) |
|
outs["stem"] = x |
|
|
|
residuals = [] |
|
for conv_block in self.conv_tower: |
|
x, res = conv_block(x) |
|
residuals.append(res) |
|
x = x.permute(0, 2, 1) |
|
outs["conv_tower"] = x |
|
outs["conv_tower_residuals"] = residuals |
|
residuals = residuals[::-1] |
|
|
|
for layer_idx, transformer in enumerate(self.transformer_layers): |
|
output = transformer(x) |
|
x = output["embeddings"] |
|
if (layer_idx + 1) in self.config.embeddings_layers_to_save: |
|
outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] |
|
if (layer_idx + 1) in self._attention_layers_to_save: |
|
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]: |
|
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}" |
|
outs[dkey] = output["attention_weights"][:, map_number + 1] |
|
outs["after_transformer_embedding"] = x |
|
|
|
x = x.permute(0, 2, 1) |
|
for deconv_block, res in zip(self.deconv_tower, residuals): |
|
x = deconv_block(x, res) |
|
x = x.permute(0, 2, 1) |
|
outs["deconv_tower"] = x |
|
|
|
outs["logits"] = { |
|
omic: self.omic_lm_heads[omic](x) for omic in self.config.alphabet_size |
|
} |
|
|
|
return outs |
|
|