Yanisadel commited on
Commit
264c5fd
·
1 Parent(s): 657ebbe

Delete chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +0 -1850
chatNT.py DELETED
@@ -1,1850 +0,0 @@
1
- # This file stores ChatNT and all associated layers and configs
2
-
3
- from dataclasses import asdict, dataclass, field
4
- from typing import Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F # noqa: N812
10
- from transformers import PretrainedConfig, PreTrainedModel
11
-
12
-
13
- @dataclass
14
- class RotaryEmbeddingConfig:
15
- """
16
- Rotary Positional Embedding configuration
17
- max_seq_len: The number of positions to encode and cache.
18
- dim: Dimension of RoPE.
19
- theta: Rotation angle.
20
- """
21
-
22
- max_seq_len: int
23
- dim: int
24
- theta: float
25
-
26
-
27
- @dataclass
28
- class PerceiverResamplerConfig:
29
- """
30
- Parameters to initialize an PerceiverResampler model. Based on the ESM architecture.
31
-
32
- Args:
33
- emb_layer_norm_before: Whether to use layer norm before the first attention
34
- layer.
35
- attention_heads: Number of attention heads.
36
- key_size: The dimension of the query, key, and values within each attention
37
- head, if not specified, it is set to attention_heads//embed_dim.
38
- It can be useful to set a custom key size if we want to impose the size of
39
- the query, key and value tensor ( for example, tensors shaped with
40
- power of 2 are more efficiently handled on TPUs ).
41
- Note: Parametrizing the model with a custom key size has been done in :
42
- Brown, Tom, et al. "Language models are few-shot learners."
43
- Advances in neural information processing systems 33 (2020): 1877-1901.
44
- embed_dim: Embedding dimension.
45
- ffn_embed_dim: Feed forward embedding dimension.
46
- num_layers: Number of attention blocks.
47
- ffn_activation_name: Activation function to be used in FFN block. Supported
48
- names are "gelu", "relu", "swish".
49
- use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
50
- Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
51
- to True and use swish as ffn_activation_name.
52
- Same principle for a gated-relu. To keep the same number of parameters in
53
- the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
54
- See https://arxiv.org/pdf/2002.05202.pdf for more details.
55
- resampled_length: length of the resampled output of the module
56
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
57
- gradients in the forward pass to reduce the computation in the backward).
58
- """
59
-
60
- # architecture
61
- emb_layer_norm_before: bool = False
62
- attention_heads: int = 20
63
- key_size: Optional[int] = None
64
- embed_dim: int = 1280
65
- ffn_embed_dim: int = 5120
66
- num_layers: int = 24
67
- add_bias_kv: bool = False
68
- add_bias_ffn: bool = True
69
- ffn_activation_name: str = "gelu-no-approx"
70
- use_glu_in_ffn: bool = False
71
- resampled_length: int = 64
72
-
73
- # performance
74
- use_gradient_checkpointing: bool = False
75
-
76
- def __post_init__(self) -> None:
77
- """
78
- Checks that the given values are compatible.
79
- """
80
-
81
- if self.key_size is None:
82
- if not self.embed_dim % self.attention_heads == 0:
83
- raise ValueError(
84
- f"When no key size is provided, the embedding dimension should be "
85
- f"divisible by the number of heads, however provided embedding "
86
- f"dimension is {self.embed_dim} and the number of heads is "
87
- f"{self.attention_heads}."
88
- )
89
- self.key_size = self.embed_dim // self.attention_heads
90
-
91
-
92
- @dataclass
93
- class GptConfig:
94
- """
95
- Parameters to initialize a Gpt model.
96
-
97
- NOTE: the pad token is not defined
98
-
99
- Args:
100
- vocab_size: Token vocabulary.
101
- eos_token_id: used to stop sentence generation
102
- embed_dim: Embedding dimension.
103
- ffn_embed_dim: Feed forward embedding dimension.
104
- num_heads: Number of attention heads.
105
- num_kv_heads: Number of key and value heads to support Grouped-Query and
106
- Multi-Query Attention. If None, the number of key and value heads is
107
- equal to the number of attention heads.
108
- num_layers: Number of Decoder layer_stack
109
- rope_config: The configuration for the rotary positional embeddings
110
- add_bias_ffn: Add bias in feed forward network block.
111
- ffn_activation_name: Activation function to be used in FFN block. Supported
112
- names are "gelu", "gelu-no-approx", "relu", "swish".
113
- use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed
114
- Forward Network (FFN) block.
115
- example: To do a swiGLU (gated-swish) put this arg
116
- to True and use swish as ffn_activation_name.
117
- Same principle for a gated-relu.
118
- add_bias_lm_head: whether to use bias in the final LM layer
119
- norm_type: The type of norm used ( pre normalization scheme ) used. can be
120
- one of ["layer_norm", "RMS_norm"]
121
- parallel_attention_ff: Whether to do the attention and the MLP in parallel,
122
- and then sum up the results as it is done in Gpt-NeoX :
123
- Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive
124
- language model." arXiv preprint arXiv:2204.06745 (2022).
125
- It is said to improve the training time of 15% when compiling with JAX
126
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
127
- gradients in the forward pass to reduce the computation in the backward).
128
- add_bias_attn: Add bias to the attention mechanism (key, query, value, and
129
- output projections).
130
- """
131
-
132
- # vocabulary
133
- vocab_size: int
134
- eos_token_id: int
135
-
136
- # architecture
137
- embed_dim: int = 16
138
- ffn_embed_dim: int = 64
139
- num_heads: int = 2
140
- num_kv_heads: Optional[int] = None
141
- num_layers: int = 2
142
- rope_config: RotaryEmbeddingConfig = field(
143
- default_factory=lambda: RotaryEmbeddingConfig(
144
- max_seq_len=512, dim=8, theta=10000.0
145
- )
146
- )
147
- add_bias_ffn: bool = False
148
- ffn_activation_name: str = "swish"
149
- use_glu_in_ffn: bool = True
150
- add_bias_lm_head: bool = False
151
- norm_type: str = "RMS_norm"
152
- rms_norm_eps: float = 1e-6
153
- parallel_attention_ff: bool = True
154
-
155
- # inference / backward behavior
156
- use_gradient_checkpointing: bool = False
157
-
158
- # architecture params with default values
159
- add_bias_attn: bool = False
160
-
161
- def __post_init__(self) -> None:
162
- """
163
- Checks that the given values are compatible.
164
- """
165
- if not self.embed_dim % self.num_heads == 0:
166
- raise ValueError(
167
- f"The embedding dimension should be "
168
- f"divisible by the number of heads, however provided embedding "
169
- f"dimension is {self.embed_dim} and the number of heads is "
170
- f"{self.num_heads}."
171
- )
172
-
173
- if not self.embed_dim // self.num_heads > 1:
174
- raise ValueError(
175
- "embed_dim / num_heads must be higher than 2 to apply rotary embeddings"
176
- )
177
-
178
- if not self.embed_dim // self.num_heads >= self.rope_config.dim:
179
- raise ValueError(
180
- "embed_dim // num_heads must be higher than rope_config.dim "
181
- "to apply rotary embeddings"
182
- )
183
-
184
- def to_dict(self): # type: ignore
185
- output = asdict(self)
186
- output["rope_config"] = asdict(self.rope_config)
187
- return output
188
-
189
-
190
- @dataclass
191
- class ESMTransformerConfig:
192
- """
193
- Parameters to initialize an ESM model. While the ESM architecture is an encoder-only
194
- model, different choices have been made for each version and this configuration aims
195
- to cover most of them.
196
-
197
- Args:
198
- alphabet_size: Token vocabulary.
199
- pad_token_id: ID of pad token.
200
- mask_token_id: ID of mask token.
201
- max_positions: Maximum sequence length.
202
- embed_scale: Correction ratio applied to the embeddings to make up for the
203
- norm difference between the input during training and inference.
204
- emb_layer_norm_before: Whether to use layer norm before the first attention
205
- layer.
206
- attention_heads: Number of attention heads.
207
- key_size: The dimension of the query, key, and values within each attention
208
- head, if not specified, it is set to attention_heads//embed_dim.
209
- It can be useful to set a custom key size if we want to impose the size of
210
- the query, key and value tensor ( for example, tensors shaped with
211
- power of 2 are more efficiently handled on TPUs ).
212
- Note: Parametrizing the model with a custom key size has been done in :
213
- Brown, Tom, et al. "Language models are few-shot learners."
214
- Advances in neural information processing systems 33 (2020): 1877-1901.
215
- embed_dim: Embedding dimension.
216
- ffn_embed_dim: Feed forward embedding dimension.
217
- num_layers: Number of attention blocks.
218
- positional_embedding: Type of positional embedding to use before the first
219
- attention layer. Options: "learned", "learned_standard" "sinusoidal" or
220
- None.
221
- NOTE: "learned" is the positional embedding of ESM, and "learned_standard"
222
- is a more standard one, used for example in DNAbert.
223
- lm_head: type of language model head. Options: "simple", "roberta" or None.
224
- add_bias_kv: Add bias in attention layer.
225
- add_bias_ffn: Add bias in feed forward network block.
226
- use_rotary_embedding: Whether to use rotary embeddings (for ESM2). Requires:
227
- positional_embeddings = None.
228
- rescaling_factor: Scaling factor to use for rotary embeddings.
229
- ffn_activation_name: Activation function to be used in FFN block. Supported
230
- names are "gelu", "relu", "swish".
231
- use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
232
- Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
233
- to True and use swish as ffn_activation_name.
234
- Same principle for a gated-relu. To keep the same number of parameters in
235
- the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
236
- See https://arxiv.org/pdf/2002.05202.pdf for more details.
237
- mask_before_attention: Use mask before attention layers (for EMS1b and ESM2).
238
- layer_norm_eps: the eps factor in the different layer norms of the model (refer
239
- to layer norm implementation)
240
- token_dropout: Token dropout.
241
- masking_ratio: Masking ratio (used if token dropout is enabled).
242
- masking_prob: Masking probability (used if token dropout is enabled).
243
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
244
- gradients in the forward pass to reduce the computation in the backward).
245
- """
246
-
247
- alphabet_size: int
248
- pad_token_id: int
249
- mask_token_id: int
250
-
251
- max_positions: int = 1024
252
- embed_scale: float = 1.0
253
-
254
- # architecture
255
- emb_layer_norm_before: bool = False
256
- attention_heads: int = 20
257
- key_size: Optional[int] = None
258
- embed_dim: int = 1280
259
- ffn_embed_dim: int = 5120
260
- num_layers: int = 24
261
- positional_embedding: Optional[str] = "learned"
262
- lm_head: Optional[str] = "simple"
263
- add_bias_kv: bool = False
264
- add_bias_ffn: bool = True
265
- use_rotary_embedding: bool = False
266
- rescaling_factor: Optional[float] = None
267
- ffn_activation_name: str = "gelu-no-approx"
268
- use_glu_in_ffn: bool = False
269
- mask_before_attention: bool = False
270
- layer_norm_eps: float = 1e-5
271
- pre_layer_norm: bool = True
272
- bias_word_embedding: bool = False
273
-
274
- # dropout
275
- token_dropout: bool = False
276
- masking_ratio: float = 0.1
277
- masking_prob: float = 0.8
278
-
279
- # logging
280
- use_gradient_checkpointing: bool = False
281
-
282
- # return
283
- embeddings_layers_to_save: List[int] = field(default_factory=list)
284
- attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list)
285
-
286
- def __post_init__(self) -> None:
287
- """
288
- Checks that the given values are compatible.
289
- """
290
-
291
- if self.key_size is None:
292
- if not self.embed_dim % self.attention_heads == 0:
293
- raise ValueError(
294
- f"When no key size is provided, the embedding dimension should be "
295
- f"divisible by the number of heads, however provided embedding "
296
- f"dimension is {self.embed_dim} and the number of heads is "
297
- f"{self.attention_heads}."
298
- )
299
- self.key_size = self.embed_dim // self.attention_heads
300
- if self.positional_embedding is not None:
301
- if type(self.positional_embedding) != str:
302
- raise TypeError
303
-
304
- if self.positional_embedding not in [
305
- "learned",
306
- "sinusoidal",
307
- "learned_standard",
308
- "alibi_dnabert_2",
309
- ]:
310
- raise ValueError(
311
- "The positional_embedding argument should either be None,"
312
- "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'."
313
- )
314
- if self.lm_head is not None:
315
- if type(self.lm_head) != str:
316
- raise TypeError
317
-
318
- if self.lm_head not in ["simple", "roberta"]:
319
- raise ValueError(
320
- "The lm_head argument should either be None,"
321
- "`simple` or `roberta`."
322
- )
323
-
324
- if self.use_rotary_embedding and self.positional_embedding is not None:
325
- raise ValueError(
326
- "When using rotary embedding, positional_embedding must be set to none"
327
- )
328
-
329
- if self.add_bias_kv and self.use_rotary_embedding:
330
- raise ValueError(
331
- "Biases on key and values are not compatible with Rotary embeddings."
332
- )
333
-
334
- if self.positional_embedding == "alibi_dnabert_2":
335
- assert not self.add_bias_kv
336
-
337
-
338
- @dataclass
339
- class ChatNTConfig(PretrainedConfig):
340
- model_type = "ChatNT"
341
-
342
- def __init__(self, **kwargs): # type: ignore
343
- self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3))
344
- self.esm_config: ESMTransformerConfig = kwargs.get(
345
- "esm_config", ESMTransformerConfig(4000, 1, 4)
346
- )
347
- self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get(
348
- "perceiver_resampler_config", PerceiverResamplerConfig()
349
- )
350
- self.seq_token_id: int = kwargs.get("seq_token_id", 32000)
351
- self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1)
352
- self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2)
353
- super().__init__(**kwargs)
354
-
355
- def to_dict(self): # type: ignore
356
- output = super().to_dict()
357
-
358
- def serialize(obj): # type: ignore
359
- return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj)
360
-
361
- output["gpt_config"] = serialize(self.gpt_config) # type: ignore
362
- output["esm_config"] = serialize(self.esm_config) # type: ignore
363
- output["perceiver_resampler_config"] = serialize( # type: ignore
364
- self.perceiver_resampler_config
365
- )
366
- return output
367
-
368
-
369
- class TorchBioBrainDecoder(nn.Module):
370
- def __init__(
371
- self,
372
- gpt_config: GptConfig,
373
- seq_token_id: int,
374
- ):
375
- """
376
- Initializes the BioBrain decoder, using a GPT model for text generation with
377
- bio embeddings.
378
-
379
- Args:
380
- gpt_config: Configuration for the GPT model
381
- seq_token_id: Index of the SEQ token
382
- """
383
- super(TorchBioBrainDecoder, self).__init__()
384
- self.gpt_config = gpt_config
385
- self.seq_token_id = seq_token_id
386
-
387
- # Initialize the GPT model (assumed you have it already in PyTorch)
388
- self.gpt_model = TorchGptDecoder(self.gpt_config)
389
-
390
- def forward(
391
- self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor
392
- ) -> torch.Tensor:
393
- """
394
- Forward pass through the model.
395
-
396
- Args:
397
- english_token_ids: Tensor of English token IDs with shape
398
- (batch_size, num_english_tokens).
399
- projected_bio_embeddings: Optional tensor of bio embeddings with shape
400
- (batch_size, num_bio_sequences, ?, embed_dim).
401
-
402
- Returns:
403
- torch.Tensor: The logits from the GPT model,
404
- shaped (batch_size, num_english_tokens, vocab_size).
405
- """
406
-
407
- # Compute English token embeddings
408
- tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
409
-
410
- if projected_bio_embeddings is not None:
411
- (
412
- batch_size,
413
- num_bio_sequences,
414
- _,
415
- bio_embed_dim,
416
- ) = projected_bio_embeddings.shape
417
-
418
- # Insert the bio embeddings at the SEQ token positions
419
- processed_tokens_ids = english_token_ids.clone()
420
- for bio_seq_num in range(num_bio_sequences):
421
- tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
422
- processed_tokens_ids,
423
- tokens_embeddings,
424
- projected_bio_embeddings[:, bio_seq_num, :, :],
425
- bio_seq_num=bio_seq_num,
426
- )
427
-
428
- # Regular GPT pass through
429
- embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
430
- embeddings = self.gpt_model.final_norm(embeddings)
431
-
432
- # Compute logits
433
- logits = self.gpt_model.lm_head(embeddings)
434
-
435
- if projected_bio_embeddings is not None:
436
- # Clean logits sequentially
437
- processed_tokens_ids = english_token_ids.clone()
438
- resampled_length = projected_bio_embeddings.shape[-2]
439
- for _ in range(num_bio_sequences):
440
- logits, processed_tokens_ids = self.cleanup_logits(
441
- tokens=processed_tokens_ids,
442
- logits=logits,
443
- resampled_length=resampled_length,
444
- )
445
-
446
- return logits
447
-
448
- def insert_embeddings(
449
- self,
450
- tokens: torch.Tensor,
451
- input_embeddings: torch.Tensor,
452
- resampled_embeddings: torch.Tensor,
453
- bio_seq_num: int,
454
- ) -> Tuple[torch.Tensor, torch.Tensor]:
455
- """
456
- Inserts resampled embeddings in input_embeddings, starting at the SEQ token
457
-
458
- Args:
459
- tokens (torch.Tensor): Shape (batch_size, num_tokens)
460
- input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
461
- resampled_embeddings (torch.Tensor):
462
- Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
463
-
464
- Returns:
465
- Tuple[torch.Tensor, torch.Tensor]:
466
- - input_embeddings with resampled_embeddings inserted at the SEQ token
467
- - tokens with the SEQ token set to -1
468
- """
469
-
470
- def _insert(
471
- tokens_1d: torch.Tensor,
472
- input_embeddings_1d: torch.Tensor,
473
- resampled_embeddings_1d: torch.Tensor,
474
- ) -> Tuple[torch.Tensor, torch.Tensor]:
475
- """
476
- Args:
477
- tokens (torch.Tensor): Shape (num_tokens,)
478
- input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,)
479
- resampled_embeddings (torch.Tensor):
480
- Shape (bio_sequence_length, embed_dim,)
481
- """
482
- indices = torch.where(tokens_1d == self.seq_token_id)[0]
483
- if indices.numel() > 0:
484
- idx = indices[0].item()
485
- insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
486
- x = torch.cat(
487
- [
488
- input_embeddings_1d[:insertion_pos, :],
489
- resampled_embeddings_1d,
490
- input_embeddings_1d[insertion_pos:, :],
491
- ],
492
- dim=0,
493
- )[: tokens_1d.shape[0] + 1, :]
494
- x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[
495
- :-1, :
496
- ]
497
- tokens_1d[idx] = -1
498
- return x, tokens_1d
499
- else:
500
- return (
501
- input_embeddings,
502
- tokens_1d,
503
- ) # Return unchanged if seq_token_id is not found
504
-
505
- tokens_acc = []
506
- embeddings_acc = []
507
-
508
- for i in range(tokens.shape[0]):
509
- embeddings_out, tokens_out = _insert(
510
- tokens[i].clone(),
511
- input_embeddings[i].clone(),
512
- resampled_embeddings[i].clone(),
513
- )
514
- tokens_acc.append(tokens_out)
515
- embeddings_acc.append(embeddings_out)
516
- tokens_acc = torch.stack(tokens_acc)
517
- embeddings_acc = torch.stack(embeddings_acc)
518
-
519
- return embeddings_acc, tokens_acc
520
-
521
- def cleanup_logits(
522
- self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int
523
- ) -> Tuple[torch.Tensor, torch.Tensor]:
524
- """
525
- Removes the logits corresponding to the unused embeddings.
526
-
527
- Args:
528
- tokens: Input english tokens.
529
- logits: Input logits.
530
-
531
- Returns:
532
- Cleaned logits, last values will be equal to 0.
533
- """
534
-
535
- def _clean(
536
- token: torch.Tensor, logit: torch.Tensor
537
- ) -> Tuple[torch.Tensor, torch.Tensor]:
538
- indices = torch.where(token == self.seq_token_id)[0]
539
- if indices.numel() > 0:
540
- idx = indices[0].item()
541
-
542
- mask_idx = (
543
- torch.arange(logit.shape[0] - resampled_length, device=logit.device)
544
- > idx
545
- )
546
- mask_idx = mask_idx.unsqueeze(1)
547
-
548
- # Remove values corresponding to bio tokens
549
- logit = (
550
- logit[:-resampled_length] * (~mask_idx)
551
- + logit[resampled_length:] * mask_idx
552
- )
553
-
554
- # Append zeros at the end
555
- logit = torch.cat(
556
- (
557
- logit,
558
- torch.zeros(
559
- (resampled_length, logit.shape[1]),
560
- dtype=logit.dtype,
561
- device=logit.device,
562
- ),
563
- )
564
- )
565
-
566
- # Update token
567
- token[idx] = -1
568
-
569
- return logit, token
570
-
571
- else:
572
- return logit, token
573
-
574
- tokens_acc = []
575
- logits_acc = []
576
-
577
- for i in range(tokens.shape[0]):
578
- logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone())
579
- tokens_acc.append(tokens_out)
580
- logits_acc.append(logits_out)
581
- tokens_acc = torch.stack(tokens_acc)
582
- logits_acc = torch.stack(logits_acc)
583
-
584
- return logits_acc, tokens_acc
585
-
586
-
587
- class TorchMultiOmicsModel(PreTrainedModel):
588
- config_class = ChatNTConfig
589
-
590
- def __init__(self, config: ChatNTConfig) -> None:
591
- super().__init__(config=config)
592
- self.gpt_config = config.gpt_config
593
- self.esm_config = config.esm_config
594
- self.perceiver_resampler_config = config.perceiver_resampler_config
595
- self.seq_token_id = config.seq_token_id
596
- self.bio_pad_token_id = config.bio_pad_token_id
597
- self.english_pad_token_id = config.english_pad_token_id
598
-
599
- # Correct seq_token_id
600
- self.seq_token_id -= 1
601
-
602
- self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config)
603
- self.biobrain_decoder = TorchBioBrainDecoder(
604
- gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
605
- )
606
- self.projection_model = TorchMultiModalPerceiverResamplerProjection(
607
- perceiver_resampler_config=self.perceiver_resampler_config,
608
- input_embed_dim=self.esm_config.embed_dim,
609
- embed_dim=self.gpt_config.embed_dim,
610
- english_vocab_size=self.gpt_config.vocab_size,
611
- bio_pad_token_id=self.bio_pad_token_id,
612
- english_pad_token_id=self.english_pad_token_id,
613
- )
614
-
615
- def forward(
616
- self,
617
- multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
618
- projection_english_tokens_ids: torch.Tensor,
619
- projected_bio_embeddings: torch.Tensor = None,
620
- ) -> dict[str, torch.Tensor]:
621
- """
622
-
623
- Args:
624
- multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
625
- english_tokens_ids: Represents the prompt tokens (english tokens)
626
- Shape (batch_size, num_english_tokens)
627
-
628
- bio_tokens_ids: Represents the bio sequences tokens
629
- Shape (batch_size, num_bio_sequences, num_bio_tokens)
630
-
631
- projection_english_tokens_ids (torch.Tensor):
632
- Shape (batch_size, num_english_tokens)
633
-
634
- projected_bio_embeddings (projected_bio_embeddings, optional):
635
- Shape (batch_size, num_bio_sequencse, ?, embed_dim).
636
- Defaults to None.
637
-
638
- Returns:
639
- dict[str, torch.Tensor] containing:
640
- - logits:
641
- Shape (batch_size, num_tokens, vocab_size)
642
-
643
- - projected_bio_embeddings:
644
- Shape (batch_size, num_bio_sequences, ?, embed_dim)
645
- """
646
- english_token_ids, bio_token_ids = multi_omics_tokens_ids
647
-
648
- # Replace config.vocab_size value in english tokens
649
- # We do this because the default vocab size (32000) doesn't match with the
650
- # number of tokens because of seq_token_id(=32000) that was added
651
- # Therefore, we will put seq_token_id to 31999
652
- # (I will also put token n°31999 to 0, which is for unknown token)
653
- # This is a workaround to avoid having to change the vocab size in the config
654
- vocab_size = self.gpt_config.vocab_size
655
- # Replace vocab
656
- english_token_ids[english_token_ids == vocab_size - 1] = 0
657
- projection_english_tokens_ids[
658
- projection_english_tokens_ids == vocab_size - 1
659
- ] = 0
660
- english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
661
- projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
662
- vocab_size - 1
663
- )
664
-
665
- if bio_token_ids is None:
666
- projected_bio_embeddings = None
667
- else:
668
- num_bio_sequences = bio_token_ids.shape[1]
669
-
670
- if projected_bio_embeddings is None:
671
- # Compute bio sequences embeddings
672
- bio_embeddings_list = [
673
- self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
674
- for bio_seq_num in range(num_bio_sequences)
675
- ]
676
-
677
- # Project these embeddings
678
- projected_bio_embeddings = [
679
- self.projection_model(
680
- bio_token_ids=bio_token_ids[:, bio_seq_num],
681
- bio_embeddings=bio_embeddings,
682
- english_token_ids=projection_english_tokens_ids,
683
- )
684
- for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
685
- ]
686
- projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
687
-
688
- # decode
689
- logits = self.biobrain_decoder(
690
- english_token_ids=english_token_ids,
691
- projected_bio_embeddings=projected_bio_embeddings,
692
- )
693
-
694
- outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
695
-
696
- return outs
697
-
698
-
699
- class TorchRotaryEmbedding(torch.nn.Module):
700
- def __init__(self, config: RotaryEmbeddingConfig):
701
- super().__init__()
702
-
703
- self.max_seq_len = config.max_seq_len
704
- self.dim = config.dim
705
- self.theta = config.theta
706
- self.sincos_cache = self._create_sinusoidal_positions()
707
-
708
- def _create_sinusoidal_positions(self) -> torch.Tensor:
709
- """
710
- Create the sines and cosines for the RoPE.
711
-
712
- Returns:
713
- Sinusoidal positions of shape (self.max_seq_len, self.dim).
714
- """
715
- # Create the inverse frequency based on theta and dim
716
- inv_freq = 1.0 / (
717
- self.theta ** (torch.arange(0, self.dim, 2).float() / self.dim)
718
- )
719
-
720
- # Compute sinusoidal input using the broadcasting
721
- sinusoid_inp = torch.einsum(
722
- "i,j->ij", torch.arange(self.max_seq_len).float(), inv_freq
723
- )
724
-
725
- # Apply sin and cos to the sinusoidal input
726
- sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
727
-
728
- # Allocate a tensor for the final sin-cos values
729
- sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32)
730
-
731
- # Fill the sincos tensor with sin and cos values
732
- sentinel = self.dim // 2 + self.dim % 2
733
- sincos[:, :sentinel] = sin
734
- sincos[:, sentinel:] = cos
735
-
736
- return sincos
737
-
738
- def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
739
- """
740
- Prepare a tensor to apply the RoPE mechanism.
741
-
742
- Args:
743
- x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
744
- typically this is the key or query tensor.
745
-
746
- Returns:
747
- The even indices in the last dimension have their sign flipped.
748
- Tensor of shape (batch_size, seq_len, num_heads, head_dim).
749
- """
750
- # Split the tensor into two halves (odd and even indexed dimensions)
751
- rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1)
752
-
753
- # Reshape the tensor to the original shape
754
- rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,))
755
- return rotate_half
756
-
757
- def _apply_rotary_pos_emb(
758
- self, x: torch.Tensor, sincos: torch.Tensor
759
- ) -> torch.Tensor:
760
- """
761
- Applies rotary embeddings to x.
762
-
763
- Args:
764
- x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
765
- typically this is the key or query tensor.
766
- sincos: Tuple of sine and cosine tensors for position encoding.
767
-
768
- Returns:
769
- RoPE embeddings tensor.
770
- """
771
- sin_pos, cos_pos = sincos
772
-
773
- # Reshape the sin and cos tensors for broadcasting
774
- sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1)
775
- cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1)
776
-
777
- # Apply the rotary embedding mechanism
778
- return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos)
779
-
780
- def __call__(
781
- self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None
782
- ) -> tuple[torch.Tensor, torch.Tensor]:
783
- """
784
- Applies rotary embeddings to k and q.
785
-
786
- Args:
787
- k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
788
- q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
789
- positions: optional positions offset useful when caching,
790
-
791
- Returns:
792
- RoPE embeddings for the keys and values.
793
- """
794
- batch_size, seq_len, num_heads, head_dim = k.shape
795
-
796
- # Generate position ids
797
- position_ids = (
798
- torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1)
799
- )
800
-
801
- if positions is not None:
802
- position_ids += positions
803
-
804
- # Retrieve sincos values using the position_ids
805
- sincos = self.sincos_cache[position_ids]
806
-
807
- # Split sincos into sin_pos and cos_pos
808
- sincos = torch.chunk(sincos, 2, dim=-1)
809
-
810
- # Apply rotary position embedding to key (k) and query (q)
811
- k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos)
812
- k_pass = k[..., self.dim :]
813
-
814
- q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos)
815
- q_pass = q[..., self.dim :]
816
-
817
- # Concatenate the rotated and non-rotated parts
818
- keys = torch.cat([k_rot, k_pass], dim=-1)
819
- values = torch.cat([q_rot, q_pass], dim=-1)
820
-
821
- return keys, values
822
-
823
-
824
- class TorchGptGroupedQueryAttention(nn.Module):
825
- def __init__(
826
- self,
827
- embed_dim: int,
828
- num_heads: int,
829
- rope_config: RotaryEmbeddingConfig,
830
- num_kv_heads: int = None, # type: ignore
831
- head_dim: int = None, # type: ignore
832
- add_bias_attn: bool = False, # type: ignore
833
- ) -> None:
834
- super().__init__()
835
- self.num_heads = num_heads
836
- self.num_kv_heads = num_kv_heads or num_heads
837
- self.embed_dim = embed_dim
838
- self.head_dim = head_dim or (embed_dim // num_heads)
839
- self.add_bias_attn = add_bias_attn
840
- self.rope = TorchRotaryEmbedding(rope_config)
841
-
842
- self.query_linear = nn.Linear(
843
- embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn
844
- )
845
- self.key_linear = nn.Linear(
846
- embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
847
- )
848
- self.value_linear = nn.Linear(
849
- embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
850
- )
851
- self.out_linear = nn.Linear(
852
- self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn
853
- )
854
-
855
- def forward(
856
- self,
857
- query_inputs: torch.Tensor,
858
- key_inputs: torch.Tensor,
859
- value_inputs: torch.Tensor,
860
- attention_mask: torch.Tensor = None,
861
- ) -> torch.Tensor:
862
- batch_size, seq_len, _ = query_inputs.shape
863
-
864
- queries = self.query_linear(query_inputs).view( # noqa
865
- batch_size, seq_len, self.num_heads, self.head_dim
866
- )
867
- keys = self.key_linear(key_inputs).view( # noqa
868
- batch_size, seq_len, self.num_kv_heads, self.head_dim
869
- )
870
- values = self.value_linear(value_inputs).view( # noqa
871
- batch_size, seq_len, self.num_kv_heads, self.head_dim
872
- )
873
-
874
- keys, queries = self.rope(keys, queries)
875
-
876
- n_rep = self.num_heads // self.num_kv_heads
877
- keys = keys.repeat_interleave(n_rep, dim=2)
878
- values = values.repeat_interleave(n_rep, dim=2)
879
-
880
- attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / (
881
- self.head_dim**0.5
882
- )
883
-
884
- if attention_mask is not None:
885
- attention_logits = attention_logits.masked_fill(
886
- attention_mask == 0, float("-inf")
887
- )
888
-
889
- attention_weights = nn.functional.softmax(attention_logits, dim=-1)
890
-
891
- values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
892
- values = values.contiguous().view(batch_size, seq_len, -1)
893
-
894
- return self.out_linear(values)
895
-
896
-
897
- class TorchGptDecoder(nn.Module):
898
- def __init__(self, config: GptConfig, name: Optional[str] = None):
899
- super().__init__()
900
- self.config = config
901
-
902
- self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim)
903
-
904
- if config.norm_type == "layer_norm":
905
- self.final_norm = nn.LayerNorm(config.embed_dim)
906
- elif config.norm_type == "RMS_norm":
907
- self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps)
908
- else:
909
- raise ValueError(f"unrecognized norm_type in config {config.norm_type}")
910
-
911
- self.layers = nn.ModuleList(
912
- [
913
- TorchGptDecoderLayer(
914
- embed_dim=config.embed_dim,
915
- ffn_embed_dim=config.ffn_embed_dim,
916
- num_heads=config.num_heads,
917
- rope_config=config.rope_config,
918
- norm_type=config.norm_type,
919
- parallel_attention_ff=config.parallel_attention_ff,
920
- add_bias_ffn=config.add_bias_ffn,
921
- ffn_activation_name=config.ffn_activation_name,
922
- use_glu_in_ffn=config.use_glu_in_ffn,
923
- num_kv_heads=config.num_kv_heads, # type: ignore
924
- add_bias_attn=config.add_bias_attn,
925
- rms_norm_eps=config.rms_norm_eps,
926
- )
927
- for _ in range(config.num_layers)
928
- ]
929
- )
930
-
931
- self.lm_head = TorchSimpleLMHead(
932
- embed_dim=config.embed_dim,
933
- alphabet_size=config.vocab_size,
934
- add_bias_lm_head=config.add_bias_lm_head,
935
- )
936
-
937
- def apply_transformer_layers(
938
- self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
939
- ) -> torch.Tensor:
940
- if attention_mask is None:
941
- attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
942
- for layer in self.layers:
943
- embeddings = layer(embeddings, attention_mask)
944
-
945
- return embeddings
946
-
947
- def forward(
948
- self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
949
- ) -> dict[str, torch.Tensor]:
950
- if attention_mask is None:
951
- attention_mask = build_causal_attention_mask(1, token_ids.shape[1])
952
-
953
- tokens_embeddings = self.token_embed(token_ids)
954
-
955
- after_transformer_embeddings = self.apply_transformer_layers(
956
- tokens_embeddings, attention_mask=attention_mask
957
- )
958
-
959
- embeddings = self.final_norm(after_transformer_embeddings)
960
- logits = self.lm_head(embeddings)
961
- return {"embeddings": embeddings, "logits": logits}
962
-
963
-
964
- class TorchSimpleLMHead(nn.Module):
965
- def __init__(
966
- self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True
967
- ) -> None:
968
- super().__init__()
969
- self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head)
970
-
971
- def forward(self, x: torch.Tensor) -> torch.Tensor:
972
- return self.fc(x)
973
-
974
-
975
- class TorchGptDecoderLayer(nn.Module):
976
- def __init__(
977
- self,
978
- embed_dim: int,
979
- ffn_embed_dim: int,
980
- num_heads: int,
981
- rope_config: RotaryEmbeddingConfig,
982
- norm_type: str,
983
- parallel_attention_ff: bool,
984
- add_bias_ffn: bool,
985
- ffn_activation_name: str,
986
- use_glu_in_ffn: bool,
987
- num_kv_heads: int,
988
- add_bias_attn: bool,
989
- rms_norm_eps: float = 1e-6,
990
- ) -> None:
991
- super().__init__()
992
- self.num_heads = num_heads
993
- self.parallel_attention_ff = parallel_attention_ff
994
- self.use_glu_in_ffn = use_glu_in_ffn
995
-
996
- # Self-Attention layer
997
- self.self_attn = TorchGptGroupedQueryAttention(
998
- embed_dim=embed_dim,
999
- num_heads=num_heads,
1000
- num_kv_heads=num_kv_heads,
1001
- rope_config=rope_config,
1002
- add_bias_attn=add_bias_attn,
1003
- )
1004
-
1005
- # Normalization layers
1006
- if norm_type == "layer_norm":
1007
- self.attn_norm = nn.LayerNorm(embed_dim)
1008
- if not self.parallel_attention_ff:
1009
- self.ffn_norm = nn.LayerNorm(embed_dim)
1010
- elif norm_type == "RMS_norm":
1011
- self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1012
- if not self.parallel_attention_ff:
1013
- self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1014
- else:
1015
- raise ValueError(f"unrecognized norm_type: {norm_type}")
1016
-
1017
- # Feedforward network
1018
- self.activation = get_activation_fn(ffn_activation_name)
1019
- ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1)
1020
- self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn)
1021
- self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1022
-
1023
- def forward(
1024
- self, embeddings: torch.Tensor, attention_mask: torch.Tensor
1025
- ) -> torch.Tensor:
1026
- residuals = embeddings
1027
-
1028
- if self.parallel_attention_ff:
1029
- # Parallel Attention + MLP
1030
- embeddings_normed = self.attn_norm(embeddings)
1031
-
1032
- attn_output, _ = self.self_attn(
1033
- embeddings_normed,
1034
- embeddings_normed,
1035
- embeddings_normed,
1036
- attn_mask=attention_mask,
1037
- )
1038
- ffn_output = self.mlp(embeddings_normed) # type: ignore
1039
-
1040
- return residuals + attn_output + ffn_output
1041
- else:
1042
- # Sequential Attention + MLP
1043
- normed_embeddings = self.attn_norm(embeddings)
1044
-
1045
- attn_output = embeddings + self.self_attn(
1046
- normed_embeddings,
1047
- normed_embeddings,
1048
- normed_embeddings,
1049
- attention_mask=attention_mask,
1050
- )
1051
-
1052
- normed_embeddings2 = self.ffn_norm(attn_output)
1053
- ffn_output = self.mlp(normed_embeddings2) # type: ignore
1054
- return attn_output + ffn_output # Residual connection
1055
-
1056
- def mlp(self, x: torch.Tensor) -> torch.Tensor:
1057
- """Applies the feedforward network (MLP) with optional GLU."""
1058
- ffn_output = self.fc1(x)
1059
-
1060
- if self.use_glu_in_ffn:
1061
- ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1)
1062
- ffn_output = self.activation(ffn_output1) * ffn_output2
1063
- else:
1064
- ffn_output = self.activation(ffn_output)
1065
-
1066
- return self.fc2(ffn_output)
1067
-
1068
-
1069
- class TorchRMSNorm(nn.Module):
1070
- def __init__(self, dim: int, eps: float = 1e-6) -> None:
1071
- super().__init__()
1072
- self.eps = eps
1073
- self.scale = nn.Parameter(torch.ones(dim))
1074
-
1075
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1076
- return (
1077
- x
1078
- * self.scale
1079
- / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
1080
- )
1081
-
1082
-
1083
- def get_activation_fn(activation_name: str): # type: ignore
1084
- activations = {
1085
- "gelu": nn.functional.gelu,
1086
- "relu": nn.functional.relu,
1087
- "swish": nn.functional.silu,
1088
- "silu": nn.functional.silu,
1089
- }
1090
- return activations.get(activation_name, nn.functional.relu)
1091
-
1092
-
1093
- def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
1094
- """
1095
- Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1096
- to an attention layer.
1097
-
1098
- Args:
1099
- batch_size: Batch size.
1100
- seq_len: Length of the sequences.
1101
-
1102
- Returns:
1103
- Batch of causal masks.
1104
- """
1105
- mask = torch.ones((batch_size, 1, seq_len, seq_len))
1106
- causal_mask = torch.tril(mask)
1107
- return causal_mask
1108
-
1109
-
1110
- @dataclass
1111
- class RotaryEmbeddingConfigBis:
1112
- """
1113
- Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
1114
- to adapt the rotary embeddings to larger lengths than what was used for training.
1115
- One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
1116
- Args:
1117
- """
1118
-
1119
- rescaling_factor: Optional[float]
1120
-
1121
-
1122
- class RotaryEmbeddingBis(torch.nn.Module):
1123
- """
1124
- Rotary position embeddings based on those in
1125
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
1126
- Query and keys are transformed by rotation
1127
- matrices which depend on their relative positions.
1128
- """
1129
-
1130
- def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis):
1131
- super().__init__()
1132
-
1133
- # Extract argument from the config
1134
- self.rescaling_factor = rotary_embedding_config.rescaling_factor
1135
- self.upper_freq = 10000
1136
- self.dim = dim
1137
-
1138
- self._seq_len_cached = None
1139
- self._cos_cached = None
1140
- self._sin_cached = None
1141
-
1142
- def _apply_rotary_pos_emb(
1143
- self,
1144
- heads: torch.Tensor,
1145
- cos: torch.Tensor,
1146
- sin: torch.Tensor,
1147
- ) -> torch.Tensor:
1148
- """ """
1149
- x_first, x_second = (
1150
- heads[..., : heads.shape[-1] // 2],
1151
- heads[..., heads.shape[-1] // 2 :],
1152
- )
1153
-
1154
- first_part = x_first * cos - x_second * sin
1155
- second_part = x_second * cos + x_first * sin
1156
-
1157
- return torch.cat((first_part, second_part), dim=-1)
1158
-
1159
- def _compute_cos_sin_tables(
1160
- self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1161
- ) -> tuple[torch.Tensor, torch.Tensor]:
1162
- seq_len = x.shape[seq_dimension]
1163
- # Reset the tables if the sequence length has changed,
1164
- # or if we're on a new device (possibly due to tracing for instance)
1165
- self._seq_len_cached = seq_len
1166
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
1167
- # freqs = torch.outer(t, inv_freq)
1168
- freqs = torch.einsum("i, j -> ij", t, inv_freq)
1169
-
1170
- self._cos_cached = torch.cos(freqs)[None, :, None, :]
1171
- self._sin_cached = torch.sin(freqs)[None, :, None, :]
1172
- # emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1173
-
1174
- # self._cos_cached = emb.cos()[None, None, :, :]
1175
- # self._sin_cached = emb.sin()[None, None, :, :]
1176
-
1177
- return self._cos_cached, self._sin_cached
1178
-
1179
- def forward(
1180
- self, q: torch.Tensor, k: torch.Tensor
1181
- ) -> Tuple[torch.Tensor, torch.Tensor]:
1182
- if self.rescaling_factor is None:
1183
- inv_freq = 1.0 / (
1184
- self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
1185
- )
1186
- else:
1187
- updated_base = self.upper_freq * (
1188
- self.rescaling_factor ** (self.dim / (self.dim - 2))
1189
- )
1190
- inv_freq = 1.0 / (
1191
- updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
1192
- )
1193
-
1194
- self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1195
- q,
1196
- inv_freq,
1197
- seq_dimension=-3,
1198
- )
1199
-
1200
- return (
1201
- self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1202
- self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1203
- )
1204
-
1205
-
1206
- class MultiHeadAttention(nn.Module):
1207
- def __init__(
1208
- self,
1209
- num_heads: int,
1210
- key_size: int,
1211
- rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1212
- add_bias_kv: bool = False,
1213
- value_size: Optional[int] = None,
1214
- model_size: Optional[int] = None,
1215
- name: Optional[str] = None,
1216
- ):
1217
- super().__init__()
1218
- if not model_size:
1219
- model_size = key_size * num_heads
1220
- if not value_size:
1221
- value_size = key_size
1222
- self.model_size = model_size
1223
- self.key_size = key_size
1224
- self.value_size = value_size
1225
- self.add_bias_kv = add_bias_kv
1226
- self.name = name
1227
- self.num_heads = num_heads
1228
- self._rotary_embedding_config = rotary_embedding_config
1229
-
1230
- self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
1231
- self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
1232
- self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
1233
- self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
1234
- if self._rotary_embedding_config:
1235
- self._rotary_embedding = RotaryEmbeddingBis(
1236
- self.key_size, self._rotary_embedding_config
1237
- )
1238
-
1239
- def apply_rotary_embeddings(
1240
- self,
1241
- query: torch.Tensor,
1242
- key: torch.Tensor,
1243
- ) -> tuple[torch.Tensor, torch.Tensor]:
1244
- """ """
1245
- query, key = self._rotary_embedding(query, key)
1246
- return query, key
1247
-
1248
- def forward(
1249
- self,
1250
- query: torch.Tensor,
1251
- key: torch.Tensor,
1252
- value: torch.Tensor,
1253
- attention_mask: Optional[torch.Tensor] = None,
1254
- attention_weight_bias: Optional[torch.Tensor] = None,
1255
- ) -> dict[str, torch.Tensor]:
1256
- """
1257
- Returns:
1258
- dictionary containing attention weights
1259
- and outputs.
1260
- """
1261
- key_heads = self.w_k(key).reshape(
1262
- (*key.shape[:-1], self.num_heads, self.key_size)
1263
- )
1264
- query_heads = self.w_q(query).reshape(
1265
- (*query.shape[:-1], self.num_heads, self.key_size)
1266
- )
1267
- value_heads = self.w_v(value).reshape(
1268
- (*value.shape[:-1], self.num_heads, self.value_size)
1269
- )
1270
- if self._rotary_embedding_config:
1271
- query_heads, key_heads = self.apply_rotary_embeddings(
1272
- query_heads, key_heads
1273
- )
1274
- attention_weights = torch.einsum(
1275
- "...thd, ...Thd -> ...htT", query_heads, key_heads
1276
- )
1277
- sqrt_key_size = np.sqrt(self.key_size)
1278
- attention_weights = attention_weights / sqrt_key_size
1279
- if attention_mask is not None:
1280
- attention_weights = torch.where(attention_mask, attention_weights, -1e30)
1281
- if attention_weight_bias is not None:
1282
- attention_weights = F.softmax(
1283
- attention_weights + attention_weight_bias, dim=-1
1284
- )
1285
- else:
1286
- attention_weights = F.softmax(attention_weights, dim=-1)
1287
- value_out = torch.einsum(
1288
- "...htT, ...Thd->...thd", attention_weights, value_heads
1289
- )
1290
- value_out = value_out.reshape((*value_out.shape[:-2], -1))
1291
- embeddings = self.output(value_out)
1292
-
1293
- return {"attention_weights": attention_weights, "embeddings": embeddings}
1294
-
1295
-
1296
- class SelfAttentionBlock(nn.Module):
1297
- def __init__(
1298
- self,
1299
- num_heads: int,
1300
- embed_dim: int,
1301
- ffn_embed_dim: int,
1302
- key_size: Optional[int] = None,
1303
- add_bias_kv: bool = False,
1304
- add_bias_fnn: bool = True,
1305
- ffn_activation_name: str = "gelu-no-approx",
1306
- use_glu_in_ffn: bool = False,
1307
- layer_norm_eps: float = 1e-5, # this is the default haiku value
1308
- pre_layer_norm: bool = True,
1309
- name: Optional[str] = None,
1310
- rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1311
- ):
1312
- super().__init__()
1313
- if key_size is None:
1314
- if embed_dim % num_heads != 0:
1315
- raise ValueError(
1316
- f"The embedding dimension should be divisible by the number of "
1317
- f"heads, however provided embedding dimension is {embed_dim} and "
1318
- f"the number of heads is {num_heads}."
1319
- )
1320
- else:
1321
- key_size = embed_dim // num_heads
1322
-
1323
- # Get ffn activation function
1324
- self._pre_layer_norm = pre_layer_norm
1325
- self._use_glu_in_fnn = use_glu_in_ffn
1326
- # Define layers
1327
- if use_glu_in_ffn:
1328
- # user should multiply ffn_embed_dim by 2/3 when using GLU
1329
- # to keep total number of parameters equal
1330
- # see https://arxiv.org/pdf/2002.05202.pdf. for more details
1331
- # we multiply by 2 here as the output will be split in 2 for GLU
1332
- self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
1333
- else:
1334
- self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
1335
-
1336
- self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
1337
-
1338
- self.layer_norm_self_attention = nn.LayerNorm(
1339
- embed_dim,
1340
- )
1341
- self.layer_norm_mlp = nn.LayerNorm(embed_dim)
1342
- if ffn_activation_name == "swish":
1343
- self._ffn_activation_fn = nn.SiLU()
1344
- elif ffn_activation_name == "gelu-no-approx":
1345
- self._ffn_activation_fn = nn.GELU(approximate="tanh")
1346
- else:
1347
- self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
1348
-
1349
- self.mha = MultiHeadAttention(
1350
- num_heads=num_heads,
1351
- key_size=key_size,
1352
- add_bias_kv=add_bias_kv,
1353
- model_size=embed_dim,
1354
- name="self_attention",
1355
- rotary_embedding_config=rotary_embedding_config,
1356
- )
1357
-
1358
- def mlp(self, embed: torch.Tensor) -> torch.Tensor:
1359
-
1360
- if self._pre_layer_norm:
1361
- x = self.layer_norm_mlp(embed)
1362
- else:
1363
- x = embed
1364
-
1365
- if self._use_glu_in_fnn:
1366
- x = self.fc1(x)
1367
- x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
1368
- x = self._ffn_activation_fn(x1) * x2
1369
- else:
1370
- x = self._ffn_activation_fn(self.fc1(x))
1371
- x = self.fc2(x)
1372
-
1373
- if not self._pre_layer_norm:
1374
- x = self.layer_norm_mlp(x + embed)
1375
- return x
1376
-
1377
- def forward(
1378
- self,
1379
- x: torch.Tensor,
1380
- attention_mask: Optional[torch.Tensor] = None,
1381
- attention_weight_bias: Optional[torch.Tensor] = None,
1382
- ) -> dict[str, torch.Tensor]:
1383
-
1384
- res = x
1385
- if self._pre_layer_norm:
1386
- x = self.layer_norm_self_attention(x)
1387
-
1388
- output: dict[str, torch.Tensor] = self.mha(
1389
- x,
1390
- x,
1391
- x,
1392
- attention_mask=attention_mask,
1393
- attention_weight_bias=attention_weight_bias,
1394
- )
1395
-
1396
- if not self._pre_layer_norm:
1397
- output["embeddings"] = self.layer_norm_self_attention(
1398
- output["embeddings"] + res
1399
- )
1400
-
1401
- x = output["embeddings"]
1402
- else:
1403
- x = output["embeddings"]
1404
- x = res + x
1405
-
1406
- # MLP
1407
- if not self._pre_layer_norm:
1408
- x = self.mlp(x)
1409
- else:
1410
- x = x + self.mlp(x)
1411
-
1412
- output["embeddings"] = x
1413
- return output
1414
-
1415
-
1416
- class RobertaLMHead(nn.Module):
1417
- """
1418
- Roberta Language Model head. Transforms final attention layer output into a
1419
- distribution over tokens at each position.
1420
- """
1421
-
1422
- def __init__(self, embed_dim: int, alphabet_size: int):
1423
- """
1424
- Args:
1425
- embed_dim: Embedding dimension.
1426
- alphabet_size: Number of tokens in the alphabet.
1427
- """
1428
- super().__init__()
1429
- self.embed_dim = embed_dim
1430
- self.alphabet_size = alphabet_size
1431
-
1432
- # Define layers
1433
- self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1434
- self._fc1 = nn.Linear(embed_dim, embed_dim)
1435
- self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1436
- self._final_fc = nn.Linear(embed_dim, alphabet_size)
1437
-
1438
- def forward(self, x: torch.Tensor) -> dict:
1439
- x = self._first_layer_norm(x)
1440
- embeddings = x
1441
- x = self._fc1(x)
1442
- x = nn.functional.gelu(x)
1443
- x = self._second_layer_norm(x)
1444
- logits = self._final_fc(x)
1445
- return {"embeddings": embeddings, "logits": logits}
1446
-
1447
-
1448
- class TorchESMTransformer(nn.Module):
1449
- def __init__(
1450
- self,
1451
- esm_config: ESMTransformerConfig,
1452
- ):
1453
- super(TorchESMTransformer, self).__init__()
1454
- self.esm_config = esm_config
1455
-
1456
- # Other cases are not implemented
1457
- assert esm_config.positional_embedding is None
1458
- assert esm_config.lm_head == "roberta"
1459
- assert esm_config.use_rotary_embedding is True
1460
- assert esm_config.token_dropout is False
1461
- assert esm_config.emb_layer_norm_before is False
1462
- assert esm_config.mask_before_attention is False
1463
- assert esm_config.bias_word_embedding is False
1464
- assert esm_config.use_gradient_checkpointing is False
1465
-
1466
- self.embed_layer = nn.Embedding(esm_config.alphabet_size, esm_config.embed_dim)
1467
-
1468
- self.lm_head = RobertaLMHead(
1469
- embed_dim=esm_config.embed_dim,
1470
- alphabet_size=esm_config.alphabet_size,
1471
- )
1472
-
1473
- self.rotary_embedding_config = RotaryEmbeddingConfigBis(
1474
- rescaling_factor=esm_config.rescaling_factor
1475
- )
1476
-
1477
- self.attention_blocks = nn.ModuleList(
1478
- [
1479
- SelfAttentionBlock( # type: ignore
1480
- num_heads=esm_config.attention_heads,
1481
- embed_dim=esm_config.embed_dim,
1482
- key_size=esm_config.key_size,
1483
- ffn_embed_dim=esm_config.ffn_embed_dim,
1484
- add_bias_kv=esm_config.add_bias_kv,
1485
- add_bias_fnn=esm_config.add_bias_ffn,
1486
- ffn_activation_name=esm_config.ffn_activation_name,
1487
- use_glu_in_ffn=esm_config.use_glu_in_ffn,
1488
- rotary_embedding_config=self.rotary_embedding_config,
1489
- layer_norm_eps=esm_config.layer_norm_eps,
1490
- pre_layer_norm=esm_config.pre_layer_norm,
1491
- )
1492
- for _ in range(esm_config.num_layers)
1493
- ]
1494
- )
1495
-
1496
- def forward(
1497
- self, tokens: torch.Tensor, attention_mask: torch.Tensor = None
1498
- ) -> torch.Tensor:
1499
- """
1500
- Computes the embeddings based on the input tokens.
1501
-
1502
- Args:
1503
- tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1504
- attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1505
- If no mask is provided, a mask by default which equals 1 over all non
1506
- pad tokens and 0 over pad tokens is computed.
1507
-
1508
- Returns:
1509
- Dictionary containing the final embeddings and logits.
1510
- """
1511
- x = self.embed_layer(tokens)
1512
-
1513
- # RoBERTa's mask scaling factor
1514
- x = self.esm_config.embed_scale * x
1515
-
1516
- if attention_mask is None:
1517
- attention_mask = build_padding_attention_mask(
1518
- tokens=tokens, pad_token_id=self.esm_config.pad_token_id
1519
- )
1520
-
1521
- for layer in self.attention_blocks:
1522
- x = layer(x, attention_mask)["embeddings"]
1523
-
1524
- assert self.esm_config.lm_head == "roberta"
1525
- x = self.lm_head(x)["embeddings"]
1526
-
1527
- return x
1528
-
1529
-
1530
- def build_padding_attention_mask(
1531
- tokens: torch.Tensor, pad_token_id: int
1532
- ) -> torch.Tensor:
1533
- """
1534
- Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
1535
-
1536
- Args:
1537
- tokens: Batch of sequences of shape (batch_size, seq_len).
1538
- pad_token_id: Int corresponding to the <pad> token to mask.
1539
-
1540
- Returns:
1541
- Batch of attention masks, masking out <pad> tokens.
1542
- """
1543
- padding_mask = tokens != pad_token_id
1544
- padding_mask = padding_mask.unsqueeze(1)
1545
- padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask)
1546
- return padding_mask
1547
-
1548
-
1549
- class TorchBioBrainEncoder(nn.Module):
1550
- def __init__(
1551
- self,
1552
- esm_config: ESMTransformerConfig,
1553
- ):
1554
- super(TorchBioBrainEncoder, self).__init__()
1555
- self.esm_config = esm_config
1556
- self.esm_model = TorchESMTransformer(self.esm_config)
1557
-
1558
- def forward(
1559
- self,
1560
- bio_token_ids: torch.Tensor,
1561
- ) -> torch.Tensor:
1562
- """
1563
- Args:
1564
- bio_token_ids (torch.Tensor):
1565
- Shape (batch_size, num_bio_tokens)
1566
-
1567
- Returns:
1568
- torch.Tensor:
1569
- Shape (batch_size, num_bio_tokens, embed_dim)
1570
- """
1571
- bio_embeddings = self.esm_model(tokens=bio_token_ids)
1572
-
1573
- return bio_embeddings
1574
-
1575
-
1576
- class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1577
- def __init__(
1578
- self,
1579
- num_heads: int,
1580
- embed_dim: int,
1581
- ffn_embed_dim: int,
1582
- key_size: Optional[int] = None,
1583
- add_bias_kv: bool = False,
1584
- add_bias_ffn: bool = True,
1585
- ffn_activation_name: str = "gelu",
1586
- use_glu_in_ffn: bool = False,
1587
- ):
1588
- super().__init__()
1589
-
1590
- if key_size is None:
1591
- if embed_dim % num_heads != 0:
1592
- raise ValueError(
1593
- f"Embedding dimension {embed_dim} should be divisible by "
1594
- f"num_heads {num_heads}."
1595
- )
1596
- key_size = embed_dim // num_heads
1597
-
1598
- self.num_heads = num_heads
1599
- self.embed_dim = embed_dim
1600
- self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim
1601
- self.use_glu_in_ffn = use_glu_in_ffn
1602
-
1603
- self.cross_attention_1 = MultiHeadAttention(
1604
- num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1605
- )
1606
- self.cross_attention_2 = MultiHeadAttention(
1607
- num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1608
- )
1609
-
1610
- self.norm_cross_attention_1 = nn.LayerNorm(embed_dim)
1611
- self.norm_cross_attention_2 = nn.LayerNorm(embed_dim)
1612
- self.norm_mlp = nn.LayerNorm(embed_dim)
1613
-
1614
- self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn)
1615
- self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1616
-
1617
- self.activation_fn = getattr(
1618
- nn.functional, ffn_activation_name, nn.functional.gelu
1619
- )
1620
-
1621
- def mlp(self, x: torch.Tensor) -> torch.Tensor:
1622
- x = self.norm_mlp(x)
1623
- if self.use_glu_in_ffn:
1624
- x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
1625
- x = self.activation_fn(x1) * x2
1626
- else:
1627
- x = self.activation_fn(self.fc1(x))
1628
- return self.fc2(x)
1629
-
1630
- def forward(
1631
- self,
1632
- x: torch.Tensor,
1633
- cross_attention_embeddings_1: torch.Tensor,
1634
- cross_attention_embeddings_2: torch.Tensor,
1635
- attention_mask_1: Optional[torch.Tensor] = None,
1636
- attention_mask_2: Optional[torch.Tensor] = None,
1637
- ) -> Dict[str, torch.Tensor]:
1638
- res = x
1639
- x = self.norm_cross_attention_1(x)
1640
-
1641
- attn_output = self.cross_attention_1(
1642
- query=x,
1643
- key=cross_attention_embeddings_1,
1644
- value=cross_attention_embeddings_1,
1645
- attention_mask=attention_mask_1,
1646
- )["embeddings"]
1647
- x = res + attn_output
1648
-
1649
- res = x
1650
- x = self.norm_cross_attention_2(x)
1651
- attn_output = self.cross_attention_2(
1652
- query=x,
1653
- key=cross_attention_embeddings_2,
1654
- value=cross_attention_embeddings_2,
1655
- attention_mask=attention_mask_2,
1656
- )["embeddings"]
1657
- x = res + attn_output
1658
-
1659
- x = x + self.mlp(x)
1660
-
1661
- return {"embeddings": x}
1662
-
1663
-
1664
- class TorchMultiModalPerceiverResampler(nn.Module):
1665
- """
1666
- Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1667
- """
1668
-
1669
- def __init__(
1670
- self,
1671
- config: PerceiverResamplerConfig,
1672
- name: Optional[str] = None,
1673
- ):
1674
- """
1675
- Initialize a Perceiver Resampler model.
1676
-
1677
- Args:
1678
- config: Dataclass containing model hyperparameters.
1679
- name: Name for module (custom will break weight loading).
1680
- """
1681
- super().__init__()
1682
- self.config = config
1683
- self.name = name
1684
- self.layers = nn.ModuleList(
1685
- [
1686
- TorchMultiModalPerceiverResamplerBlock(
1687
- num_heads=self.config.attention_heads,
1688
- embed_dim=self.config.embed_dim,
1689
- key_size=self.config.key_size,
1690
- ffn_embed_dim=self.config.ffn_embed_dim,
1691
- add_bias_kv=self.config.add_bias_kv,
1692
- add_bias_ffn=self.config.add_bias_ffn,
1693
- ffn_activation_name=self.config.ffn_activation_name,
1694
- use_glu_in_ffn=self.config.use_glu_in_ffn,
1695
- )
1696
- for _ in range(self.config.num_layers)
1697
- ]
1698
- )
1699
-
1700
- self.latent_queries = torch.nn.Parameter(
1701
- torch.randn(self.config.resampled_length, self.config.embed_dim)
1702
- * (
1703
- 1.0
1704
- / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32))
1705
- )
1706
- )
1707
-
1708
- def apply_attention_blocks(
1709
- self,
1710
- x: torch.Tensor,
1711
- xf_1: torch.Tensor,
1712
- xf_2: torch.Tensor,
1713
- outs: Dict[str, torch.Tensor],
1714
- attention_mask_1: Optional[torch.Tensor] = None,
1715
- attention_mask_2: Optional[torch.Tensor] = None,
1716
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1717
- """
1718
- Create the blocks of attention layers and applies them.
1719
- """
1720
- for layer in self.layers:
1721
- concat_input_1 = torch.cat([xf_1, x], dim=1)
1722
- concat_input_2 = torch.cat([xf_2, x], dim=1)
1723
-
1724
- output = layer(
1725
- x=x,
1726
- cross_attention_embeddings_1=concat_input_1,
1727
- cross_attention_embeddings_2=concat_input_2,
1728
- attention_mask_1=attention_mask_1,
1729
- attention_mask_2=attention_mask_2,
1730
- )
1731
- x = output["embeddings"]
1732
-
1733
- return x, outs
1734
-
1735
- def forward(
1736
- self,
1737
- input_embeddings_1: torch.Tensor,
1738
- input_embeddings_2: torch.Tensor,
1739
- attention_mask_1: Optional[torch.Tensor] = None,
1740
- attention_mask_2: Optional[torch.Tensor] = None,
1741
- ) -> Dict[str, torch.Tensor]:
1742
- """
1743
- Computes the embeddings based on the input tokens.
1744
- """
1745
- assert (
1746
- input_embeddings_1.shape[-1] == self.config.embed_dim
1747
- ), "The input embedding dim should match the model embed dim"
1748
- assert (
1749
- input_embeddings_2.shape[-1] == self.config.embed_dim
1750
- ), "The input embedding dim should match the model embed dim"
1751
-
1752
- batch_size = input_embeddings_1.shape[0]
1753
-
1754
- latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1)
1755
-
1756
- outs: Dict[str, torch.Tensor] = {}
1757
- x = latent_queries
1758
-
1759
- x, outs = self.apply_attention_blocks(
1760
- x=x,
1761
- xf_1=input_embeddings_1,
1762
- xf_2=input_embeddings_2,
1763
- outs=outs,
1764
- attention_mask_1=attention_mask_1,
1765
- attention_mask_2=attention_mask_2,
1766
- )
1767
-
1768
- outs["embeddings"] = x
1769
-
1770
- return outs
1771
-
1772
-
1773
- class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1774
- def __init__(
1775
- self,
1776
- perceiver_resampler_config: PerceiverResamplerConfig,
1777
- input_embed_dim: int,
1778
- embed_dim: int,
1779
- bio_pad_token_id: int,
1780
- english_pad_token_id: int,
1781
- english_vocab_size: int,
1782
- ):
1783
- super().__init__()
1784
- self.config = perceiver_resampler_config
1785
- self.input_embed_dim = input_embed_dim
1786
- self.embed_dim = embed_dim
1787
- self.bio_pad_token_id = bio_pad_token_id
1788
- self.english_pad_token_id = english_pad_token_id
1789
- self.english_vocab_size = english_vocab_size
1790
-
1791
- self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1792
- self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1793
- self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
1794
-
1795
- def forward(
1796
- self,
1797
- bio_token_ids: torch.Tensor,
1798
- bio_embeddings: torch.Tensor,
1799
- english_token_ids: torch.Tensor,
1800
- ) -> torch.Tensor:
1801
- """
1802
- Args:
1803
- bio_token_ids (torch.Tensor):
1804
- Shape (batch_size, num_bio_tokens)
1805
-
1806
- bio_embeddings (torch.Tensor):
1807
- Shape (batch_size, num_bio_tokens, embed_dim)
1808
-
1809
- english_token_ids (torch.Tensor):
1810
- Shape (batch_size, num_english_tokens)
1811
- """
1812
- projected_bio_embeddings = self.bio_projection(bio_embeddings)
1813
- english_embeddings = self.token_embedding(english_token_ids)
1814
-
1815
- bio_attention_mask = build_perceiver_padding_attention_mask(
1816
- bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
1817
- )
1818
- english_attention_mask = build_perceiver_padding_attention_mask(
1819
- english_token_ids, self.config.resampled_length, self.english_pad_token_id
1820
- )
1821
-
1822
- projected_embeddings = self.perceiver_resampler(
1823
- input_embeddings_1=projected_bio_embeddings,
1824
- attention_mask_1=bio_attention_mask,
1825
- input_embeddings_2=english_embeddings,
1826
- attention_mask_2=english_attention_mask,
1827
- )["embeddings"]
1828
-
1829
- return projected_embeddings
1830
-
1831
-
1832
- def build_perceiver_padding_attention_mask(
1833
- tokens: torch.Tensor, resampled_length: int, pad_token_id: int
1834
- ) -> torch.Tensor:
1835
- batch_size, seq_len = tokens.shape
1836
- padding_mask = tokens != pad_token_id # (batch_size, seq_len)
1837
-
1838
- padding_mask = torch.cat(
1839
- [
1840
- padding_mask,
1841
- torch.ones(
1842
- (batch_size, resampled_length), dtype=torch.bool, device=tokens.device
1843
- ),
1844
- ],
1845
- dim=1,
1846
- ) # (batch_size, seq_len + resampled_length)
1847
-
1848
- padding_mask = padding_mask[:, None, None, :]
1849
- padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1850
- return padding_mask