Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -471,11 +471,8 @@ class TorchBioBrainDecoder(nn.Module):
|
|
471 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
472 |
- tokens with the SEQ token set to -1
|
473 |
"""
|
474 |
-
print("
|
475 |
-
print("
|
476 |
-
print("Input embeddings : ", input_embeddings.shape)
|
477 |
-
print("Resampled embeddings : ", resampled_embeddings.shape)
|
478 |
-
print("Bio seq num : ", bio_seq_num)
|
479 |
|
480 |
def _insert(
|
481 |
tokens_1d: torch.Tensor,
|
@@ -489,9 +486,9 @@ class TorchBioBrainDecoder(nn.Module):
|
|
489 |
resampled_embeddings (torch.Tensor):
|
490 |
Shape (bio_sequence_length, embed_dim,)
|
491 |
"""
|
492 |
-
print("_insert input : ", input_embeddings_1d.shape, resampled_embeddings_1d.shape)
|
493 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
494 |
if indices.numel() > 0:
|
|
|
495 |
idx = indices[0].item()
|
496 |
insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
|
497 |
x = torch.cat(
|
@@ -506,9 +503,9 @@ class TorchBioBrainDecoder(nn.Module):
|
|
506 |
:-1, :
|
507 |
]
|
508 |
tokens_1d[idx] = -1
|
509 |
-
print("_insert output : ", x.shape)
|
510 |
return x, tokens_1d
|
511 |
else:
|
|
|
512 |
return (
|
513 |
input_embeddings,
|
514 |
tokens_1d,
|
@@ -526,10 +523,8 @@ class TorchBioBrainDecoder(nn.Module):
|
|
526 |
tokens_acc.append(tokens_out)
|
527 |
embeddings_acc.append(embeddings_out)
|
528 |
|
529 |
-
print("(Embeddings_acc[0] shape : ", embeddings_acc[0].shape)
|
530 |
tokens_acc = torch.stack(tokens_acc)
|
531 |
embeddings_acc = torch.stack(embeddings_acc)
|
532 |
-
print("Embeddings acc shape : ", embeddings_acc.shape)
|
533 |
|
534 |
return embeddings_acc, tokens_acc
|
535 |
|
@@ -703,6 +698,8 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
703 |
vocab_size - 1
|
704 |
)
|
705 |
|
|
|
|
|
706 |
if bio_token_ids is None:
|
707 |
projected_bio_embeddings = None
|
708 |
else:
|
@@ -728,6 +725,8 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
728 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
729 |
|
730 |
# decode
|
|
|
|
|
731 |
logits = self.biobrain_decoder(
|
732 |
english_token_ids=english_token_ids,
|
733 |
projected_bio_embeddings=projected_bio_embeddings,
|
|
|
471 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
472 |
- tokens with the SEQ token set to -1
|
473 |
"""
|
474 |
+
print("Tokens : ", list(tokens))
|
475 |
+
print("seq_token_id : ", self.seq_token_id)
|
|
|
|
|
|
|
476 |
|
477 |
def _insert(
|
478 |
tokens_1d: torch.Tensor,
|
|
|
486 |
resampled_embeddings (torch.Tensor):
|
487 |
Shape (bio_sequence_length, embed_dim,)
|
488 |
"""
|
|
|
489 |
indices = torch.where(tokens_1d == self.seq_token_id)[0]
|
490 |
if indices.numel() > 0:
|
491 |
+
print("going in if")
|
492 |
idx = indices[0].item()
|
493 |
insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
|
494 |
x = torch.cat(
|
|
|
503 |
:-1, :
|
504 |
]
|
505 |
tokens_1d[idx] = -1
|
|
|
506 |
return x, tokens_1d
|
507 |
else:
|
508 |
+
print("going in else")
|
509 |
return (
|
510 |
input_embeddings,
|
511 |
tokens_1d,
|
|
|
523 |
tokens_acc.append(tokens_out)
|
524 |
embeddings_acc.append(embeddings_out)
|
525 |
|
|
|
526 |
tokens_acc = torch.stack(tokens_acc)
|
527 |
embeddings_acc = torch.stack(embeddings_acc)
|
|
|
528 |
|
529 |
return embeddings_acc, tokens_acc
|
530 |
|
|
|
698 |
vocab_size - 1
|
699 |
)
|
700 |
|
701 |
+
print("seq token id : ", self.seq_token_id)
|
702 |
+
print("Tokens at step 1 in multiomics : ", list(english_token_ids))
|
703 |
if bio_token_ids is None:
|
704 |
projected_bio_embeddings = None
|
705 |
else:
|
|
|
725 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
726 |
|
727 |
# decode
|
728 |
+
print("Tokens at step 2 in multiomics : ", list(english_token_ids))
|
729 |
+
|
730 |
logits = self.biobrain_decoder(
|
731 |
english_token_ids=english_token_ids,
|
732 |
projected_bio_embeddings=projected_bio_embeddings,
|