JoeArmani commited on
Commit
d53c64b
·
1 Parent(s): fc5f33b

fix checkpointing restoration

Browse files
Files changed (2) hide show
  1. chatbot_model.py +171 -137
  2. train_model.py +14 -10
chatbot_model.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from transformers import TFAutoModel, AutoTokenizer
3
  import tensorflow as tf
4
  from typing import List, Tuple, Dict, Optional, Union, Any
@@ -74,7 +75,8 @@ class EncoderModel(tf.keras.Model):
74
  self.projection = tf.keras.layers.Dense(
75
  config.embedding_dim,
76
  activation='tanh',
77
- name="projection"
 
78
  )
79
  self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
80
  self.normalize = tf.keras.layers.Lambda(
@@ -613,20 +615,20 @@ class RetrievalChatbot(DeviceAwareModel):
613
  test_mode: bool = False,
614
  initial_epoch: int = 0
615
  ) -> None:
616
- """Training using a pre-prepared TFRecord dataset."""
 
 
 
 
 
 
 
 
617
  logger.info("Starting training with pre-prepared TFRecord dataset...")
618
-
619
  def parse_tfrecord_fn(example_proto, max_length, neg_samples):
620
  """
621
  Parses a single TFRecord example.
622
-
623
- Args:
624
- example_proto: A serialized TFRecord example.
625
- max_length: The maximum sequence length for tokenization.
626
- neg_samples: The number of hard negatives per query.
627
-
628
- Returns:
629
- A tuple of (query_ids, positive_ids, negative_ids).
630
  """
631
  feature_description = {
632
  'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
@@ -640,9 +642,9 @@ class RetrievalChatbot(DeviceAwareModel):
640
  negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
641
  negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
642
 
643
- return query_ids, positive_ids, negative_ids
644
-
645
- # Calculate total steps by counting the number of records in the TFRecord
646
  raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
647
  total_pairs = sum(1 for _ in raw_dataset)
648
  logger.info(f"Total pairs in TFRecord: {total_pairs}")
@@ -652,7 +654,7 @@ class RetrievalChatbot(DeviceAwareModel):
652
  steps_per_epoch = math.ceil(train_size / batch_size)
653
  val_steps = math.ceil(val_size / batch_size)
654
  total_steps = steps_per_epoch * epochs
655
- buffer_size = total_pairs // 10 # 10% of the dataset
656
 
657
  logger.info(f"Training pairs: {train_size}")
658
  logger.info(f"Validation pairs: {val_size}")
@@ -660,61 +662,104 @@ class RetrievalChatbot(DeviceAwareModel):
660
  logger.info(f"Validation steps: {val_steps}")
661
  logger.info(f"Total steps: {total_steps}")
662
 
663
- # Set up optimizer with learning rate schedule
664
  if use_lr_schedule:
665
  warmup_steps = int(total_steps * warmup_steps_ratio)
666
  lr_schedule = self._get_lr_schedule(
667
  total_steps=total_steps,
668
- peak_lr=peak_lr,
669
  warmup_steps=warmup_steps
670
  )
671
  self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
672
  logger.info("Using custom learning rate schedule.")
673
  else:
674
- self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
675
  logger.info("Using fixed learning rate.")
676
 
677
- # Initialize checkpoint manager
 
 
 
 
 
 
 
 
678
  checkpoint = tf.train.Checkpoint(
679
  epoch=tf.Variable(0, dtype=tf.int32),
680
  optimizer=self.optimizer,
681
- optimizer_iterations=self.optimizer.iterations,
682
- model=self.encoder,
683
- variables=self.encoder.variables
 
 
 
 
 
684
  )
685
- manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3, checkpoint_name='ckpt')
686
 
687
- # Restore from checkpoint if available
688
  latest_checkpoint = manager.latest_checkpoint
