JoeArmani commited on
Commit
fc5f33b
·
1 Parent(s): dfb45fe

checkpointing updates, optimizations

Browse files
chatbot_model.py CHANGED
@@ -30,7 +30,7 @@ class ChatbotConfig:
30
  num_attention_heads: int = 8
31
  dropout_rate: float = 0.2
32
  l2_reg_weight: float = 0.001
33
- learning_rate: float = 0.001
34
  min_text_length: int = 3
35
  max_context_turns: int = 5
36
  warmup_steps: int = 200
@@ -676,8 +676,9 @@ class RetrievalChatbot(DeviceAwareModel):
676
 
677
  # Initialize checkpoint manager
678
  checkpoint = tf.train.Checkpoint(
679
- epoch=tf.Variable(0),
680
- optimizer=self.optimizer,
 
681
  model=self.encoder,
682
  variables=self.encoder.variables
683
  )
@@ -685,27 +686,29 @@ class RetrievalChatbot(DeviceAwareModel):
685
 
686
  # Restore from checkpoint if available
687
  latest_checkpoint = manager.latest_checkpoint
 
688
  if latest_checkpoint:
689
- history_path = Path(checkpoint_dir) / 'training_history.json'
690
- if history_path.exists():
691
- try:
692
- with open(history_path, 'r') as f:
693
- self.history = json.load(f)
694
- logger.info(f"Loaded previous training history from {history_path}")
695
- except Exception as e:
696
- logger.warning(f"Could not load history, starting fresh: {e}")
697
- self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
698
- else:
699
- self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
700
 
701
  status = checkpoint.restore(latest_checkpoint)
702
  status.expect_partial()
703
-
704
  logger.info(f"Restored from checkpoint: {latest_checkpoint}")
 
705
  # Get the checkpoint number to validate initial_epoch
706
  ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
707
  if initial_epoch == 0:
708
  initial_epoch = ckpt_number
 
 
709
  logger.info(f"Resuming from epoch {initial_epoch}")
710
  else:
711
  logger.info("Starting training from scratch")
@@ -736,6 +739,7 @@ class RetrievalChatbot(DeviceAwareModel):
736
  total_pairs = subset_size
737
  train_size = int(total_pairs * (1 - validation_split))
738
  val_size = total_pairs - train_size
 
739
  steps_per_epoch = math.ceil(train_size / batch_size)
740
  val_steps = math.ceil(val_size / batch_size)
741
  total_steps = steps_per_epoch * epochs
@@ -758,7 +762,7 @@ class RetrievalChatbot(DeviceAwareModel):
758
  train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
759
  train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
760
 
761
- val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
762
  val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
763
  val_dataset = val_dataset.cache()
764
 
@@ -766,7 +770,9 @@ class RetrievalChatbot(DeviceAwareModel):
766
  best_val_loss = float("inf")
767
  epochs_no_improve = 0
768
 
769
- for epoch in range(initial_epoch + 1, epochs + 1):
 
 
770
  # --- Training Phase ---
771
  epoch_loss_avg = tf.keras.metrics.Mean()
772
  batches_processed = 0
@@ -790,8 +796,8 @@ class RetrievalChatbot(DeviceAwareModel):
790
  elif grad_norm_value > 100:
791
  logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}")
792
 
793
- if grad_norm_value != post_clip_value:
794
- logger.info(f"Gradient clipped: {grad_norm_value:.2e} -> {post_clip_value:.2e}")
795
 
796
  epoch_loss_avg(loss)
797
  batches_processed += 1
@@ -840,9 +846,18 @@ class RetrievalChatbot(DeviceAwareModel):
840
  is_tqdm_val = False
841
  logger.info("Validation progress bar disabled")
842
 
 
 
 
843
  for q_batch, p_batch, n_batch in val_dataset:
 
 
 
 
 
844
  val_loss = self.validation_step(q_batch, p_batch, n_batch)
845
  val_loss_avg(val_loss)
 
846
  val_batches_processed += 1
847
 
848
  if is_tqdm_val:
@@ -857,7 +872,18 @@ class RetrievalChatbot(DeviceAwareModel):
857
 
858
  if val_batches_processed >= val_steps:
859
  break
860
-
 
 
 
 
 
 
 
 
 
 
 
861
  if is_tqdm_val and val_pbar:
862
  val_pbar.close()
863
 
@@ -893,9 +919,10 @@ class RetrievalChatbot(DeviceAwareModel):
893
  self.history.setdefault('learning_rate', []).append(current_lr)
894
 
895
  # Save history to file
896
- with open(history_path, 'w') as f:
897
- json.dump(self.history, f)
898
- logger.info(f"Saved training history to {history_path}")
 
899
 
