Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -417,10 +417,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
417 |
|
418 |
# Insert the bio embeddings at the SEQ token positions
|
419 |
processed_tokens_ids = english_token_ids.clone()
|
420 |
-
print("(debug) Before call tokens embeddings shape : ", tokens_embeddings.shape)
|
421 |
-
print("(debug) Before call Processed tokens ids shape : ", processed_tokens_ids.shape)
|
422 |
-
print("(debug) Before call Projected bio embeddings shape : ", projected_bio_embeddings.shape)
|
423 |
-
print("num bio sequences : ", num_bio_sequences)
|
424 |
for bio_seq_num in range(num_bio_sequences):
|
425 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
426 |
processed_tokens_ids,
|
@@ -428,7 +424,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
428 |
projected_bio_embeddings[:, bio_seq_num, :, :],
|
429 |
bio_seq_num=bio_seq_num,
|
430 |
)
|
431 |
-
print("After call : ", tokens_embeddings.shape)
|
432 |
|
433 |
# Regular GPT pass through
|
434 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
@@ -471,8 +466,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
471 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
472 |
- tokens with the SEQ token set to -1
|
473 |
"""
|
474 |
-
print("Tokens : ", list(tokens))
|
475 |
-
print("seq_token_id : ", self.seq_token_id)
|
476 |
|
477 |
def _insert(
|
478 |
tokens_1d: torch.Tensor,
|
@@ -488,7 +481,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
488 |
"""
|
489 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
490 |
if indices.numel() > 0:
|
491 |
-
print("going in if")
|
492 |
idx = indices[0].item()
|
493 |
insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
|
494 |
x = torch.cat(
|
@@ -505,7 +497,6 @@ class TorchBioBrainDecoder(nn.Module):
|
|
505 |
tokens_1d[idx] = -1
|
506 |
return x, tokens_1d
|
507 |
else:
|
508 |
-
print("going in else")
|
509 |
return (
|
510 |
input_embeddings,
|
511 |
tokens_1d,
|
@@ -680,6 +671,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
680 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
681 |
"""
|
682 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
# Replace config.vocab_size value in english tokens
|
685 |
# We do this because the default vocab size (32000) doesn't match with the
|
@@ -698,8 +694,6 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
698 |
vocab_size - 1
|
699 |
)
|
700 |
|
701 |
-
print("seq token id : ", self.seq_token_id)
|
702 |
-
print("Tokens at step 1 in multiomics : ", list(english_token_ids))
|
703 |
if bio_token_ids is None:
|
704 |
projected_bio_embeddings = None
|
705 |
else:
|
@@ -724,9 +718,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
724 |
]
|
725 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
726 |
|
727 |
-
# decode
|
728 |
-
print("Tokens at step 2 in multiomics : ", list(english_token_ids))
|
729 |
-
|
730 |
logits = self.biobrain_decoder(
|
731 |
english_token_ids=english_token_ids,
|
732 |
projected_bio_embeddings=projected_bio_embeddings,
|
|
|
417 |
|
418 |
# Insert the bio embeddings at the SEQ token positions
|
419 |
processed_tokens_ids = english_token_ids.clone()
|
|
|
|
|
|
|
|
|
420 |
for bio_seq_num in range(num_bio_sequences):
|
421 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
422 |
processed_tokens_ids,
|
|
|
424 |
projected_bio_embeddings[:, bio_seq_num, :, :],
|
425 |
bio_seq_num=bio_seq_num,
|
426 |
)
|
|
|
427 |
|
428 |
# Regular GPT pass through
|
429 |
embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
|
|
|
466 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
467 |
- tokens with the SEQ token set to -1
|
468 |
"""
|
|
|
|
|
469 |
|
470 |
def _insert(
|
471 |
tokens_1d: torch.Tensor,
|
|
|
481 |
"""
|
482 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
483 |
if indices.numel() > 0:
|
|
|
484 |
idx = indices[0].item()
|
485 |
insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
|
486 |
x = torch.cat(
|
|
|
497 |
tokens_1d[idx] = -1
|
498 |
return x, tokens_1d
|
499 |
else:
|
|
|
500 |
return (
|
501 |
input_embeddings,
|
502 |
tokens_1d,
|
|
|
671 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
672 |
"""
|
673 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
674 |
+
english_token_ids = english_token_ids.clone()
|
675 |
+
bio_token_ids = bio_token_ids.clone()
|
676 |
+
projection_english_tokens_ids = projection_english_tokens_ids.clone()
|
677 |
+
if projected_bio_embeddings is not None:
|
678 |
+
projected_bio_embeddings = projected_bio_embeddings.clone()
|
679 |
|
680 |
# Replace config.vocab_size value in english tokens
|
681 |
# We do this because the default vocab size (32000) doesn't match with the
|
|
|
694 |
vocab_size - 1
|
695 |
)
|
696 |
|
|
|
|
|
697 |
if bio_token_ids is None:
|
698 |
projected_bio_embeddings = None
|
699 |
else:
|
|
|
718 |
]
|
719 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
720 |
|
721 |
+
# decode
|
|
|
|
|
722 |
logits = self.biobrain_decoder(
|
723 |
english_token_ids=english_token_ids,
|
724 |
projected_bio_embeddings=projected_bio_embeddings,
|