Yanisadel commited on
Commit
3279d7a
·
1 Parent(s): 05a712b

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +6 -14
chatNT.py CHANGED
@@ -417,10 +417,6 @@ class TorchBioBrainDecoder(nn.Module):
417
 
418
  # Insert the bio embeddings at the SEQ token positions
419
  processed_tokens_ids = english_token_ids.clone()
420
- print("(debug) Before call tokens embeddings shape : ", tokens_embeddings.shape)
421
- print("(debug) Before call Processed tokens ids shape : ", processed_tokens_ids.shape)
422
- print("(debug) Before call Projected bio embeddings shape : ", projected_bio_embeddings.shape)
423
- print("num bio sequences : ", num_bio_sequences)
424
  for bio_seq_num in range(num_bio_sequences):
425
  tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
426
  processed_tokens_ids,
@@ -428,7 +424,6 @@ class TorchBioBrainDecoder(nn.Module):
428
  projected_bio_embeddings[:, bio_seq_num, :, :],
429
  bio_seq_num=bio_seq_num,
430
  )
431
- print("After call : ", tokens_embeddings.shape)
432
 
433
  # Regular GPT pass through
434
  embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
@@ -471,8 +466,6 @@ class TorchBioBrainDecoder(nn.Module):
471
  - input_embeddings with resampled_embeddings inserted at the SEQ token
472
  - tokens with the SEQ token set to -1
473
  """
474
- print("Tokens : ", list(tokens))
475
- print("seq_token_id : ", self.seq_token_id)
476
 
477
  def _insert(
478
  tokens_1d: torch.Tensor,
@@ -488,7 +481,6 @@ class TorchBioBrainDecoder(nn.Module):
488
  """
489
  indices = torch.where(tokens_1d == self.seq_token_id)[0]
490
  if indices.numel() > 0:
491
- print("going in if")
492
  idx = indices[0].item()
493
  insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
494
  x = torch.cat(
@@ -505,7 +497,6 @@ class TorchBioBrainDecoder(nn.Module):
505
  tokens_1d[idx] = -1
506
  return x, tokens_1d
507
  else:
508
- print("going in else")
509
  return (
510
  input_embeddings,
511
  tokens_1d,
@@ -680,6 +671,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
680
  Shape (batch_size, num_bio_sequences, ?, embed_dim)
681
  """
682
  english_token_ids, bio_token_ids = multi_omics_tokens_ids
 
 
 
 
 
683
 
684
  # Replace config.vocab_size value in english tokens
685
  # We do this because the default vocab size (32000) doesn't match with the
@@ -698,8 +694,6 @@ class TorchMultiOmicsModel(PreTrainedModel):
698
  vocab_size - 1
699
  )
700
 
701
- print("seq token id : ", self.seq_token_id)
702
- print("Tokens at step 1 in multiomics : ", list(english_token_ids))
703
  if bio_token_ids is None:
704
  projected_bio_embeddings = None
705
  else:
@@ -724,9 +718,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
724
  ]
725
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
726
 
727
- # decode
728
- print("Tokens at step 2 in multiomics : ", list(english_token_ids))
729
-
730
  logits = self.biobrain_decoder(
731
  english_token_ids=english_token_ids,
732
  projected_bio_embeddings=projected_bio_embeddings,
 
417
 
418
  # Insert the bio embeddings at the SEQ token positions
419
  processed_tokens_ids = english_token_ids.clone()
 
 
 
 
420
  for bio_seq_num in range(num_bio_sequences):
421
  tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
422
  processed_tokens_ids,
 
424
  projected_bio_embeddings[:, bio_seq_num, :, :],
425
  bio_seq_num=bio_seq_num,
426
  )
 
427
 
428
  # Regular GPT pass through
429
  embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
 
466
  - input_embeddings with resampled_embeddings inserted at the SEQ token
467
  - tokens with the SEQ token set to -1
468
  """
 
 
469
 
470
  def _insert(
471
  tokens_1d: torch.Tensor,
 
481
  """
482
  indices = torch.where(tokens_1d == self.seq_token_id)[0]
483
  if indices.numel() > 0:
 
484
  idx = indices[0].item()
485
  insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
486
  x = torch.cat(
 
497
  tokens_1d[idx] = -1
498
  return x, tokens_1d
499
  else:
 
500
  return (
501
  input_embeddings,
502
  tokens_1d,
 
671
  Shape (batch_size, num_bio_sequences, ?, embed_dim)
672
  """
673
  english_token_ids, bio_token_ids = multi_omics_tokens_ids
674
+ english_token_ids = english_token_ids.clone()
675
+ bio_token_ids = bio_token_ids.clone()
676
+ projection_english_tokens_ids = projection_english_tokens_ids.clone()
677
+ if projected_bio_embeddings is not None:
678
+ projected_bio_embeddings = projected_bio_embeddings.clone()
679
 
680
  # Replace config.vocab_size value in english tokens
681
  # We do this because the default vocab size (32000) doesn't match with the
 
694
  vocab_size - 1
695
  )
696
 
 
 
697
  if bio_token_ids is None:
698
  projected_bio_embeddings = None
699
  else:
 
718
  ]
719
  projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
720
 
721
+ # decode
 
 
722
  logits = self.biobrain_decoder(
723
  english_token_ids=english_token_ids,
724
  projected_bio_embeddings=projected_bio_embeddings,