900
  # Early stopping logic
901
  if val_loss < best_val_loss - min_delta:
@@ -970,7 +997,7 @@ class RetrievalChatbot(DeviceAwareModel):
970
  gradients_norm = tf.linalg.global_norm(gradients)
971
 
972
  # Clip gradients if norm exceeds threshold
973
- max_grad_norm = 1.0
974
  gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
975
  post_clip_norm = tf.linalg.global_norm(gradients)
976
 
 
30
  num_attention_heads: int = 8
31
  dropout_rate: float = 0.2
32
  l2_reg_weight: float = 0.001
33
+ learning_rate: float = 0.0005
34
  min_text_length: int = 3
35
  max_context_turns: int = 5
36
  warmup_steps: int = 200
 
676
 
677
  # Initialize checkpoint manager
678
  checkpoint = tf.train.Checkpoint(
679
+ epoch=tf.Variable(0, dtype=tf.int32),
680
+ optimizer=self.optimizer,
681
+ optimizer_iterations=self.optimizer.iterations,
682
  model=self.encoder,
683
  variables=self.encoder.variables
684
  )
 
686
 
687
  # Restore from checkpoint if available
688
  latest_checkpoint = manager.latest_checkpoint
689
+ #history_path = Path(checkpoint_dir) / 'training_history.json'
690
  if latest_checkpoint:
691
+ # if history_path.exists():
692
+ # try:
693
+ # with open(history_path, 'r') as f:
694
+ # self.history = json.load(f)
695
+ # logger.info(f"Loaded previous training history from {history_path}")
696
+ # except Exception as e:
697
+ # logger.warning(f"Could not load history, starting fresh: {e}")
698
+ # self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
699
+ # else:
700
+ # self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
 
701
 
702
  status = checkpoint.restore(latest_checkpoint)
703
  status.expect_partial()
 
704
  logger.info(f"Restored from checkpoint: {latest_checkpoint}")
705
+
706
  # Get the checkpoint number to validate initial_epoch
707
  ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
708
  if initial_epoch == 0:
709
  initial_epoch = ckpt_number
710
+
711
+ checkpoint.epoch.assign(initial_epoch)
712
  logger.info(f"Resuming from epoch {initial_epoch}")
713
  else:
714
  logger.info("Starting training from scratch")
 
739
  total_pairs = subset_size
740
  train_size = int(total_pairs * (1 - validation_split))
741
  val_size = total_pairs - train_size
742
+ batch_size = min(batch_size, val_size)
743
  steps_per_epoch = math.ceil(train_size / batch_size)
744
  val_steps = math.ceil(val_size / batch_size)
745
  total_steps = steps_per_epoch * epochs
 
762
  train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
763
  train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
764
 
765
+ val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
766
  val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
767
  val_dataset = val_dataset.cache()
768
 
 
770
  best_val_loss = float("inf")
771
  epochs_no_improve = 0
772
 
773
+ for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
774
+ checkpoint.epoch.assign(epoch)
775
+ logger.info(f"Starting Epoch {epoch}...")
776
  # --- Training Phase ---
777
  epoch_loss_avg = tf.keras.metrics.Mean()
778
  batches_processed = 0
 
796
  elif grad_norm_value > 100:
797
  logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}")
798
 
799
+ # if grad_norm_value != post_clip_value:
800
+ # logger.info(f"Gradient clipped: {grad_norm_value:.2e} -> {post_clip_value:.2e}")
801
 
802
  epoch_loss_avg(loss)
803
  batches_processed += 1
 
846
  is_tqdm_val = False
847
  logger.info("Validation progress bar disabled")
848
 
849
+ last_valid_val_loss = None # Initialize outside the loop
850
+ valid_batches = False
851
+
852
  for q_batch, p_batch, n_batch in val_dataset:
853
+ if tf.shape(q_batch)[0] < 2:
854
+ logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]} (too small for loss calculation)")
855
+ continue
856
+
857
+ valid_batches = True
858
  val_loss = self.validation_step(q_batch, p_batch, n_batch)
859
  val_loss_avg(val_loss)
860
+ last_valid_val_loss = val_loss
861
  val_batches_processed += 1
862
 
863
  if is_tqdm_val:
 
872
 
873
  if val_batches_processed >= val_steps:
874
  break
875
+
876
+ if not valid_batches:
877
+ logger.warning("No valid validation batches in this epoch, using last known validation loss")
878
+ if last_valid_val_loss is not None:
879
+ val_loss = last_valid_val_loss
880
+ val_loss_avg(val_loss)
881
+ else:
882
+ # If we've never had a valid batch (first epoch), use training loss as fallback
883
+ logger.warning("No previous validation loss available, using training loss as fallback")
884
+ val_loss = train_loss
885
+ val_loss_avg(val_loss)
886
+
887
  if is_tqdm_val and val_pbar:
