JoeArmani
commited on
Commit
·
fc5f33b
1
Parent(s):
dfb45fe
checkpointing updates, optimizations
Browse files- chatbot_model.py +51 -24
- training_plotter.py → plotter.py +2 -19
- readme.md +0 -3
- train_model.py +12 -12
- validate_model.py +2 -2
chatbot_model.py
CHANGED
@@ -30,7 +30,7 @@ class ChatbotConfig:
|
|
30 |
num_attention_heads: int = 8
|
31 |
dropout_rate: float = 0.2
|
32 |
l2_reg_weight: float = 0.001
|
33 |
-
learning_rate: float = 0.
|
34 |
min_text_length: int = 3
|
35 |
max_context_turns: int = 5
|
36 |
warmup_steps: int = 200
|
@@ -676,8 +676,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
676 |
|
677 |
# Initialize checkpoint manager
|
678 |
checkpoint = tf.train.Checkpoint(
|
679 |
-
epoch=tf.Variable(0),
|
680 |
-
optimizer=self.optimizer,
|
|
|
681 |
model=self.encoder,
|
682 |
variables=self.encoder.variables
|
683 |
)
|
@@ -685,27 +686,29 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
685 |
|
686 |
# Restore from checkpoint if available
|
687 |
latest_checkpoint = manager.latest_checkpoint
|
|
|
688 |
if latest_checkpoint:
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
700 |
|
701 |
status = checkpoint.restore(latest_checkpoint)
|
702 |
status.expect_partial()
|
703 |
-
|
704 |
logger.info(f"Restored from checkpoint: {latest_checkpoint}")
|
|
|
705 |
# Get the checkpoint number to validate initial_epoch
|
706 |
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
|
707 |
if initial_epoch == 0:
|
708 |
initial_epoch = ckpt_number
|
|
|
|
|
709 |
logger.info(f"Resuming from epoch {initial_epoch}")
|
710 |
else:
|
711 |
logger.info("Starting training from scratch")
|
@@ -736,6 +739,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
736 |
total_pairs = subset_size
|
737 |
train_size = int(total_pairs * (1 - validation_split))
|
738 |
val_size = total_pairs - train_size
|
|
|
739 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
740 |
val_steps = math.ceil(val_size / batch_size)
|
741 |
total_steps = steps_per_epoch * epochs
|
@@ -758,7 +762,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
758 |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
759 |
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
760 |
|
761 |
-
val_dataset = val_dataset.batch(batch_size, drop_remainder=
|
762 |
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
|
763 |
val_dataset = val_dataset.cache()
|
764 |
|
@@ -766,7 +770,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
766 |
best_val_loss = float("inf")
|
767 |
epochs_no_improve = 0
|
768 |
|
769 |
-
for epoch in range(
|
|
|
|
|
770 |
# --- Training Phase ---
|
771 |
epoch_loss_avg = tf.keras.metrics.Mean()
|
772 |
batches_processed = 0
|
@@ -790,8 +796,8 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
790 |
elif grad_norm_value > 100:
|
791 |
logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}")
|
792 |
|
793 |
-
if grad_norm_value != post_clip_value:
|
794 |
-
|
795 |
|
796 |
epoch_loss_avg(loss)
|
797 |
batches_processed += 1
|
@@ -840,9 +846,18 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
840 |
is_tqdm_val = False
|
841 |
logger.info("Validation progress bar disabled")
|
842 |
|
|
|
|
|
|
|
843 |
for q_batch, p_batch, n_batch in val_dataset:
|
|
|
|
|
|
|
|
|
|
|
844 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
845 |
val_loss_avg(val_loss)
|
|
|
846 |
val_batches_processed += 1
|
847 |
|
848 |
if is_tqdm_val:
|
@@ -857,7 +872,18 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
857 |
|
858 |
if val_batches_processed >= val_steps:
|
859 |
break
|
860 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
861 |
if is_tqdm_val and val_pbar:
|
862 |
val_pbar.close()
|
863 |
|
@@ -893,9 +919,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
893 |
self.history.setdefault('learning_rate', []).append(current_lr)
|
894 |
|
895 |
# Save history to file
|
896 |
-
|
897 |
-
|
898 |
-
|
|
|
899 |
|
900 |
# Early stopping logic
|
901 |
if val_loss < best_val_loss - min_delta:
|
@@ -970,7 +997,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
970 |
gradients_norm = tf.linalg.global_norm(gradients)
|
971 |
|
972 |
# Clip gradients if norm exceeds threshold
|
973 |
-
max_grad_norm = 1.
|
974 |
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
975 |
post_clip_norm = tf.linalg.global_norm(gradients)
|
976 |
|
|
|
30 |
num_attention_heads: int = 8
|
31 |
dropout_rate: float = 0.2
|
32 |
l2_reg_weight: float = 0.001
|
33 |
+
learning_rate: float = 0.0005
|
34 |
min_text_length: int = 3
|
35 |
max_context_turns: int = 5
|
36 |
warmup_steps: int = 200
|
|
|
676 |
|
677 |
# Initialize checkpoint manager
|
678 |
checkpoint = tf.train.Checkpoint(
|
679 |
+
epoch=tf.Variable(0, dtype=tf.int32),
|
680 |
+
optimizer=self.optimizer,
|
681 |
+
optimizer_iterations=self.optimizer.iterations,
|
682 |
model=self.encoder,
|
683 |
variables=self.encoder.variables
|
684 |
)
|
|
|
686 |
|
687 |
# Restore from checkpoint if available
|
688 |
latest_checkpoint = manager.latest_checkpoint
|
689 |
+
#history_path = Path(checkpoint_dir) / 'training_history.json'
|
690 |
if latest_checkpoint:
|
691 |
+
# if history_path.exists():
|
692 |
+
# try:
|
693 |
+
# with open(history_path, 'r') as f:
|
694 |
+
# self.history = json.load(f)
|
695 |
+
# logger.info(f"Loaded previous training history from {history_path}")
|
696 |
+
# except Exception as e:
|
697 |
+
# logger.warning(f"Could not load history, starting fresh: {e}")
|
698 |
+
# self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
699 |
+
# else:
|
700 |
+
# self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
|
|
701 |
|
702 |
status = checkpoint.restore(latest_checkpoint)
|
703 |
status.expect_partial()
|
|
|
704 |
logger.info(f"Restored from checkpoint: {latest_checkpoint}")
|
705 |
+
|
706 |
# Get the checkpoint number to validate initial_epoch
|
707 |
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
|
708 |
if initial_epoch == 0:
|
709 |
initial_epoch = ckpt_number
|
710 |
+
|
711 |
+
checkpoint.epoch.assign(initial_epoch)
|
712 |
logger.info(f"Resuming from epoch {initial_epoch}")
|
713 |
else:
|
714 |
logger.info("Starting training from scratch")
|
|
|
739 |
total_pairs = subset_size
|
740 |
train_size = int(total_pairs * (1 - validation_split))
|
741 |
val_size = total_pairs - train_size
|
742 |
+
batch_size = min(batch_size, val_size)
|
743 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
744 |
val_steps = math.ceil(val_size / batch_size)
|
745 |
total_steps = steps_per_epoch * epochs
|
|
|
762 |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
763 |
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
764 |
|
765 |
+
val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
|
766 |
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
|
767 |
val_dataset = val_dataset.cache()
|
768 |
|
|
|
770 |
best_val_loss = float("inf")
|
771 |
epochs_no_improve = 0
|
772 |
|
773 |
+
for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
|
774 |
+
checkpoint.epoch.assign(epoch)
|
775 |
+
logger.info(f"Starting Epoch {epoch}...")
|
776 |
# --- Training Phase ---
|
777 |
epoch_loss_avg = tf.keras.metrics.Mean()
|
778 |
batches_processed = 0
|
|
|
796 |
elif grad_norm_value > 100:
|
797 |
logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}")
|
798 |
|
799 |
+
# if grad_norm_value != post_clip_value:
|
800 |
+
# logger.info(f"Gradient clipped: {grad_norm_value:.2e} -> {post_clip_value:.2e}")
|
801 |
|
802 |
epoch_loss_avg(loss)
|
803 |
batches_processed += 1
|
|
|
846 |
is_tqdm_val = False
|
847 |
logger.info("Validation progress bar disabled")
|
848 |
|
849 |
+
last_valid_val_loss = None # Initialize outside the loop
|
850 |
+
valid_batches = False
|
851 |
+
|
852 |
for q_batch, p_batch, n_batch in val_dataset:
|
853 |
+
if tf.shape(q_batch)[0] < 2:
|
854 |
+
logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]} (too small for loss calculation)")
|
855 |
+
continue
|
856 |
+
|
857 |
+
valid_batches = True
|
858 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
859 |
val_loss_avg(val_loss)
|
860 |
+
last_valid_val_loss = val_loss
|
861 |
val_batches_processed += 1
|
862 |
|
863 |
if is_tqdm_val:
|
|
|
872 |
|
873 |
if val_batches_processed >= val_steps:
|
874 |
break
|
875 |
+
|
876 |
+
if not valid_batches:
|
877 |
+
logger.warning("No valid validation batches in this epoch, using last known validation loss")
|
878 |
+
if last_valid_val_loss is not None:
|
879 |
+
val_loss = last_valid_val_loss
|
880 |
+
val_loss_avg(val_loss)
|
881 |
+
else:
|
882 |
+
# If we've never had a valid batch (first epoch), use training loss as fallback
|
883 |
+
logger.warning("No previous validation loss available, using training loss as fallback")
|
884 |
+
val_loss = train_loss
|
885 |
+
val_loss_avg(val_loss)
|
886 |
+
|
887 |
if is_tqdm_val and val_pbar:
|
888 |
val_pbar.close()
|
889 |
|
|
|
919 |
self.history.setdefault('learning_rate', []).append(current_lr)
|
920 |
|
921 |
# Save history to file
|
922 |
+
#if history_path.exists():
|
923 |
+
# with open(history_path, 'w') as f:
|
924 |
+
# json.dump(self.history, f)
|
925 |
+
# logger.info(f"Saved training history to {history_path}")
|
926 |
|
927 |
# Early stopping logic
|
928 |
if val_loss < best_val_loss - min_delta:
|
|
|
997 |
gradients_norm = tf.linalg.global_norm(gradients)
|
998 |
|
999 |
# Clip gradients if norm exceeds threshold
|
1000 |
+
max_grad_norm = 1.5
|
1001 |
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
1002 |
post_clip_norm = tf.linalg.global_norm(gradients)
|
1003 |
|
training_plotter.py → plotter.py
RENAMED
@@ -2,11 +2,8 @@ from pathlib import Path
|
|
2 |
from typing import Dict, List, Optional
|
3 |
import matplotlib.pyplot as plt
|
4 |
from datetime import datetime
|
5 |
-
import logging
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
class TrainingPlotter:
|
10 |
def __init__(self, save_dir: Optional[Path] = None):
|
11 |
self.save_dir = save_dir
|
12 |
if save_dir:
|
@@ -18,10 +15,7 @@ class TrainingPlotter:
|
|
18 |
Args:
|
19 |
history: Dictionary containing training metrics
|
20 |
title: Title for the plot
|
21 |
-
"""
|
22 |
-
# Silence matplotlib debug messages
|
23 |
-
logger.setLevel(logging.WARNING)
|
24 |
-
|
25 |
# Create figure with subplots
|
26 |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
27 |
|
@@ -46,15 +40,11 @@ class TrainingPlotter:
|
|
46 |
plt.suptitle(title)
|
47 |
plt.tight_layout()
|
48 |
|
49 |
-
# Reset the logger level
|
50 |
-
logger.setLevel(logging.INFO)
|
51 |
-
|
52 |
# Save if directory provided
|
53 |
if self.save_dir:
|
54 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
55 |
save_path = self.save_dir / f'training_history_{timestamp}.png'
|
56 |
plt.savefig(save_path)
|
57 |
-
logger.info(f"Saved training history plot to {save_path}")
|
58 |
|
59 |
plt.show()
|
60 |
|
@@ -64,8 +54,6 @@ class TrainingPlotter:
|
|
64 |
Args:
|
65 |
metrics: Dictionary of validation metrics. Can handle nested dictionaries.
|
66 |
"""
|
67 |
-
# Silence matplotlib debug messages
|
68 |
-
logger.setLevel(logging.WARNING)
|
69 |
|
70 |
# Flatten nested metrics dictionary
|
71 |
flat_metrics = {}
|
@@ -83,7 +71,6 @@ class TrainingPlotter:
|
|
83 |
flat_metrics[key] = value
|
84 |
|
85 |
if not flat_metrics:
|
86 |
-
logger.warning("No numeric metrics to plot")
|
87 |
return
|
88 |
|
89 |
plt.figure(figsize=(12, 6))
|
@@ -113,15 +100,11 @@ class TrainingPlotter:
|
|
113 |
# Adjust layout to prevent label cutoff
|
114 |
plt.tight_layout()
|
115 |
|
116 |
-
# Reset the logger level
|
117 |
-
logger.setLevel(logging.INFO)
|
118 |
-
|
119 |
# Save if directory provided
|
120 |
if self.save_dir:
|
121 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
122 |
save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
|
123 |
plt.savefig(save_path)
|
124 |
-
logger.info(f"Saved validation metrics plot to {save_path}")
|
125 |
|
126 |
plt.show()
|
127 |
|
|
|
2 |
from typing import Dict, List, Optional
|
3 |
import matplotlib.pyplot as plt
|
4 |
from datetime import datetime
|
|
|
5 |
|
6 |
+
class Plotter:
|
|
|
|
|
7 |
def __init__(self, save_dir: Optional[Path] = None):
|
8 |
self.save_dir = save_dir
|
9 |
if save_dir:
|
|
|
15 |
Args:
|
16 |
history: Dictionary containing training metrics
|
17 |
title: Title for the plot
|
18 |
+
"""
|
|
|
|
|
|
|
19 |
# Create figure with subplots
|
20 |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
|
21 |
|
|
|
40 |
plt.suptitle(title)
|
41 |
plt.tight_layout()
|
42 |
|
|
|
|
|
|
|
43 |
# Save if directory provided
|
44 |
if self.save_dir:
|
45 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
46 |
save_path = self.save_dir / f'training_history_{timestamp}.png'
|
47 |
plt.savefig(save_path)
|
|
|
48 |
|
49 |
plt.show()
|
50 |
|
|
|
54 |
Args:
|
55 |
metrics: Dictionary of validation metrics. Can handle nested dictionaries.
|
56 |
"""
|
|
|
|
|
57 |
|
58 |
# Flatten nested metrics dictionary
|
59 |
flat_metrics = {}
|
|
|
71 |
flat_metrics[key] = value
|
72 |
|
73 |
if not flat_metrics:
|
|
|
74 |
return
|
75 |
|
76 |
plt.figure(figsize=(12, 6))
|
|
|
100 |
# Adjust layout to prevent label cutoff
|
101 |
plt.tight_layout()
|
102 |
|
|
|
|
|
|
|
103 |
# Save if directory provided
|
104 |
if self.save_dir:
|
105 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
106 |
save_path = self.save_dir / f'validation_metrics_{timestamp}.png'
|
107 |
plt.savefig(save_path)
|
|
|
108 |
|
109 |
plt.show()
|
110 |
|
readme.md
CHANGED
@@ -18,9 +18,6 @@ This package automatically downloads the following models during installation:
|
|
18 |
|
19 |
pip install -e .
|
20 |
|
21 |
-
On Linux with Cuda/GPU:
|
22 |
-
pip install faiss-gpu>=1.7.0
|
23 |
-
|
24 |
## Description
|
25 |
|
26 |
This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
|
|
|
18 |
|
19 |
pip install -e .
|
20 |
|
|
|
|
|
|
|
21 |
## Description
|
22 |
|
23 |
This Python script demonstrates a complete pipeline for dialogue augmentation, including validation, optimization, and data augmentation.
|
train_model.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import tensorflow as tf
|
2 |
from chatbot_model import RetrievalChatbot, ChatbotConfig
|
3 |
from environment_setup import EnvironmentSetup
|
4 |
-
from
|
5 |
|
6 |
from logger_config import config_logger
|
7 |
logger = config_logger(__name__)
|
@@ -38,7 +38,7 @@ def main():
|
|
38 |
# Training configuration
|
39 |
EPOCHS = 20
|
40 |
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
41 |
-
|
42 |
# Optimize batch size for Colab
|
43 |
batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
|
44 |
|
@@ -48,14 +48,14 @@ def main():
|
|
48 |
# Initialize chatbot
|
49 |
chatbot = RetrievalChatbot(config, mode='training')
|
50 |
|
51 |
-
# Load from a checkpoint
|
52 |
-
|
53 |
-
latest_checkpoint = tf.train.latest_checkpoint(
|
54 |
-
initial_epoch = 0
|
55 |
-
if latest_checkpoint:
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
|
60 |
# Train the model
|
61 |
chatbot.train_model(
|
@@ -64,7 +64,7 @@ def main():
|
|
64 |
batch_size=batch_size,
|
65 |
use_lr_schedule=True,
|
66 |
test_mode=False,
|
67 |
-
|
68 |
)
|
69 |
|
70 |
# Save final model
|
@@ -72,7 +72,7 @@ def main():
|
|
72 |
chatbot.save_models(model_save_path)
|
73 |
|
74 |
# Plot and save training history
|
75 |
-
plotter =
|
76 |
plotter.plot_training_history(chatbot.history)
|
77 |
|
78 |
if __name__ == "__main__":
|
|
|
1 |
import tensorflow as tf
|
2 |
from chatbot_model import RetrievalChatbot, ChatbotConfig
|
3 |
from environment_setup import EnvironmentSetup
|
4 |
+
from plotter import Plotter
|
5 |
|
6 |
from logger_config import config_logger
|
7 |
logger = config_logger(__name__)
|
|
|
38 |
# Training configuration
|
39 |
EPOCHS = 20
|
40 |
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
41 |
+
CHECKPOINT_DIR = 'checkpoints/'
|
42 |
# Optimize batch size for Colab
|
43 |
batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
|
44 |
|
|
|
48 |
# Initialize chatbot
|
49 |
chatbot = RetrievalChatbot(config, mode='training')
|
50 |
|
51 |
+
# # Load from a checkpoint
|
52 |
+
|
53 |
+
# latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
|
54 |
+
# initial_epoch = 0
|
55 |
+
# if latest_checkpoint:
|
56 |
+
# ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
|
57 |
+
# initial_epoch = ckpt_number
|
58 |
+
# logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
|
59 |
|
60 |
# Train the model
|
61 |
chatbot.train_model(
|
|
|
64 |
batch_size=batch_size,
|
65 |
use_lr_schedule=True,
|
66 |
test_mode=False,
|
67 |
+
checkpoint_dir=CHECKPOINT_DIR
|
68 |
)
|
69 |
|
70 |
# Save final model
|
|
|
72 |
chatbot.save_models(model_save_path)
|
73 |
|
74 |
# Plot and save training history
|
75 |
+
plotter = Plotter(save_dir=env.training_dirs['plots'])
|
76 |
plotter.plot_training_history(chatbot.history)
|
77 |
|
78 |
if __name__ == "__main__":
|
validate_model.py
CHANGED
@@ -3,7 +3,7 @@ import json
|
|
3 |
from chatbot_model import ChatbotConfig, RetrievalChatbot
|
4 |
from response_quality_checker import ResponseQualityChecker
|
5 |
from chatbot_validator import ChatbotValidator
|
6 |
-
from
|
7 |
from environment_setup import EnvironmentSetup
|
8 |
|
9 |
from logger_config import config_logger
|
@@ -103,7 +103,7 @@ def validate_chatbot():
|
|
103 |
|
104 |
# Plot validation_metrics
|
105 |
try:
|
106 |
-
plotter =
|
107 |
plotter.plot_validation_metrics(validation_metrics)
|
108 |
logger.info("Validation metrics plotted successfully.")
|
109 |
except Exception as e:
|
|
|
3 |
from chatbot_model import ChatbotConfig, RetrievalChatbot
|
4 |
from response_quality_checker import ResponseQualityChecker
|
5 |
from chatbot_validator import ChatbotValidator
|
6 |
+
from plotter import Plotter
|
7 |
from environment_setup import EnvironmentSetup
|
8 |
|
9 |
from logger_config import config_logger
|
|
|
103 |
|
104 |
# Plot validation_metrics
|
105 |
try:
|
106 |
+
plotter = Plotter(save_dir=env.training_dirs['plots'])
|
107 |
plotter.plot_validation_metrics(validation_metrics)
|
108 |
logger.info("Validation metrics plotted successfully.")
|
109 |
except Exception as e:
|