689
- #history_path = Path(checkpoint_dir) / 'training_history.json'
690
- if latest_checkpoint:
691
- # if history_path.exists():
692
- # try:
693
- # with open(history_path, 'r') as f:
694
- # self.history = json.load(f)
695
- # logger.info(f"Loaded previous training history from {history_path}")
696
- # except Exception as e:
697
- # logger.warning(f"Could not load history, starting fresh: {e}")
698
- # self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
699
- # else:
700
- # self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  status = checkpoint.restore(latest_checkpoint)
703
- status.expect_partial()
704
  logger.info(f"Restored from checkpoint: {latest_checkpoint}")
 
705
 
706
- # Get the checkpoint number to validate initial_epoch
 
 
 
 
 
 
 
707
  ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
708
  if initial_epoch == 0:
709
  initial_epoch = ckpt_number
710
-
711
- checkpoint.epoch.assign(initial_epoch)
 
712
  logger.info(f"Resuming from epoch {initial_epoch}")
 
 
 
 
 
 
 
 
 
713
  else:
714
  logger.info("Starting training from scratch")
 
715
  initial_epoch = 0
716
-
717
- # Setup TensorBoard
718
  log_dir = Path(checkpoint_dir) / "tensorboard_logs"
719
  log_dir.mkdir(parents=True, exist_ok=True)
720
  current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
@@ -724,18 +769,15 @@ class RetrievalChatbot(DeviceAwareModel):
724
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
725
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
726
 
727
- # Define the parsing function with the appropriate max_length and neg_samples
728
- parse_fn = lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples)
729
-
730
- # Create the full dataset
731
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
732
-
733
- # Test mode for debugging
734
  if test_mode:
735
- subset_size = 200
736
  dataset = dataset.take(subset_size)
737
  logger.info(f"TEST MODE: Using only {subset_size} examples")
738
- # Recalculate sizes
739
  total_pairs = subset_size
740
  train_size = int(total_pairs * (1 - validation_split))
741
  val_size = total_pairs - train_size
@@ -743,22 +785,23 @@ class RetrievalChatbot(DeviceAwareModel):
743
  steps_per_epoch = math.ceil(train_size / batch_size)
744
  val_steps = math.ceil(val_size / batch_size)
745
  total_steps = steps_per_epoch * epochs
746
- buffer_size = total_pairs // 10 # 10% of the dataset
747
- epochs = min(epochs, 5) # Limit epochs in test mode
748
  early_stopping_patience = 2
749
  logger.info(f"New training pairs: {train_size}")
750
  logger.info(f"New validation pairs: {val_size}")
751
-
752
- dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
753
-
754
- # Split into training and validation sets
 
 
 
755
  train_dataset = dataset.take(train_size)
756
  val_dataset = dataset.skip(train_size).take(val_size)
757
 
758
- # Shuffle the training data
759
  train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
760
-
761
- # Batch both datasets
762
  train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
763
  train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
764
 
@@ -773,41 +816,34 @@ class RetrievalChatbot(DeviceAwareModel):
773
  for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
774
  checkpoint.epoch.assign(epoch)
775
  logger.info(f"Starting Epoch {epoch}...")
 
776
  # --- Training Phase ---
777
- epoch_loss_avg = tf.keras.metrics.Mean()
778
  batches_processed = 0
779
 
 
780
  try:
781
- train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch")
 
 
 
 
782
  is_tqdm_train = True
783
  except ImportError:
784
  train_pbar = None
785
  is_tqdm_train = False
786
- logger.info("Training progress bar disabled")
787
 
788
  for q_batch, p_batch, n_batch in train_dataset:
789
  loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
790
-
791
- # Check for gradient issues
792
- grad_norm_value = float(grad_norm.numpy())
793
- post_clip_value = float(post_clip_norm.numpy())
794
- if grad_norm_value < 1e-7:
795
- logger.warning(f"Potential vanishing gradient detected: norm = {grad_norm_value:.2e}")
796
- elif grad_norm_value > 100:
797
- logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}")
798
-
799
- # if grad_norm_value != post_clip_value:
800
- # logger.info(f"Gradient clipped: {grad_norm_value:.2e} -> {post_clip_value:.2e}")
801
-
802
  epoch_loss_avg(loss)