888
  val_pbar.close()
889
 
 
919
  self.history.setdefault('learning_rate', []).append(current_lr)
920
 
921
  # Save history to file
922
+ #if history_path.exists():
923
+ # with open(history_path, 'w') as f:
924
+ # json.dump(self.history, f)
925
+ # logger.info(f"Saved training history to {history_path}")
926
 
927
  # Early stopping logic
928
  if val_loss < best_val_loss - min_delta:
 
997
  gradients_norm = tf.linalg.global_norm(gradients)
998
 
999
  # Clip gradients if norm exceeds threshold
1000
+ max_grad_norm = 1.5
1001
  gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
1002
  post_clip_norm = tf.linalg.global_norm(gradients)
1003
 
training_plotter.py → plotter.py RENAMED
@@ -2,11 +2,8 @@ from pathlib import Path
2
  from typing import Dict, List, Optional
3
  import matplotlib.pyplot as plt
4
  from datetime import datetime
5
- import logging
6
 
7
- logger = logging.getLogger(__name__)
8
-
9
- class TrainingPlotter:
10
  def __init__(self, save_dir: Optional[Path] = None):
11
  self.save_dir = save_dir
12
  if save_dir:
@@ -18,10 +15,7 @@ class TrainingPlotter:
18
  Args:
19
  history: Dictionary containing training metrics
20
  title: Title for the plot
21
- """
22
- # Silence matplotlib debug messages
23
- logger.setLevel(logging.WARNING)
24
-
25
  # Create figure with subplots
26
  fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
27
 
@@ -46,15 +40,11 @@ class TrainingPlotter:
46
  plt.suptitle(title)
47
  plt.tight_layout()
48
 
49
- # Reset the logger level
50
- logger.setLevel(logging.INFO)
51
-
52
  # Save if directory provided
53
  if self.save_dir:
54
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
55
  save_path = self.save_dir / f'training_history_{timestamp}.png'
56
  plt.savefig(save_path)
57
- logger.info(f"Saved training history plot to {save_path}")
58
 
59
  plt.show()
60
 
@@ -64,8 +54,6 @@ class TrainingPlotter:
64
  Args:
65
  metrics: Dictionary of validation metrics. Can handle nested dictionaries.
66
  """
67
- # Silence matplotlib debug messages
68
- logger.setLevel(logging.WARNING)
69
 
70
  # Flatten nested metrics dictionary
71
  flat_metrics = {}
@@ -83,7 +71,6 @@ class TrainingPlotter:
83
  flat_metrics[key] = value
84
 
85
  if not flat_metrics:
86
- logger.warning("No numeric metrics to plot")
87
  return
88
 
89
  plt.figure(figsize=(12, 6))
@@ -113,15 +100,11 @@ class TrainingPlotter:
113
  # Adjust layout to prevent label cutoff
114
  plt.tight_layout()
115
 
116
- # Reset the logger level
117
- logger.setLevel(logging.INFO)
118
-
119
  # Save if directory provided
120
  if self.save_dir:
121
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
122
  save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
123
  plt.savefig(save_path)
124
- logger.info(f"Saved validation metrics plot to {save_path}")
125
 
126
  plt.show()
127
 
 
2
  from typing import Dict, List, Optional
3
  import matplotlib.pyplot as plt
4
  from datetime import datetime
 
5
 
6
+ class Plotter:
 
 
7
  def __init__(self, save_dir: Optional[Path] = None):
8
  self.save_dir = save_dir
9
  if save_dir:
 
15
  Args:
16
  history: Dictionary containing training metrics
17
  title: Title for the plot
18
+ """
 
 
 
19
  # Create figure with subplots
20
  fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
21
 
 
40
  plt.suptitle(title)
41
  plt.tight_layout()
42
 
 
 
 
43
  # Save if directory provided
44
  if self.save_dir:
45
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
46
  save_path = self.save_dir / f'training_history_{timestamp}.png'
47
  plt.savefig(save_path)
 
48
 
49
  plt.show()
50
 
 
54
  Args:
55
  metrics: Dictionary of validation metrics. Can handle nested dictionaries.
56
  """
 
 
57
 
58
  # Flatten nested metrics dictionary
59
  flat_metrics = {}
 
71
  flat_metrics[key] = value
72
 
73
  if not flat_metrics:
 
74
  return
75
 
76
  plt.figure(figsize=(12, 6))
 
100
  # Adjust layout to prevent label cutoff
101
  plt.tight_layout()
102
 
 
 
 
103
  # Save if directory provided
104
  if self.save_dir:
