Yanisadel commited on
Commit
05a712b
·
1 Parent(s): 2164e14

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +8 -9
chatNT.py CHANGED
@@ -471,11 +471,8 @@ class TorchBioBrainDecoder(nn.Module):
471
  - input_embeddings with resampled_embeddings inserted at the SEQ token
472
  - tokens with the SEQ token set to -1
473
  """
474
- print("Insert_embeddings input shape : ")
475
- print("Tokens : ", tokens.shape)
476
- print("Input embeddings : ", input_embeddings.shape)
477
- print("Resampled embeddings : ", resampled_embeddings.shape)
478
- print("Bio seq num : ", bio_seq_num)
479
 
480
  def _insert(
481
  tokens_1d: torch.Tensor,
@@ -489,9 +486,9 @@ class TorchBioBrainDecoder(nn.Module):
489
  resampled_embeddings (torch.Tensor):
490
  Shape (bio_sequence_length, embed_dim,)
491
  """
492
- print("_insert input : ", input_embeddings_1d.shape, resampled_embeddings_1d.shape)
493
  indices = torch.where(tokens_1d == self.seq_token_id)[0]
494
  if indices.numel() > 0:
 
495
  idx = indices[0].item()
496
  insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
497
  x = torch.cat(
@@ -506,9 +503,9 @@ class TorchBioBrainDecoder(nn.Module):
506
  :-1, :
507
  ]
508
  tokens_1d[idx] = -1
509
- print("_insert output : ", x.shape)
510
  return x, tokens_1d
511
  else:
 
512
  return (
513
  input_embeddings,
514
  tokens_1d,
@@ -526,10 +523,8 @@ class TorchBioBrainDecoder(nn.Module):
526
  tokens_acc.append(tokens_out)
527
  embeddings_acc.append(embeddings_out)
528
 
529
- print("(Embeddings_acc[0] shape : ", embeddings_acc[0].shape)
530
  tokens_acc = torch.stack(tokens_acc)
531
  embeddings_acc = torch.stack(embeddings_acc)
532
- print("Embeddings acc shape : ", embeddings_acc.shape)
533
 
534
  return embeddings_acc, tokens_acc
535
 
@@ -703,6 +698,8 @@ class TorchMultiOmicsModel(PreTrainedModel):
703
  vocab_size - 1
704
  )
705
 
 
 
706
  if bio_token_ids is None:
707
  projected_bio_embeddings = None
708
  else:
@@ -728,6 +725,8 @@ class TorchMultiOmicsModel(PreTrainedModel):
728
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
729
 
730
  # decode
 
 
731
  logits = self.biobrain_decoder(
732
  english_token_ids=english_token_ids,
733
  projected_bio_embeddings=projected_bio_embeddings,
 
471
  - input_embeddings with resampled_embeddings inserted at the SEQ token
472
  - tokens with the SEQ token set to -1
473
  """
474
+ print("Tokens : ", list(tokens))
475
+ print("seq_token_id : ", self.seq_token_id)
 
 
 
476
 
477
  def _insert(
478
  tokens_1d: torch.Tensor,
 
486
  resampled_embeddings (torch.Tensor):
487
  Shape (bio_sequence_length, embed_dim,)
488
  """
 
489
  indices = torch.where(tokens_1d == self.seq_token_id)[0]
490
  if indices.numel() > 0:
491
+ print("going in if")
492
  idx = indices[0].item()
493
  insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
494
  x = torch.cat(
 
503
  :-1, :
504
  ]
505
  tokens_1d[idx] = -1
 
506
  return x, tokens_1d
507
  else:
508
+ print("going in else")
509
  return (
510
  input_embeddings,
511
  tokens_1d,
 
523
  tokens_acc.append(tokens_out)
524
  embeddings_acc.append(embeddings_out)
525
 
 
526
  tokens_acc = torch.stack(tokens_acc)
527
  embeddings_acc = torch.stack(embeddings_acc)
 
528
 
529
  return embeddings_acc, tokens_acc
530
 
 
698
  vocab_size - 1
699
  )
700
 
701
+ print("seq token id : ", self.seq_token_id)
702
+ print("Tokens at step 1 in multiomics : ", list(english_token_ids))
703
  if bio_token_ids is None:
704
  projected_bio_embeddings = None
705
  else:
 
725
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
726
 
727
  # decode
728
+ print("Tokens at step 2 in multiomics : ", list(english_token_ids))
729
+
730
  logits = self.biobrain_decoder(
731
  english_token_ids=english_token_ids,
732
  projected_bio_embeddings=projected_bio_embeddings,