803
  batches_processed += 1
804
 
805
  # Log to TensorBoard
806
  with train_summary_writer.as_default():
807
  step = (epoch - 1) * steps_per_epoch + batches_processed
808
- tf.summary.scalar("loss", loss, step=step)
809
- tf.summary.scalar("gradient_norm_pre_clip", grad_norm, step=step)
810
- tf.summary.scalar("gradient_norm_post_clip", post_clip_norm, step=step)
811
 
812
  # Update progress bar
813
  if use_lr_schedule:
@@ -819,15 +855,15 @@ class RetrievalChatbot(DeviceAwareModel):
819
  train_pbar.update(1)
820
  train_pbar.set_postfix({
821
  "loss": f"{loss.numpy():.4f}",
822
- "pre_clip": f"{grad_norm_value:.2e}",
823
- "post_clip": f"{post_clip_value:.2e}",
824
  "lr": f"{current_lr:.2e}",
825
  "batches": f"{batches_processed}/{steps_per_epoch}"
826
  })
827
 
828
- # Memory cleanup
829
  gc.collect()
830
 
 
831
  if batches_processed >= steps_per_epoch:
832
  break
833
 
@@ -835,7 +871,7 @@ class RetrievalChatbot(DeviceAwareModel):
835
  train_pbar.close()
836
 
837
  # --- Validation Phase ---
838
- val_loss_avg = tf.keras.metrics.Mean()
839
  val_batches_processed = 0
840
 
841
  try:
@@ -844,16 +880,16 @@ class RetrievalChatbot(DeviceAwareModel):
844
  except ImportError:
845
  val_pbar = None
846
  is_tqdm_val = False
847
- logger.info("Validation progress bar disabled")
848
 
849
- last_valid_val_loss = None # Initialize outside the loop
850
  valid_batches = False
851
-
852
  for q_batch, p_batch, n_batch in val_dataset:
853
- if tf.shape(q_batch)[0] < 2:
854
- logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]} (too small for loss calculation)")
 
855
  continue
856
-
857
  valid_batches = True
858
  val_loss = self.validation_step(q_batch, p_batch, n_batch)
859
  val_loss_avg(val_loss)
@@ -867,32 +903,30 @@ class RetrievalChatbot(DeviceAwareModel):
867
  "batches": f"{val_batches_processed}/{val_steps}"
868
  })
869
 
870
- # Memory cleanup
871
  gc.collect()
872
 
873
  if val_batches_processed >= val_steps:
874
  break
875
-
876
  if not valid_batches:
877
- logger.warning("No valid validation batches in this epoch, using last known validation loss")
 
878
  if last_valid_val_loss is not None:
879
  val_loss = last_valid_val_loss
880
  val_loss_avg(val_loss)
881
  else:
882
- # If we've never had a valid batch (first epoch), use training loss as fallback
883
- logger.warning("No previous validation loss available, using training loss as fallback")
884
- val_loss = train_loss
885
  val_loss_avg(val_loss)
886
-
887
  if is_tqdm_val and val_pbar:
888
  val_pbar.close()
889
 
890
- # End of epoch: compute final epoch stats, log, and save checkpoint
891
  train_loss = epoch_loss_avg.result().numpy()
892
  val_loss = val_loss_avg.result().numpy()
893
  logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
894
 
895
- # Log epoch metrics
896
  with train_summary_writer.as_default():
897
  tf.summary.scalar("epoch_loss", train_loss, step=epoch)
898
  with val_summary_writer.as_default():
@@ -900,31 +934,38 @@ class RetrievalChatbot(DeviceAwareModel):
900
 
901
  # Save checkpoint
902
  manager.save()
903
-
904
- # Save model after each epoch for testing/inference
905
  model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
906
  self.save_models(model_save_path)
907
  logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
908
 
909
- # Store metrics in history
910
  self.history['train_loss'].append(train_loss)
911
  self.history['val_loss'].append(val_loss)
912
-
913
- if use_lr_schedule:
914
- current_lr = float(lr_schedule(self.optimizer.iterations))
915
- else:
916
- current_lr = float(self.optimizer.learning_rate.numpy())
917
-
918
- # Log learning rate
919
  self.history.setdefault('learning_rate', []).append(current_lr)
920
-
921
- # Save history to file
922
- #if history_path.exists():
923
- # with open(history_path, 'w') as f:
924
- # json.dump(self.history, f)
925
- # logger.info(f"Saved training history to {history_path}")
926
 
927
- # Early stopping logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
  if val_loss < best_val_loss - min_delta:
929
  best_val_loss = val_loss
930
  epochs_no_improve = 0
@@ -980,26 +1021,20 @@ class RetrievalChatbot(DeviceAwareModel):
980
  # Now compute scores: dot product of q_enc with each column in combined_p_n
981
  # We'll use `tf.einsum` to handle the batch dimension properly
982
  # dot_products => shape [batch_size, (1+neg_samples)]
983
- dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
984
-
985
- # The label for each row is 0 (the first column is the correct/positive)
986
- labels = tf.zeros([bs], dtype=tf.int32)
987
-
988
- # Cross-entropy over the [batch_size, 1+neg_samples] scores
989
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
990
  labels=labels,
991
  logits=dot_products
992
  )
993
- loss = tf.reduce_mean(loss)
994
 
995
  # Calculate gradients
996
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
997
- gradients_norm = tf.linalg.global_norm(gradients)
998
-
999
- # Clip gradients if norm exceeds threshold
1000
- max_grad_norm = 1.5
1001
  gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
1002
- post_clip_norm = tf.linalg.global_norm(gradients)
1003
 
1004
  # Apply gradients
1005
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
@@ -1032,14 +1067,14 @@ class RetrievalChatbot(DeviceAwareModel):
1032
  axis=1
1033
  )
1034
 
1035
- dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
1036
- labels = tf.zeros([bs], dtype=tf.int32)
1037
-
1038
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1039
  labels=labels,
1040
  logits=dot_products
1041
  )
1042
- loss = tf.reduce_mean(loss)
1043
 
1044
  return loss
1045
 
@@ -1066,8 +1101,8 @@ class RetrievalChatbot(DeviceAwareModel):
1066
  self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
1067
 
1068
  # Calculate and store constants
1069
- self.initial_lr = self.peak_lr * 0.1 # Start at 10% of peak
1070
- self.min_lr = self.peak_lr * 0.01 # Minimum 1% of peak
1071
 
1072
  logger.info(f"Learning rate schedule initialized:")
1073
  logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
@@ -1080,15 +1115,14 @@ class RetrievalChatbot(DeviceAwareModel):
1080
  step = tf.cast(step, tf.float32)
1081
 
1082
  # Warmup phase
1083
- warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
1084
  warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
1085
 
1086
  # Decay phase
1087
- decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
1088
- decay_factor = (step - self.warmup_steps) / decay_steps
1089
- decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0) # Clip to [0,1]
1090
-
1091
- cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(math.pi) * decay_factor))
1092
  decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
1093
 
1094
  # Choose between warmup and decay
 
1
  import os
2
+ import numpy as np
3
  from transformers import TFAutoModel, AutoTokenizer
4
  import tensorflow as tf
5
  from typing import List, Tuple, Dict, Optional, Union, Any
 
75
  self.projection = tf.keras.layers.Dense(
76
  config.embedding_dim,
77
  activation='tanh',
78
+ name="projection",
79
+ dtype=tf.float32
80
  )
81
  self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
82
  self.normalize = tf.keras.layers.Lambda(
 
615
  test_mode: bool = False,
616
  initial_epoch: int = 0
617
  ) -> None:
618
+ """
619
+ Train the retrieval model using a pre-prepared TFRecord dataset.
620
+ This method handles:
621
+ - Checkpoint loading/restoring
622
+ - LR scheduling
623
+ - Epoch/iteration tracking
624
+ - Optional training-history logging
625
+ - Basic early stopping
626
+ """
627
  logger.info("Starting training with pre-prepared TFRecord dataset...")
628
+
629
  def parse_tfrecord_fn(example_proto, max_length, neg_samples):
630
  """
631
  Parses a single TFRecord example.
 
 
 
 
 
 
 
 
632
  """
633
  feature_description = {
634
  'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
 
642
  negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
643
  negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
644
 
645
+ return query_ids, positive_ids, negative_ids
646
+
647
+ # Count total records in TFRecord
648
  raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
649
  total_pairs = sum(1 for _ in raw_dataset)
650
  logger.info(f"Total pairs in TFRecord: {total_pairs}")
 
654
  steps_per_epoch = math.ceil(train_size / batch_size)
655
  val_steps = math.ceil(val_size / batch_size)
656
  total_steps = steps_per_epoch * epochs
657
+ buffer_size = max(1, total_pairs // 10) # 10% of the dataset
658
 
659
  logger.info(f"Training pairs: {train_size}")
660
  logger.info(f"Validation pairs: {val_size}")
 
662
  logger.info(f"Validation steps: {val_steps}")
663
  logger.info(f"Total steps: {total_steps}")
664
 
665
+ # Set up optimizer & LR schedule
666
  if use_lr_schedule:
667
  warmup_steps = int(total_steps * warmup_steps_ratio)
668
  lr_schedule = self._get_lr_schedule(
669
  total_steps=total_steps,
670
+ peak_lr=tf.cast(peak_lr, tf.float32),
671
  warmup_steps=warmup_steps
672
  )
673
  self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
674
  logger.info("Using custom learning rate schedule.")
675
  else:
676
+ self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32))
677
  logger.info("Using fixed learning rate.")
678
 
679
+ # Initialize optimizer with dummy step
680
+ dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
681
+ with tf.GradientTape() as tape:
682
+ dummy_output = self.encoder(dummy_input)
683
+ dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32)
684
+ dummy_grads = tape.gradient(dummy_loss, self.encoder.trainable_variables)
685
+ self.optimizer.apply_gradients(zip(dummy_grads, self.encoder.trainable_variables))
686
+
687
+ # Create checkpoint and manager
688
  checkpoint = tf.train.Checkpoint(
689
  epoch=tf.Variable(0, dtype=tf.int32),
690
  optimizer=self.optimizer,
691
+ model=self.encoder
692
+ )
693
+
694
+ manager = tf.train.CheckpointManager(
695
+ checkpoint,
696
+ directory=checkpoint_dir,
697
+ max_to_keep=3,
698
+ checkpoint_name='ckpt'
699
  )
 
700
 
701
+ # Restore from existing checkpoint if present
702
  latest_checkpoint = manager.latest_checkpoint
703
+ history_path = Path(checkpoint_dir) / 'training_history.json'
704
+
705
+ # If you want to log all epoch losses across runs
706
+ if not hasattr(self, 'history'):
707
+ self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
708
+
709
+ if latest_checkpoint and not test_mode:
710
+ # Debug info before restore
711
+ logger.info("\nEncoder Variables:")
712
+ for var in self.encoder.variables:
713
+ logger.info(f"{var.name}: {var.dtype} - Shape: {var.shape}")
 
714
 
715
+ logger.info("\nOptimizer Variables:")
716
+ for var in self.optimizer.variables:
717
+ logger.info(f"{var.name}: {var.dtype} - Shape: {var.shape}")
718
+
719
+ # Add checkpoint inspection
720
+ logger.info("\nTrying to load checkpoint from: ", latest_checkpoint)
721
+ reader = tf.train.load_checkpoint(latest_checkpoint)
722
+ shape_from_key = reader.get_variable_to_shape_map()
723
+ dtype_from_key = reader.get_variable_to_dtype_map()
724
+ logger.info("\nCheckpoint Variables:")
725
+ for key in shape_from_key:
726
+ logger.info(f"{key}: dtype={dtype_from_key[key]} - Shape: {shape_from_key[key]}")
727
+
728
  status = checkpoint.restore(latest_checkpoint)