105
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
106
  save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
107
  plt.savefig(save_path)
 
108
 
109
  plt.show()
110
 
readme.md CHANGED
@@ -18,9 +18,6 @@ This package automatically downloads the following models during installation:
18
 
19
  pip install -e .
20
 
21
- On Linux with Cuda/GPU:
22
- pip install faiss-gpu>=1.7.0
23
-
24
  ## Description
25
 
26
  This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
 
18
 
19
  pip install -e .
20
 
 
 
 
21
  ## Description
22
 
23
  This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
train_model.py CHANGED
@@ -1,7 +1,7 @@
1
  import tensorflow as tf
2
  from chatbot_model import RetrievalChatbot, ChatbotConfig
3
  from environment_setup import EnvironmentSetup
4
- from training_plotter import TrainingPlotter
5
 
6
  from logger_config import config_logger
7
  logger = config_logger(__name__)
@@ -38,7 +38,7 @@ def main():
38
  # Training configuration
39
  EPOCHS = 20
40
  TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
41
-
42
  # Optimize batch size for Colab
43
  batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
44
 
@@ -48,14 +48,14 @@ def main():
48
  # Initialize chatbot
49
  chatbot = RetrievalChatbot(config, mode='training')
50
 
51
- # Load from a checkpoint
52
- checkpoint_dir = 'checkpoints/'
53
- latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
54
- initial_epoch = 0
55
- if latest_checkpoint:
56
- ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
57
- initial_epoch = ckpt_number
58
- logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
59
 
60
  # Train the model
61
  chatbot.train_model(
@@ -64,7 +64,7 @@ def main():
64
  batch_size=batch_size,
65
  use_lr_schedule=True,
66
  test_mode=False,
67
- initial_epoch=initial_epoch
68
  )
69
 
70
  # Save final model
@@ -72,7 +72,7 @@ def main():
72
  chatbot.save_models(model_save_path)
73
 
74
  # Plot and save training history
75
- plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
76
  plotter.plot_training_history(chatbot.history)
77
 
78
  if __name__ == "__main__":
 
1
  import tensorflow as tf
2
  from chatbot_model import RetrievalChatbot, ChatbotConfig
3
  from environment_setup import EnvironmentSetup
4
+ from plotter import Plotter
5
 
6
  from logger_config import config_logger
7
  logger = config_logger(__name__)
 
38
  # Training configuration
39
  EPOCHS = 20
40
  TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
41
+ CHECKPOINT_DIR = 'checkpoints/'
42
  # Optimize batch size for Colab
43
  batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
44
 
 
48
  # Initialize chatbot
49
  chatbot = RetrievalChatbot(config, mode='training')
50
 
51
+ # # Load from a checkpoint
52
+
53
+ # latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
54
+ # initial_epoch = 0
55
+ # if latest_checkpoint:
56
+ # ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
57
+ # initial_epoch = ckpt_number
58
+ # logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
59
 
60
  # Train the model
61
  chatbot.train_model(
 
64
  batch_size=batch_size,
65
  use_lr_schedule=True,
66
  test_mode=False,
67
+ checkpoint_dir=CHECKPOINT_DIR
68
  )
69
 
70
  # Save final model
 
72
  chatbot.save_models(model_save_path)
73
 
74
  # Plot and save training history
75
+ plotter = Plotter(save_dir=env.training_dirs['plots'])
76
  plotter.plot_training_history(chatbot.history)
77
 
78
  if __name__ == "__main__":
validate_model.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  from chatbot_model import ChatbotConfig, RetrievalChatbot
4
  from response_quality_checker import ResponseQualityChecker
5
  from chatbot_validator import ChatbotValidator
6
- from training_plotter import TrainingPlotter
7
  from environment_setup import EnvironmentSetup
8
 
9
  from logger_config import config_logger
@@ -103,7 +103,7 @@ def validate_chatbot():
103
 
104
  # Plot validation_metrics
105
  try:
106
- plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
107
  plotter.plot_validation_metrics(validation_metrics)
108
  logger.info("Validation metrics plotted successfully.")
109
  except Exception as e:
 
3
  from chatbot_model import ChatbotConfig, RetrievalChatbot
4
  from response_quality_checker import ResponseQualityChecker
5
  from chatbot_validator import ChatbotValidator
6
+ from plotter import Plotter
7
  from environment_setup import EnvironmentSetup
8
 
9
  from logger_config import config_logger
 
103
 
104
  # Plot validation_metrics
105
  try:
106
+ plotter = Plotter(save_dir=env.training_dirs['plots'])
107
  plotter.plot_validation_metrics(validation_metrics)
108
  logger.info("Validation metrics plotted successfully.")
109
  except Exception as e: