JoeArmani
commited on
Commit
·
d53c64b
1
Parent(s):
fc5f33b
fix checkpointing restoration
Browse files- chatbot_model.py +171 -137
- train_model.py +14 -10
chatbot_model.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
from transformers import TFAutoModel, AutoTokenizer
|
3 |
import tensorflow as tf
|
4 |
from typing import List, Tuple, Dict, Optional, Union, Any
|
@@ -74,7 +75,8 @@ class EncoderModel(tf.keras.Model):
|
|
74 |
self.projection = tf.keras.layers.Dense(
|
75 |
config.embedding_dim,
|
76 |
activation='tanh',
|
77 |
-
name="projection"
|
|
|
78 |
)
|
79 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
80 |
self.normalize = tf.keras.layers.Lambda(
|
@@ -613,20 +615,20 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
613 |
test_mode: bool = False,
|
614 |
initial_epoch: int = 0
|
615 |
) -> None:
|
616 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
617 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
618 |
-
|
619 |
def parse_tfrecord_fn(example_proto, max_length, neg_samples):
|
620 |
"""
|
621 |
Parses a single TFRecord example.
|
622 |
-
|
623 |
-
Args:
|
624 |
-
example_proto: A serialized TFRecord example.
|
625 |
-
max_length: The maximum sequence length for tokenization.
|
626 |
-
neg_samples: The number of hard negatives per query.
|
627 |
-
|
628 |
-
Returns:
|
629 |
-
A tuple of (query_ids, positive_ids, negative_ids).
|
630 |
"""
|
631 |
feature_description = {
|
632 |
'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
@@ -640,9 +642,9 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
640 |
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
641 |
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
642 |
|
643 |
-
return query_ids, positive_ids, negative_ids
|
644 |
-
|
645 |
-
#
|
646 |
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
647 |
total_pairs = sum(1 for _ in raw_dataset)
|
648 |
logger.info(f"Total pairs in TFRecord: {total_pairs}")
|
@@ -652,7 +654,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
652 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
653 |
val_steps = math.ceil(val_size / batch_size)
|
654 |
total_steps = steps_per_epoch * epochs
|
655 |
-
buffer_size = total_pairs // 10
|
656 |
|
657 |
logger.info(f"Training pairs: {train_size}")
|
658 |
logger.info(f"Validation pairs: {val_size}")
|
@@ -660,61 +662,104 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
660 |
logger.info(f"Validation steps: {val_steps}")
|
661 |
logger.info(f"Total steps: {total_steps}")
|
662 |
|
663 |
-
# Set up optimizer
|
664 |
if use_lr_schedule:
|
665 |
warmup_steps = int(total_steps * warmup_steps_ratio)
|
666 |
lr_schedule = self._get_lr_schedule(
|
667 |
total_steps=total_steps,
|
668 |
-
peak_lr=peak_lr,
|
669 |
warmup_steps=warmup_steps
|
670 |
)
|
671 |
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
672 |
logger.info("Using custom learning rate schedule.")
|
673 |
else:
|
674 |
-
self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
|
675 |
logger.info("Using fixed learning rate.")
|
676 |
|
677 |
-
# Initialize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
checkpoint = tf.train.Checkpoint(
|
679 |
epoch=tf.Variable(0, dtype=tf.int32),
|
680 |
optimizer=self.optimizer,
|
681 |
-
|
682 |
-
|
683 |
-
|
|
|
|
|
|
|
|
|
|
|
684 |
)
|
685 |
-
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3, checkpoint_name='ckpt')
|
686 |
|
687 |
-
# Restore from checkpoint if
|
688 |
latest_checkpoint = manager.latest_checkpoint
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
#
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
# self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
701 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
702 |
status = checkpoint.restore(latest_checkpoint)
|
703 |
-
status.
|
704 |
logger.info(f"Restored from checkpoint: {latest_checkpoint}")
|
|
|
705 |
|
706 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
|
708 |
if initial_epoch == 0:
|
709 |
initial_epoch = ckpt_number
|
710 |
-
|
711 |
-
checkpoint.epoch
|
|
|
712 |
logger.info(f"Resuming from epoch {initial_epoch}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
713 |
else:
|
714 |
logger.info("Starting training from scratch")
|
|
|
715 |
initial_epoch = 0
|
716 |
-
|
717 |
-
#
|
718 |
log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
719 |
log_dir.mkdir(parents=True, exist_ok=True)
|
720 |
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
@@ -724,18 +769,15 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
724 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
725 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
726 |
|
727 |
-
#
|
728 |
-
parse_fn = lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples)
|
729 |
-
|
730 |
-
# Create the full dataset
|
731 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
732 |
-
|
733 |
-
#
|
734 |
if test_mode:
|
735 |
-
subset_size =
|
736 |
dataset = dataset.take(subset_size)
|
737 |
logger.info(f"TEST MODE: Using only {subset_size} examples")
|
738 |
-
#
|
739 |
total_pairs = subset_size
|
740 |
train_size = int(total_pairs * (1 - validation_split))
|
741 |
val_size = total_pairs - train_size
|
@@ -743,22 +785,23 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
743 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
744 |
val_steps = math.ceil(val_size / batch_size)
|
745 |
total_steps = steps_per_epoch * epochs
|
746 |
-
buffer_size = total_pairs // 10
|
747 |
-
epochs = min(epochs, 5) #
|
748 |
early_stopping_patience = 2
|
749 |
logger.info(f"New training pairs: {train_size}")
|
750 |
logger.info(f"New validation pairs: {val_size}")
|
751 |
-
|
752 |
-
dataset = dataset.map(
|
753 |
-
|
754 |
-
|
|
|
|
|
|
|
755 |
train_dataset = dataset.take(train_size)
|
756 |
val_dataset = dataset.skip(train_size).take(val_size)
|
757 |
|
758 |
-
# Shuffle
|
759 |
train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
|
760 |
-
|
761 |
-
# Batch both datasets
|
762 |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
763 |
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
764 |
|
@@ -773,41 +816,34 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
773 |
for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
|
774 |
checkpoint.epoch.assign(epoch)
|
775 |
logger.info(f"Starting Epoch {epoch}...")
|
|
|
776 |
# --- Training Phase ---
|
777 |
-
epoch_loss_avg = tf.keras.metrics.Mean()
|
778 |
batches_processed = 0
|
779 |
|
|
|
780 |
try:
|
781 |
-
train_pbar = tqdm(
|
|
|
|
|
|
|
|
|
782 |
is_tqdm_train = True
|
783 |
except ImportError:
|
784 |
train_pbar = None
|
785 |
is_tqdm_train = False
|
786 |
-
logger.info("Training progress bar disabled")
|
787 |
|
788 |
for q_batch, p_batch, n_batch in train_dataset:
|
789 |
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
|
790 |
-
|
791 |
-
# Check for gradient issues
|
792 |
-
grad_norm_value = float(grad_norm.numpy())
|
793 |
-
post_clip_value = float(post_clip_norm.numpy())
|
794 |
-
if grad_norm_value < 1e-7:
|
795 |
-
logger.warning(f"Potential vanishing gradient detected: norm = {grad_norm_value:.2e}")
|
796 |
-
elif grad_norm_value > 100:
|
797 |
-
logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}")
|
798 |
-
|
799 |
-
# if grad_norm_value != post_clip_value:
|
800 |
-
# logger.info(f"Gradient clipped: {grad_norm_value:.2e} -> {post_clip_value:.2e}")
|
801 |
-
|
802 |
epoch_loss_avg(loss)
|
803 |
batches_processed += 1
|
804 |
|
805 |
# Log to TensorBoard
|
806 |
with train_summary_writer.as_default():
|
807 |
step = (epoch - 1) * steps_per_epoch + batches_processed
|
808 |
-
tf.summary.scalar("loss", loss, step=step)
|
809 |
-
tf.summary.scalar("gradient_norm_pre_clip", grad_norm, step=step)
|
810 |
-
tf.summary.scalar("gradient_norm_post_clip", post_clip_norm, step=step)
|
811 |
|
812 |
# Update progress bar
|
813 |
if use_lr_schedule:
|
@@ -819,15 +855,15 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
819 |
train_pbar.update(1)
|
820 |
train_pbar.set_postfix({
|
821 |
"loss": f"{loss.numpy():.4f}",
|
822 |
-
"pre_clip": f"{
|
823 |
-
"post_clip": f"{
|
824 |
"lr": f"{current_lr:.2e}",
|
825 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
826 |
})
|
827 |
|
828 |
-
# Memory cleanup
|
829 |
gc.collect()
|
830 |
|
|
|
831 |
if batches_processed >= steps_per_epoch:
|
832 |
break
|
833 |
|
@@ -835,7 +871,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
835 |
train_pbar.close()
|
836 |
|
837 |
# --- Validation Phase ---
|
838 |
-
val_loss_avg = tf.keras.metrics.Mean()
|
839 |
val_batches_processed = 0
|
840 |
|
841 |
try:
|
@@ -844,16 +880,16 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
844 |
except ImportError:
|
845 |
val_pbar = None
|
846 |
is_tqdm_val = False
|
847 |
-
logger.info("Validation progress bar disabled")
|
848 |
|
849 |
-
last_valid_val_loss = None
|
850 |
valid_batches = False
|
851 |
-
|
852 |
for q_batch, p_batch, n_batch in val_dataset:
|
853 |
-
|
854 |
-
|
|
|
855 |
continue
|
856 |
-
|
857 |
valid_batches = True
|
858 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
859 |
val_loss_avg(val_loss)
|
@@ -867,32 +903,30 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
867 |
"batches": f"{val_batches_processed}/{val_steps}"
|
868 |
})
|
869 |
|
870 |
-
# Memory cleanup
|
871 |
gc.collect()
|
872 |
|
873 |
if val_batches_processed >= val_steps:
|
874 |
break
|
875 |
-
|
876 |
if not valid_batches:
|
877 |
-
|
|
|
878 |
if last_valid_val_loss is not None:
|
879 |
val_loss = last_valid_val_loss
|
880 |
val_loss_avg(val_loss)
|
881 |
else:
|
882 |
-
|
883 |
-
logger.warning("No previous validation loss available, using training loss as fallback")
|
884 |
-
val_loss = train_loss
|
885 |
val_loss_avg(val_loss)
|
886 |
-
|
887 |
if is_tqdm_val and val_pbar:
|
888 |
val_pbar.close()
|
889 |
|
890 |
-
# End of epoch:
|
891 |
train_loss = epoch_loss_avg.result().numpy()
|
892 |
val_loss = val_loss_avg.result().numpy()
|
893 |
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
894 |
|
895 |
-
#
|
896 |
with train_summary_writer.as_default():
|
897 |
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
898 |
with val_summary_writer.as_default():
|
@@ -900,31 +934,38 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
900 |
|
901 |
# Save checkpoint
|
902 |
manager.save()
|
903 |
-
|
904 |
-
# Save model
|
905 |
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
|
906 |
self.save_models(model_save_path)
|
907 |
logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
|
908 |
|
909 |
-
#
|
910 |
self.history['train_loss'].append(train_loss)
|
911 |
self.history['val_loss'].append(val_loss)
|
912 |
-
|
913 |
-
if use_lr_schedule:
|
914 |
-
current_lr = float(lr_schedule(self.optimizer.iterations))
|
915 |
-
else:
|
916 |
-
current_lr = float(self.optimizer.learning_rate.numpy())
|
917 |
-
|
918 |
-
# Log learning rate
|
919 |
self.history.setdefault('learning_rate', []).append(current_lr)
|
920 |
-
|
921 |
-
# Save history to file
|
922 |
-
#if history_path.exists():
|
923 |
-
# with open(history_path, 'w') as f:
|
924 |
-
# json.dump(self.history, f)
|
925 |
-
# logger.info(f"Saved training history to {history_path}")
|
926 |
|
927 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
928 |
if val_loss < best_val_loss - min_delta:
|
929 |
best_val_loss = val_loss
|
930 |
epochs_no_improve = 0
|
@@ -980,26 +1021,20 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
980 |
# Now compute scores: dot product of q_enc with each column in combined_p_n
|
981 |
# We'll use `tf.einsum` to handle the batch dimension properly
|
982 |
# dot_products => shape [batch_size, (1+neg_samples)]
|
983 |
-
dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
|
984 |
-
|
985 |
-
# The label for each row is 0 (the first column is the correct/positive)
|
986 |
-
labels = tf.zeros([bs], dtype=tf.int32)
|
987 |
-
|
988 |
-
# Cross-entropy over the [batch_size, 1+neg_samples] scores
|
989 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
990 |
labels=labels,
|
991 |
logits=dot_products
|
992 |
)
|
993 |
-
loss = tf.reduce_mean(loss)
|
994 |
|
995 |
# Calculate gradients
|
996 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
997 |
-
gradients_norm = tf.linalg.global_norm(gradients)
|
998 |
-
|
999 |
-
# Clip gradients if norm exceeds threshold
|
1000 |
-
max_grad_norm = 1.5
|
1001 |
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
1002 |
-
post_clip_norm = tf.linalg.global_norm(gradients)
|
1003 |
|
1004 |
# Apply gradients
|
1005 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
@@ -1032,14 +1067,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1032 |
axis=1
|
1033 |
)
|
1034 |
|
1035 |
-
dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
|
1036 |
-
labels = tf.zeros([bs], dtype=tf.int32)
|
1037 |
-
|
1038 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1039 |
labels=labels,
|
1040 |
logits=dot_products
|
1041 |
)
|
1042 |
-
loss = tf.reduce_mean(loss)
|
1043 |
|
1044 |
return loss
|
1045 |
|
@@ -1066,8 +1101,8 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1066 |
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
|
1067 |
|
1068 |
# Calculate and store constants
|
1069 |
-
self.initial_lr = self.peak_lr * 0.1
|
1070 |
-
self.min_lr = self.peak_lr * 0.01
|
1071 |
|
1072 |
logger.info(f"Learning rate schedule initialized:")
|
1073 |
logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
|
@@ -1080,15 +1115,14 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1080 |
step = tf.cast(step, tf.float32)
|
1081 |
|
1082 |
# Warmup phase
|
1083 |
-
warmup_factor = tf.minimum(1.0, step / self.warmup_steps)
|
1084 |
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
1085 |
|
1086 |
# Decay phase
|
1087 |
-
decay_steps = tf.maximum(1.0, self.total_steps - self.warmup_steps)
|
1088 |
-
decay_factor = (step - self.warmup_steps) / decay_steps
|
1089 |
-
decay_factor = tf.minimum(tf.maximum(0.0, decay_factor), 1.0)
|
1090 |
-
|
1091 |
-
cosine_decay = 0.5 * (1.0 + tf.cos(tf.constant(math.pi) * decay_factor))
|
1092 |
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
1093 |
|
1094 |
# Choose between warmup and decay
|
|
|
1 |
import os
|
2 |
+
import numpy as np
|
3 |
from transformers import TFAutoModel, AutoTokenizer
|
4 |
import tensorflow as tf
|
5 |
from typing import List, Tuple, Dict, Optional, Union, Any
|
|
|
75 |
self.projection = tf.keras.layers.Dense(
|
76 |
config.embedding_dim,
|
77 |
activation='tanh',
|
78 |
+
name="projection",
|
79 |
+
dtype=tf.float32
|
80 |
)
|
81 |
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
82 |
self.normalize = tf.keras.layers.Lambda(
|
|
|
615 |
test_mode: bool = False,
|
616 |
initial_epoch: int = 0
|
617 |
) -> None:
|
618 |
+
"""
|
619 |
+
Train the retrieval model using a pre-prepared TFRecord dataset.
|
620 |
+
This method handles:
|
621 |
+
- Checkpoint loading/restoring
|
622 |
+
- LR scheduling
|
623 |
+
- Epoch/iteration tracking
|
624 |
+
- Optional training-history logging
|
625 |
+
- Basic early stopping
|
626 |
+
"""
|
627 |
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
628 |
+
|
629 |
def parse_tfrecord_fn(example_proto, max_length, neg_samples):
|
630 |
"""
|
631 |
Parses a single TFRecord example.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
"""
|
633 |
feature_description = {
|
634 |
'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
|
|
642 |
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
643 |
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
644 |
|
645 |
+
return query_ids, positive_ids, negative_ids
|
646 |
+
|
647 |
+
# Count total records in TFRecord
|
648 |
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
649 |
total_pairs = sum(1 for _ in raw_dataset)
|
650 |
logger.info(f"Total pairs in TFRecord: {total_pairs}")
|
|
|
654 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
655 |
val_steps = math.ceil(val_size / batch_size)
|
656 |
total_steps = steps_per_epoch * epochs
|
657 |
+
buffer_size = max(1, total_pairs // 10) # 10% of the dataset
|
658 |
|
659 |
logger.info(f"Training pairs: {train_size}")
|
660 |
logger.info(f"Validation pairs: {val_size}")
|
|
|
662 |
logger.info(f"Validation steps: {val_steps}")
|
663 |
logger.info(f"Total steps: {total_steps}")
|
664 |
|
665 |
+
# Set up optimizer & LR schedule
|
666 |
if use_lr_schedule:
|
667 |
warmup_steps = int(total_steps * warmup_steps_ratio)
|
668 |
lr_schedule = self._get_lr_schedule(
|
669 |
total_steps=total_steps,
|
670 |
+
peak_lr=tf.cast(peak_lr, tf.float32),
|
671 |
warmup_steps=warmup_steps
|
672 |
)
|
673 |
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
674 |
logger.info("Using custom learning rate schedule.")
|
675 |
else:
|
676 |
+
self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32))
|
677 |
logger.info("Using fixed learning rate.")
|
678 |
|
679 |
+
# Initialize optimizer with dummy step
|
680 |
+
dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
|
681 |
+
with tf.GradientTape() as tape:
|
682 |
+
dummy_output = self.encoder(dummy_input)
|
683 |
+
dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32)
|
684 |
+
dummy_grads = tape.gradient(dummy_loss, self.encoder.trainable_variables)
|
685 |
+
self.optimizer.apply_gradients(zip(dummy_grads, self.encoder.trainable_variables))
|
686 |
+
|
687 |
+
# Create checkpoint and manager
|
688 |
checkpoint = tf.train.Checkpoint(
|
689 |
epoch=tf.Variable(0, dtype=tf.int32),
|
690 |
optimizer=self.optimizer,
|
691 |
+
model=self.encoder
|
692 |
+
)
|
693 |
+
|
694 |
+
manager = tf.train.CheckpointManager(
|
695 |
+
checkpoint,
|
696 |
+
directory=checkpoint_dir,
|
697 |
+
max_to_keep=3,
|
698 |
+
checkpoint_name='ckpt'
|
699 |
)
|
|
|
700 |
|
701 |
+
# Restore from existing checkpoint if present
|
702 |
latest_checkpoint = manager.latest_checkpoint
|
703 |
+
history_path = Path(checkpoint_dir) / 'training_history.json'
|
704 |
+
|
705 |
+
# If you want to log all epoch losses across runs
|
706 |
+
if not hasattr(self, 'history'):
|
707 |
+
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
708 |
+
|
709 |
+
if latest_checkpoint and not test_mode:
|
710 |
+
# Debug info before restore
|
711 |
+
logger.info("\nEncoder Variables:")
|
712 |
+
for var in self.encoder.variables:
|
713 |
+
logger.info(f"{var.name}: {var.dtype} - Shape: {var.shape}")
|
|
|
714 |
|
715 |
+
logger.info("\nOptimizer Variables:")
|
716 |
+
for var in self.optimizer.variables:
|
717 |
+
logger.info(f"{var.name}: {var.dtype} - Shape: {var.shape}")
|
718 |
+
|
719 |
+
# Add checkpoint inspection
|
720 |
+
logger.info("\nTrying to load checkpoint from: ", latest_checkpoint)
|
721 |
+
reader = tf.train.load_checkpoint(latest_checkpoint)
|
722 |
+
shape_from_key = reader.get_variable_to_shape_map()
|
723 |
+
dtype_from_key = reader.get_variable_to_dtype_map()
|
724 |
+
logger.info("\nCheckpoint Variables:")
|
725 |
+
for key in shape_from_key:
|
726 |
+
logger.info(f"{key}: dtype={dtype_from_key[key]} - Shape: {shape_from_key[key]}")
|
727 |
+
|
728 |
status = checkpoint.restore(latest_checkpoint)
|
729 |
+
status.assert_consumed()
|
730 |
logger.info(f"Restored from checkpoint: {latest_checkpoint}")
|
731 |
+
logger.info(f"Optimizer iterations after restore: {self.optimizer.iterations.numpy()}")
|
732 |
|
733 |
+
# Verify learning rate after restore
|
734 |
+
if use_lr_schedule:
|
735 |
+
current_lr = float(lr_schedule(self.optimizer.iterations))
|
736 |
+
else:
|
737 |
+
current_lr = float(self.optimizer.learning_rate.numpy())
|
738 |
+
logger.info(f"Current learning rate after restore: {current_lr:.2e}")
|
739 |
+
|
740 |
+
# Derive initial_epoch from checkpoint name if not passed in
|
741 |
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
|
742 |
if initial_epoch == 0:
|
743 |
initial_epoch = ckpt_number
|
744 |
+
|
745 |
+
# Assign to checkpoint.epoch so we keep counting from that
|
746 |
+
checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32))
|
747 |
logger.info(f"Resuming from epoch {initial_epoch}")
|
748 |
+
|
749 |
+
# If you want to load old history from file:
|
750 |
+
if history_path.exists():
|
751 |
+
try:
|
752 |
+
with open(history_path, 'r') as f:
|
753 |
+
self.history = json.load(f)
|
754 |
+
logger.info(f"Loaded previous training history from {history_path}")
|
755 |
+
except Exception as e:
|
756 |
+
logger.warning(f"Could not load history, starting fresh: {e}")
|
757 |
else:
|
758 |
logger.info("Starting training from scratch")
|
759 |
+
checkpoint.epoch.assign(tf.cast(0, tf.int32))
|
760 |
initial_epoch = 0
|
761 |
+
|
762 |
+
# Set up TensorBoard
|
763 |
log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
764 |
log_dir.mkdir(parents=True, exist_ok=True)
|
765 |
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
|
769 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
770 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
771 |
|
772 |
+
# Parse dataset
|
|
|
|
|
|
|
773 |
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
774 |
+
|
775 |
+
# Optional: test/debug mode with small subset
|
776 |
if test_mode:
|
777 |
+
subset_size = 150
|
778 |
dataset = dataset.take(subset_size)
|
779 |
logger.info(f"TEST MODE: Using only {subset_size} examples")
|
780 |
+
# Recompute sizes, steps, epochs, etc., as needed
|
781 |
total_pairs = subset_size
|
782 |
train_size = int(total_pairs * (1 - validation_split))
|
783 |
val_size = total_pairs - train_size
|
|
|
785 |
steps_per_epoch = math.ceil(train_size / batch_size)
|
786 |
val_steps = math.ceil(val_size / batch_size)
|
787 |
total_steps = steps_per_epoch * epochs
|
788 |
+
buffer_size = max(1, total_pairs // 10)
|
789 |
+
epochs = min(epochs, 5) # For quick debug
|
790 |
early_stopping_patience = 2
|
791 |
logger.info(f"New training pairs: {train_size}")
|
792 |
logger.info(f"New validation pairs: {val_size}")
|
793 |
+
|
794 |
+
dataset = dataset.map(
|
795 |
+
lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples),
|
796 |
+
num_parallel_calls=tf.data.AUTOTUNE
|
797 |
+
)
|
798 |
+
|
799 |
+
# Train/val split
|
800 |
train_dataset = dataset.take(train_size)
|
801 |
val_dataset = dataset.skip(train_size).take(val_size)
|
802 |
|
803 |
+
# Shuffle and batch
|
804 |
train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
|
|
|
|
|
805 |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
|
806 |
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
|
807 |
|
|
|
816 |
for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
|
817 |
checkpoint.epoch.assign(epoch)
|
818 |
logger.info(f"Starting Epoch {epoch}...")
|
819 |
+
|
820 |
# --- Training Phase ---
|
821 |
+
epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
822 |
batches_processed = 0
|
823 |
|
824 |
+
# Progress bar
|
825 |
try:
|
826 |
+
train_pbar = tqdm(
|
827 |
+
total=steps_per_epoch,
|
828 |
+
desc=f"Training Epoch {epoch}",
|
829 |
+
unit="batch"
|
830 |
+
)
|
831 |
is_tqdm_train = True
|
832 |
except ImportError:
|
833 |
train_pbar = None
|
834 |
is_tqdm_train = False
|
|
|
835 |
|
836 |
for q_batch, p_batch, n_batch in train_dataset:
|
837 |
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
838 |
epoch_loss_avg(loss)
|
839 |
batches_processed += 1
|
840 |
|
841 |
# Log to TensorBoard
|
842 |
with train_summary_writer.as_default():
|
843 |
step = (epoch - 1) * steps_per_epoch + batches_processed
|
844 |
+
tf.summary.scalar("loss", tf.cast(loss, tf.float32), step=step)
|
845 |
+
tf.summary.scalar("gradient_norm_pre_clip", tf.cast(grad_norm, tf.float32), step=step)
|
846 |
+
tf.summary.scalar("gradient_norm_post_clip", tf.cast(post_clip_norm, tf.float32), step=step)
|
847 |
|
848 |
# Update progress bar
|
849 |
if use_lr_schedule:
|
|
|
855 |
train_pbar.update(1)
|
856 |
train_pbar.set_postfix({
|
857 |
"loss": f"{loss.numpy():.4f}",
|
858 |
+
"pre_clip": f"{grad_norm.numpy():.2e}",
|
859 |
+
"post_clip": f"{post_clip_norm.numpy():.2e}",
|
860 |
"lr": f"{current_lr:.2e}",
|
861 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
862 |
})
|
863 |
|
|
|
864 |
gc.collect()
|
865 |
|
866 |
+
# End the epoch early if we've processed all steps
|
867 |
if batches_processed >= steps_per_epoch:
|
868 |
break
|
869 |
|
|
|
871 |
train_pbar.close()
|
872 |
|
873 |
# --- Validation Phase ---
|
874 |
+
val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
|
875 |
val_batches_processed = 0
|
876 |
|
877 |
try:
|
|
|
880 |
except ImportError:
|
881 |
val_pbar = None
|
882 |
is_tqdm_val = False
|
|
|
883 |
|
884 |
+
last_valid_val_loss = None
|
885 |
valid_batches = False
|
886 |
+
|
887 |
for q_batch, p_batch, n_batch in val_dataset:
|
888 |
+
# If batch is too small, skip
|
889 |
+
if tf.shape(q_batch)[0] < 2:
|
890 |
+
logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}")
|
891 |
continue
|
892 |
+
|
893 |
valid_batches = True
|
894 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
895 |
val_loss_avg(val_loss)
|
|
|
903 |
"batches": f"{val_batches_processed}/{val_steps}"
|
904 |
})
|
905 |
|
|
|
906 |
gc.collect()
|
907 |
|
908 |
if val_batches_processed >= val_steps:
|
909 |
break
|
910 |
+
|
911 |
if not valid_batches:
|
912 |
+
# If no valid batch is found, fallback
|
913 |
+
logger.warning("No valid validation batches in this epoch")
|
914 |
if last_valid_val_loss is not None:
|
915 |
val_loss = last_valid_val_loss
|
916 |
val_loss_avg(val_loss)
|
917 |
else:
|
918 |
+
val_loss = epoch_loss_avg.result()
|
|
|
|
|
919 |
val_loss_avg(val_loss)
|
920 |
+
|
921 |
if is_tqdm_val and val_pbar:
|
922 |
val_pbar.close()
|
923 |
|
924 |
+
# End of epoch: final stats
|
925 |
train_loss = epoch_loss_avg.result().numpy()
|
926 |
val_loss = val_loss_avg.result().numpy()
|
927 |
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
928 |
|
929 |
+
# TensorBoard epoch logs
|
930 |
with train_summary_writer.as_default():
|
931 |
tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
932 |
with val_summary_writer.as_default():
|
|
|
934 |
|
935 |
# Save checkpoint
|
936 |
manager.save()
|
937 |
+
|
938 |
+
# (Optional) Save model for quick testing/inference
|
939 |
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
|
940 |
self.save_models(model_save_path)
|
941 |
logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
|
942 |
|
943 |
+
# Update local history
|
944 |
self.history['train_loss'].append(train_loss)
|
945 |
self.history['val_loss'].append(val_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
946 |
self.history.setdefault('learning_rate', []).append(current_lr)
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
|
948 |
+
def convert_to_py_floats(obj):
|
949 |
+
if isinstance(obj, dict):
|
950 |
+
return {k: convert_to_py_floats(v) for k, v in obj.items()}
|
951 |
+
elif isinstance(obj, list):
|
952 |
+
return [convert_to_py_floats(x) for x in obj]
|
953 |
+
elif isinstance(obj, (np.float32, np.float64)):
|
954 |
+
return float(obj)
|
955 |
+
elif tf.is_tensor(obj):
|
956 |
+
return float(obj.numpy())
|
957 |
+
else:
|
958 |
+
return obj
|
959 |
+
|
960 |
+
json_history = convert_to_py_floats(self.history)
|
961 |
+
|
962 |
+
# Save training history to file every epoch
|
963 |
+
# (Create or overwrite the file so we always have the latest.)
|
964 |
+
with open(history_path, 'w') as f:
|
965 |
+
json.dump(json_history, f)
|
966 |
+
logger.info(f"Saved training history to {history_path}")
|
967 |
+
|
968 |
+
# Early stopping
|
969 |
if val_loss < best_val_loss - min_delta:
|
970 |
best_val_loss = val_loss
|
971 |
epochs_no_improve = 0
|
|
|
1021 |
# Now compute scores: dot product of q_enc with each column in combined_p_n
|
1022 |
# We'll use `tf.einsum` to handle the batch dimension properly
|
1023 |
# dot_products => shape [batch_size, (1+neg_samples)]
|
1024 |
+
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
1025 |
+
labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
|
|
|
|
|
|
|
|
|
1026 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1027 |
labels=labels,
|
1028 |
logits=dot_products
|
1029 |
)
|
1030 |
+
loss = tf.cast(tf.reduce_mean(loss), tf.float32)
|
1031 |
|
1032 |
# Calculate gradients
|
1033 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1034 |
+
gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
1035 |
+
max_grad_norm = tf.constant(1.5, dtype=tf.float32)
|
|
|
|
|
1036 |
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
|
1037 |
+
post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
|
1038 |
|
1039 |
# Apply gradients
|
1040 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
|
|
1067 |
axis=1
|
1068 |
)
|
1069 |
|
1070 |
+
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
|
1071 |
+
labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
|
1072 |
+
|
1073 |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1074 |
labels=labels,
|
1075 |
logits=dot_products
|
1076 |
)
|
1077 |
+
loss = tf.cast(tf.reduce_mean(loss), tf.float32)
|
1078 |
|
1079 |
return loss
|
1080 |
|
|
|
1101 |
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
|
1102 |
|
1103 |
# Calculate and store constants
|
1104 |
+
self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32)
|
1105 |
+
self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32)
|
1106 |
|
1107 |
logger.info(f"Learning rate schedule initialized:")
|
1108 |
logger.info(f" Initial LR: {float(self.initial_lr):.6f}")
|
|
|
1115 |
step = tf.cast(step, tf.float32)
|
1116 |
|
1117 |
# Warmup phase
|
1118 |
+
warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32)
|
1119 |
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
|
1120 |
|
1121 |
# Decay phase
|
1122 |
+
decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32)
|
1123 |
+
decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32)
|
1124 |
+
decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32)
|
1125 |
+
cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32)
|
|
|
1126 |
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
1127 |
|
1128 |
# Choose between warmup and decay
|
train_model.py
CHANGED
@@ -48,14 +48,17 @@ def main():
|
|
48 |
# Initialize chatbot
|
49 |
chatbot = RetrievalChatbot(config, mode='training')
|
50 |
|
51 |
-
#
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
# Train the model
|
61 |
chatbot.train_model(
|
@@ -63,8 +66,9 @@ def main():
|
|
63 |
epochs=EPOCHS,
|
64 |
batch_size=batch_size,
|
65 |
use_lr_schedule=True,
|
66 |
-
test_mode=
|
67 |
-
checkpoint_dir=CHECKPOINT_DIR
|
|
|
68 |
)
|
69 |
|
70 |
# Save final model
|
|
|
48 |
# Initialize chatbot
|
49 |
chatbot = RetrievalChatbot(config, mode='training')
|
50 |
|
51 |
+
# Check for existing checkpoint and get initial epoch
|
52 |
+
latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
|
53 |
+
initial_epoch = 0
|
54 |
+
if latest_checkpoint:
|
55 |
+
try:
|
56 |
+
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
|
57 |
+
initial_epoch = ckpt_number
|
58 |
+
logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
|
59 |
+
except (IndexError, ValueError):
|
60 |
+
logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
|
61 |
+
initial_epoch = 0
|
62 |
|
63 |
# Train the model
|
64 |
chatbot.train_model(
|
|
|
66 |
epochs=EPOCHS,
|
67 |
batch_size=batch_size,
|
68 |
use_lr_schedule=True,
|
69 |
+
test_mode=True,
|
70 |
+
checkpoint_dir=CHECKPOINT_DIR,
|
71 |
+
initial_epoch=initial_epoch
|
72 |
)
|
73 |
|
74 |
# Save final model
|