729
+ status.assert_consumed()
730
  logger.info(f"Restored from checkpoint: {latest_checkpoint}")
731
+ logger.info(f"Optimizer iterations after restore: {self.optimizer.iterations.numpy()}")
732
 
733
+ # Verify learning rate after restore
734
+ if use_lr_schedule:
735
+ current_lr = float(lr_schedule(self.optimizer.iterations))
736
+ else:
737
+ current_lr = float(self.optimizer.learning_rate.numpy())
738
+ logger.info(f"Current learning rate after restore: {current_lr:.2e}")
739
+
740
+ # Derive initial_epoch from checkpoint name if not passed in
741
  ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
742
  if initial_epoch == 0:
743
  initial_epoch = ckpt_number
744
+
745
+ # Assign to checkpoint.epoch so we keep counting from that
746
+ checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32))
747
  logger.info(f"Resuming from epoch {initial_epoch}")
748
+
749
+ # If you want to load old history from file:
750
+ if history_path.exists():
751
+ try:
752
+ with open(history_path, 'r') as f:
753
+ self.history = json.load(f)
754
+ logger.info(f"Loaded previous training history from {history_path}")
755
+ except Exception as e:
756
+ logger.warning(f"Could not load history, starting fresh: {e}")
757
  else:
758
  logger.info("Starting training from scratch")
759
+ checkpoint.epoch.assign(tf.cast(0, tf.int32))
760
  initial_epoch = 0
761
+
762
+ # Set up TensorBoard
763
  log_dir = Path(checkpoint_dir) / "tensorboard_logs"
764
  log_dir.mkdir(parents=True, exist_ok=True)
765
  current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
 
769
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
770
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
771
 
772
+ # Parse dataset
 
 
 
773
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
774
+
775
+ # Optional: test/debug mode with small subset
776
  if test_mode:
777
+ subset_size = 150
778
  dataset = dataset.take(subset_size)
779
  logger.info(f"TEST MODE: Using only {subset_size} examples")
780
+ # Recompute sizes, steps, epochs, etc., as needed
781
  total_pairs = subset_size
782
  train_size = int(total_pairs * (1 - validation_split))
783
  val_size = total_pairs - train_size
 
785
  steps_per_epoch = math.ceil(train_size / batch_size)
786
  val_steps = math.ceil(val_size / batch_size)
787
  total_steps = steps_per_epoch * epochs
788
+ buffer_size = max(1, total_pairs // 10)
789
+ epochs = min(epochs, 5) # For quick debug
790
  early_stopping_patience = 2
791
  logger.info(f"New training pairs: {train_size}")
792
  logger.info(f"New validation pairs: {val_size}")
793
+
794
+ dataset = dataset.map(
795
+ lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples),
796
+ num_parallel_calls=tf.data.AUTOTUNE
797
+ )
798
+
799
+ # Train/val split
800
  train_dataset = dataset.take(train_size)
801
  val_dataset = dataset.skip(train_size).take(val_size)
802
 
803
+ # Shuffle and batch
804
  train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
 
 
805
  train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
806
  train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
807
 
 
816
  for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
817
  checkpoint.epoch.assign(epoch)
818
  logger.info(f"Starting Epoch {epoch}...")
819
+
820
  # --- Training Phase ---
821
+ epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
822
  batches_processed = 0
823
 
824
+ # Progress bar
825
  try:
826
+ train_pbar = tqdm(
827
+ total=steps_per_epoch,
828
+ desc=f"Training Epoch {epoch}",
829
+ unit="batch"
830
+ )
831
  is_tqdm_train = True
832
  except ImportError:
833
  train_pbar = None
834
  is_tqdm_train = False
 
835
 
