Yanisadel commited on
Commit
d8365d8
·
1 Parent(s): 7974b5e

Delete multi_omics_model.py

Browse files
Files changed (1) hide show
  1. multi_omics_model.py +0 -127
multi_omics_model.py DELETED
@@ -1,127 +0,0 @@
1
- import torch
2
- from transformers import PreTrainedModel
3
-
4
- from genomics_research.biobrain_p1.porting_to_pytorch.configs.chatNT_config import (
5
- ChatNTConfig,
6
- )
7
- from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_decoder import (
8
- TorchBioBrainDecoder,
9
- )
10
- from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_encoder import (
11
- TorchBioBrainEncoder,
12
- )
13
- from genomics_research.biobrain_p1.porting_to_pytorch.models.perceiver_resampler_projection import ( # noqa
14
- TorchMultiModalPerceiverResamplerProjection,
15
- )
16
-
17
-
18
- class TorchMultiOmicsModel(PreTrainedModel):
19
- config_class = ChatNTConfig
20
-
21
- def __init__(self, config: ChatNTConfig) -> None:
22
- super().__init__(config=config)
23
- self.gpt_config = config.gpt_config
24
- self.esm_config = config.esm_config
25
- self.perceiver_resampler_config = config.perceiver_resampler_config
26
- self.seq_token_id = config.seq_token_id
27
- self.bio_pad_token_id = config.bio_pad_token_id
28
- self.english_pad_token_id = config.english_pad_token_id
29
-
30
- # Correct seq_token_id
31
- self.seq_token_id -= 1
32
-
33
- self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config)
34
- self.biobrain_decoder = TorchBioBrainDecoder(
35
- gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
36
- )
37
- self.projection_model = TorchMultiModalPerceiverResamplerProjection(
38
- perceiver_resampler_config=self.perceiver_resampler_config,
39
- input_embed_dim=self.esm_config.embed_dim,
40
- embed_dim=self.gpt_config.embed_dim,
41
- english_vocab_size=self.gpt_config.vocab_size,
42
- bio_pad_token_id=self.bio_pad_token_id,
43
- english_pad_token_id=self.english_pad_token_id,
44
- )
45
-
46
- def forward(
47
- self,
48
- multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
49
- projection_english_tokens_ids: torch.Tensor,
50
- projected_bio_embeddings: torch.Tensor = None,
51
- ) -> dict[str, torch.Tensor]:
52
- """
53
-
54
- Args:
55
- multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
56
- english_tokens_ids: Represents the prompt tokens (english tokens)
57
- Shape (batch_size, num_english_tokens)
58
-
59
- bio_tokens_ids: Represents the bio sequences tokens
60
- Shape (batch_size, num_bio_sequences, num_bio_tokens)
61
-
62
- projection_english_tokens_ids (torch.Tensor):
63
- Shape (batch_size, num_english_tokens)
64
-
65
- projected_bio_embeddings (projected_bio_embeddings, optional):
66
- Shape (batch_size, num_bio_sequencse, ?, embed_dim).
67
- Defaults to None.
68
-
69
- Returns:
70
- dict[str, torch.Tensor] containing:
71
- - logits:
72
- Shape (batch_size, num_tokens, vocab_size)
73
-
74
- - projected_bio_embeddings:
75
- Shape (batch_size, num_bio_sequences, ?, embed_dim)
76
- """
77
- english_token_ids, bio_token_ids = multi_omics_tokens_ids
78
-
79
- # Replace config.vocab_size value in english tokens
80
- # We do this because the default vocab size (32000) doesn't match with the
81
- # number of tokens because of seq_token_id(=32000) that was added
82
- # Therefore, we will put seq_token_id to 31999
83
- # (I will also put token n°31999 to 0, which is for unknown token)
84
- # This is a workaround to avoid having to change the vocab size in the config
85
- vocab_size = self.gpt_config.vocab_size
86
- # Replace vocab
87
- english_token_ids[english_token_ids == vocab_size - 1] = 0
88
- projection_english_tokens_ids[
89
- projection_english_tokens_ids == vocab_size - 1
90
- ] = 0
91
- english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
92
- projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
93
- vocab_size - 1
94
- )
95
-
96
- if bio_token_ids is None:
97
- projected_bio_embeddings = None
98
- else:
99
- num_bio_sequences = bio_token_ids.shape[1]
100
-
101
- if projected_bio_embeddings is None:
102
- # Compute bio sequences embeddings
103
- bio_embeddings_list = [
104
- self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
105
- for bio_seq_num in range(num_bio_sequences)
106
- ]
107
-
108
- # Project these embeddings
109
- projected_bio_embeddings = [
110
- self.projection_model(
111
- bio_token_ids=bio_token_ids[:, bio_seq_num],
112
- bio_embeddings=bio_embeddings,
113
- english_token_ids=projection_english_tokens_ids,
114
- )
115
- for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
116
- ]
117
- projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
118
-
119
- # decode
120
- logits = self.biobrain_decoder(
121
- english_token_ids=english_token_ids,
122
- projected_bio_embeddings=projected_bio_embeddings,
123
- )
124
-
125
- outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
126
-
127
- return outs