836
  for q_batch, p_batch, n_batch in train_dataset:
837
  loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
 
 
 
 
 
 
 
 
 
 
 
 
838
  epoch_loss_avg(loss)
839
  batches_processed += 1
840
 
841
  # Log to TensorBoard
842
  with train_summary_writer.as_default():
843
  step = (epoch - 1) * steps_per_epoch + batches_processed
844
+ tf.summary.scalar("loss", tf.cast(loss, tf.float32), step=step)
845
+ tf.summary.scalar("gradient_norm_pre_clip", tf.cast(grad_norm, tf.float32), step=step)
846
+ tf.summary.scalar("gradient_norm_post_clip", tf.cast(post_clip_norm, tf.float32), step=step)
847
 
848
  # Update progress bar
849
  if use_lr_schedule:
 
855
  train_pbar.update(1)
856
  train_pbar.set_postfix({
857
  "loss": f"{loss.numpy():.4f}",
858
+ "pre_clip": f"{grad_norm.numpy():.2e}",
859
+ "post_clip": f"{post_clip_norm.numpy():.2e}",
860
  "lr": f"{current_lr:.2e}",
861
  "batches": f"{batches_processed}/{steps_per_epoch}"
862
  })
863
 
 
864
  gc.collect()
865
 
866
+ # End the epoch early if we've processed all steps
867
  if batches_processed >= steps_per_epoch:
868
  break
869
 
 
871
  train_pbar.close()
872
 
873
  # --- Validation Phase ---
874
+ val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
875
  val_batches_processed = 0
876
 
877
  try:
 
880
  except ImportError:
881
  val_pbar = None
882
  is_tqdm_val = False
 
883
 
884
+ last_valid_val_loss = None
885
  valid_batches = False
886
+
887
  for q_batch, p_batch, n_batch in val_dataset:
888
+ # If batch is too small, skip
889
+ if tf.shape(q_batch)[0] < 2:
890
+ logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}")
891
  continue
892
+
893
  valid_batches = True
894
  val_loss = self.validation_step(q_batch, p_batch, n_batch)
895
  val_loss_avg(val_loss)
 
903
  "batches": f"{val_batches_processed}/{val_steps}"
904
  })
905
 
 
906
  gc.collect()
907
 
908
  if val_batches_processed >= val_steps:
909
  break
910
+
911
  if not valid_batches:
912
+ # If no valid batch is found, fallback
913
+ logger.warning("No valid validation batches in this epoch")
914
  if last_valid_val_loss is not None:
915
  val_loss = last_valid_val_loss
916
  val_loss_avg(val_loss)
917
  else:
918
+ val_loss = epoch_loss_avg.result()
 
 
919
  val_loss_avg(val_loss)
920
+
921
  if is_tqdm_val and val_pbar:
922
  val_pbar.close()
923
 
924
+ # End of epoch: final stats
925
  train_loss = epoch_loss_avg.result().numpy()
926
  val_loss = val_loss_avg.result().numpy()
927
  logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
928
 
929
+ # TensorBoard epoch logs
930
  with train_summary_writer.as_default():
931
  tf.summary.scalar("epoch_loss", train_loss, step=epoch)
932
  with val_summary_writer.as_default():
 
934
 
935
  # Save checkpoint
936
  manager.save()
937
+
938
+ # (Optional) Save model for quick testing/inference
939
  model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
940
  self.save_models(model_save_path)
941
  logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
942
 
943
+ # Update local history
944
  self.history['train_loss'].append(train_loss)
945
  self.history['val_loss'].append(val_loss)
 
 
 
 
 
 
 
946
  self.history.setdefault('learning_rate', []).append(current_lr)
 
 
 
 
 
 
947
 
948
+ def convert_to_py_floats(obj):
949
+ if isinstance(obj, dict):
950
+ return {k: convert_to_py_floats(v) for k, v in obj.items()}
951
+ elif isinstance(obj, list):
952
+ return [convert_to_py_floats(x) for x in obj]
953
+ elif isinstance(obj, (np.float32, np.float64)):
954
+ return float(obj)
955
+ elif tf.is_tensor(obj):
956
+ return float(obj.numpy())
957
+ else:
958
+ return obj
959
+
960
+ json_history = convert_to_py_floats(self.history)
961
+
962
+ # Save training history to file every epoch
963
+ # (Create or overwrite the file so we always have the latest.)
964
+ with open(history_path, 'w') as f:
965
+ json.dump(json_history, f)
966
+ logger.info(f"Saved training history to {history_path}")
967
+
968
+ # Early stopping
969
  if val_loss < best_val_loss - min_delta:
970
  best_val_loss = val_loss
971
  epochs_no_improve = 0
 
1021
  # Now compute scores: dot product of q_enc with each column in combined_p_n
1022
  # We'll use `tf.einsum` to handle the batch dimension properly
1023
  # dot_products => shape [batch_size, (1+neg_samples)]
1024
+ dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
1025
+ labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
 
 
 
 
1026
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1027
  labels=labels,
1028
  logits=dot_products
1029
  )
1030
+ loss = tf.cast(tf.reduce_mean(loss), tf.float32)
1031
 
1032
  # Calculate gradients
1033
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
1034
+ gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
1035
+ max_grad_norm = tf.constant(1.5, dtype=tf.float32)
 
 
1036
  gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
1037
+ post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
1038
 
1039
  # Apply gradients
1040
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
 
1067
  axis=1
1068
  )
1069
 
1070
+ dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
1071
+ labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
1072
+
1073
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1074
  labels=labels,
1075
  logits=dot_products
1076
  )
1077
+ loss = tf.cast(tf.reduce_mean(loss), tf.float32)
1078
 
1079
  return loss
1080
 
 
1101
  self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
1102
 
1103
  # Calculate and store constants
1104
+ self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32)
1105
+ self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32)
1106
 
1107
  logger.info(f"Learning rate schedule initialized:")
1108
  logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
 
1115
  step = tf.cast(step, tf.float32)
1116
 
1117
  # Warmup phase
1118
+ warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32)
1119
  warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
1120
 
1121
  # Decay phase
1122
+ decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32)
1123
+ decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32)
1124
+ decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32)
1125
+ cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32)
 
1126
  decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
1127
 
1128
  # Choose between warmup and decay
train_model.py CHANGED
@@ -48,14 +48,17 @@ def main():
48
  # Initialize chatbot
49
  chatbot = RetrievalChatbot(config, mode='training')
50
 
51
- # # Load from a checkpoint
52
-
53
- # latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
54
- # initial_epoch = 0
55
- # if latest_checkpoint:
56
- # ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
57
- # initial_epoch = ckpt_number
58
- # logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
 
 
 
59
 
60
  # Train the model
61
  chatbot.train_model(
@@ -63,8 +66,9 @@ def main():
63
  epochs=EPOCHS,
64
  batch_size=batch_size,
65
  use_lr_schedule=True,
66
- test_mode=False,
67
- checkpoint_dir=CHECKPOINT_DIR
 
68
  )
69
 
70
  # Save final model
 
48
  # Initialize chatbot
49
  chatbot = RetrievalChatbot(config, mode='training')
50
 
51
+ # Check for existing checkpoint and get initial epoch
52
+ latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
53
+ initial_epoch = 0
54
+ if latest_checkpoint:
55
+ try:
56
+ ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
57
+ initial_epoch = ckpt_number
58
+ logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
59
+ except (IndexError, ValueError):
60
+ logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
61
+ initial_epoch = 0
62
 
63
  # Train the model
64
  chatbot.train_model(
 
66
  epochs=EPOCHS,
67
  batch_size=batch_size,
68
  use_lr_schedule=True,
69
+ test_mode=True,
70
+ checkpoint_dir=CHECKPOINT_DIR,
71
+ initial_epoch=initial_epoch
72
  )
73
 
74
  # Save final model