JoeArmani
commited on
Commit
·
f5346f7
1
Parent(s):
74af405
updating training process
Browse files- chatbot_model.py +368 -78
- run_model_train.py +7 -6
- tf_data_pipeline.py +25 -1
chatbot_model.py
CHANGED
@@ -229,7 +229,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
229 |
self.encoder = EncoderModel(
|
230 |
self.config,
|
231 |
name="shared_encoder",
|
232 |
-
shared_weights=True # If weight sharing is intended
|
233 |
)
|
234 |
|
235 |
# Resize token embeddings after adding special tokens
|
@@ -875,37 +874,35 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
875 |
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
876 |
return chatbot
|
877 |
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
# logger.info(f"Loaded {len(dialogues)} dialogues.")
|
904 |
-
# return dialogues
|
905 |
|
906 |
def train_streaming(
|
907 |
self,
|
908 |
-
|
909 |
epochs: int = 20,
|
910 |
batch_size: int = 16,
|
911 |
validation_split: float = 0.2,
|
@@ -915,31 +912,23 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
915 |
warmup_steps_ratio: float = 0.1,
|
916 |
early_stopping_patience: int = 3,
|
917 |
min_delta: float = 1e-4,
|
918 |
-
neg_samples: int = 1
|
919 |
) -> None:
|
920 |
-
"""
|
921 |
-
logger.info("Starting
|
922 |
-
|
923 |
-
# Initialize TFDataPipeline (replaces StreamingDataPipeline)
|
924 |
-
dataset_preparer = TFDataPipeline(
|
925 |
-
embedding_batch_size=self.config.embedding_batch_size,
|
926 |
-
tokenizer=self.tokenizer,
|
927 |
-
encoder=self.encoder,
|
928 |
-
index=self.index, # Pass CPU version of FAISS index
|
929 |
-
response_pool=self.response_pool,
|
930 |
-
max_length=self.config.max_context_token_limit,
|
931 |
-
neg_samples=neg_samples
|
932 |
-
)
|
933 |
|
934 |
# Calculate total steps for learning rate schedule
|
935 |
-
|
|
|
|
|
|
|
|
|
|
|
936 |
train_size = int(total_pairs * (1 - validation_split))
|
937 |
-
val_size =
|
938 |
-
steps_per_epoch =
|
939 |
-
val_steps =
|
940 |
total_steps = steps_per_epoch * epochs
|
941 |
|
942 |
-
logger.info(f"Total pairs: {total_pairs}")
|
943 |
logger.info(f"Training pairs: {train_size}")
|
944 |
logger.info(f"Validation pairs: {val_size}")
|
945 |
logger.info(f"Steps per epoch: {steps_per_epoch}")
|
@@ -974,9 +963,19 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
974 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
975 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
976 |
|
977 |
-
#
|
978 |
-
|
979 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
980 |
|
981 |
# Training loop
|
982 |
best_val_loss = float("inf")
|
@@ -996,7 +995,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
996 |
logger.info("Training progress bar disabled")
|
997 |
|
998 |
for q_batch, p_batch, n_batch in train_dataset:
|
999 |
-
#p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
|
1000 |
loss = self.train_step(q_batch, p_batch, n_batch)
|
1001 |
epoch_loss_avg(loss)
|
1002 |
batches_processed += 1
|
@@ -1018,7 +1016,7 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1018 |
"lr": f"{current_lr:.2e}",
|
1019 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
1020 |
})
|
1021 |
-
|
1022 |
# Memory cleanup
|
1023 |
gc.collect()
|
1024 |
|
@@ -1041,7 +1039,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1041 |
logger.info("Validation progress bar disabled")
|
1042 |
|
1043 |
for q_batch, p_batch, n_batch in val_dataset:
|
1044 |
-
#p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
|
1045 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
1046 |
val_loss_avg(val_loss)
|
1047 |
val_batches_processed += 1
|
@@ -1052,11 +1049,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1052 |
"val_loss": f"{val_loss.numpy():.4f}",
|
1053 |
"batches": f"{val_batches_processed}/{val_steps}"
|
1054 |
})
|
1055 |
-
|
1056 |
# Memory cleanup
|
1057 |
gc.collect()
|
1058 |
|
1059 |
-
|
1060 |
if val_batches_processed >= val_steps:
|
1061 |
break
|
1062 |
|
@@ -1100,21 +1096,17 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1100 |
logger.info("Early stopping triggered.")
|
1101 |
break
|
1102 |
|
1103 |
-
logger.info("
|
1104 |
-
|
1105 |
-
|
1106 |
@tf.function
|
1107 |
def train_step(
|
1108 |
self,
|
1109 |
q_batch: tf.Tensor,
|
1110 |
p_batch: tf.Tensor,
|
1111 |
-
n_batch: tf.Tensor
|
1112 |
-
attention_mask: Optional[tf.Tensor] = None
|
1113 |
) -> tf.Tensor:
|
1114 |
"""
|
1115 |
-
Single training step
|
1116 |
-
contrastive/InfoNCE style. The label is always 0 (the positive) vs.
|
1117 |
-
the negative alternatives.
|
1118 |
"""
|
1119 |
with tf.GradientTape() as tape:
|
1120 |
# Encode queries
|
@@ -1160,12 +1152,6 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1160 |
)
|
1161 |
loss = tf.reduce_mean(loss)
|
1162 |
|
1163 |
-
# If there's an attention_mask you want to apply (less common in this scenario),
|
1164 |
-
# you could do something like:
|
1165 |
-
if attention_mask is not None:
|
1166 |
-
loss = loss * attention_mask
|
1167 |
-
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
1168 |
-
|
1169 |
# Apply gradients
|
1170 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1171 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
@@ -1176,12 +1162,10 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1176 |
self,
|
1177 |
q_batch: tf.Tensor,
|
1178 |
p_batch: tf.Tensor,
|
1179 |
-
n_batch: tf.Tensor
|
1180 |
-
attention_mask: Optional[tf.Tensor] = None
|
1181 |
) -> tf.Tensor:
|
1182 |
"""
|
1183 |
-
Single validation step
|
1184 |
-
Uses the same loss calculation as train_step, but `training=False`.
|
1185 |
"""
|
1186 |
q_enc = self.encoder(q_batch, training=False)
|
1187 |
p_enc = self.encoder(p_batch, training=False)
|
@@ -1208,11 +1192,317 @@ class RetrievalChatbot(DeviceAwareModel):
|
|
1208 |
)
|
1209 |
loss = tf.reduce_mean(loss)
|
1210 |
|
1211 |
-
if attention_mask is not None:
|
1212 |
-
loss = loss * attention_mask
|
1213 |
-
loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
1214 |
-
|
1215 |
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1216 |
|
1217 |
def _get_lr_schedule(
|
1218 |
self,
|
|
|
229 |
self.encoder = EncoderModel(
|
230 |
self.config,
|
231 |
name="shared_encoder",
|
|
|
232 |
)
|
233 |
|
234 |
# Resize token embeddings after adding special tokens
|
|
|
874 |
logger.info(f"Models and tokenizer loaded from {load_dir}.")
|
875 |
return chatbot
|
876 |
|
877 |
+
def parse_tfrecord_fn(example_proto, max_length, neg_samples):
|
878 |
+
"""
|
879 |
+
Parses a single TFRecord example.
|
880 |
+
|
881 |
+
Args:
|
882 |
+
example_proto: A serialized TFRecord example.
|
883 |
+
max_length: The maximum sequence length for tokenization.
|
884 |
+
neg_samples: The number of hard negatives per query.
|
885 |
+
|
886 |
+
Returns:
|
887 |
+
A tuple of (query_ids, positive_ids, negative_ids).
|
888 |
+
"""
|
889 |
+
feature_description = {
|
890 |
+
'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
891 |
+
'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
892 |
+
'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
|
893 |
+
}
|
894 |
+
parsed_features = tf.io.parse_single_example(example_proto, feature_description)
|
895 |
+
|
896 |
+
query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
|
897 |
+
positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
|
898 |
+
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
899 |
+
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
900 |
+
|
901 |
+
return query_ids, positive_ids, negative_ids
|
|
|
|
|
902 |
|
903 |
def train_streaming(
|
904 |
self,
|
905 |
+
tfrecord_file_path: str,
|
906 |
epochs: int = 20,
|
907 |
batch_size: int = 16,
|
908 |
validation_split: float = 0.2,
|
|
|
912 |
warmup_steps_ratio: float = 0.1,
|
913 |
early_stopping_patience: int = 3,
|
914 |
min_delta: float = 1e-4,
|
|
|
915 |
) -> None:
|
916 |
+
"""Training using a pre-prepared TFRecord dataset."""
|
917 |
+
logger.info("Starting training with pre-prepared TFRecord dataset...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
918 |
|
919 |
# Calculate total steps for learning rate schedule
|
920 |
+
# Estimate total pairs by counting the number of records in the TFRecord
|
921 |
+
# Assuming each record corresponds to one pair
|
922 |
+
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
923 |
+
total_pairs = sum(1 for _ in raw_dataset)
|
924 |
+
logger.info(f"Total pairs in TFRecord: {total_pairs}")
|
925 |
+
|
926 |
train_size = int(total_pairs * (1 - validation_split))
|
927 |
+
val_size = total_pairs - train_size
|
928 |
+
steps_per_epoch = math.ceil(train_size / batch_size)
|
929 |
+
val_steps = math.ceil(val_size / batch_size)
|
930 |
total_steps = steps_per_epoch * epochs
|
931 |
|
|
|
932 |
logger.info(f"Training pairs: {train_size}")
|
933 |
logger.info(f"Validation pairs: {val_size}")
|
934 |
logger.info(f"Steps per epoch: {steps_per_epoch}")
|
|
|
963 |
val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
964 |
logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
965 |
|
966 |
+
# Define the parsing function with the appropriate max_length and neg_samples
|
967 |
+
parse_fn = lambda x: self.parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples)
|
968 |
+
|
969 |
+
# Create the full dataset
|
970 |
+
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
|
971 |
+
dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
972 |
+
dataset = dataset.shuffle(buffer_size=10000) # Adjust buffer size as needed TODO: what is this?
|
973 |
+
dataset = dataset.batch(batch_size, drop_remainder=True)
|
974 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
975 |
+
|
976 |
+
# Split into training and validation
|
977 |
+
train_dataset = dataset.take(train_size)
|
978 |
+
val_dataset = dataset.skip(train_size).take(val_size)
|
979 |
|
980 |
# Training loop
|
981 |
best_val_loss = float("inf")
|
|
|
995 |
logger.info("Training progress bar disabled")
|
996 |
|
997 |
for q_batch, p_batch, n_batch in train_dataset:
|
|
|
998 |
loss = self.train_step(q_batch, p_batch, n_batch)
|
999 |
epoch_loss_avg(loss)
|
1000 |
batches_processed += 1
|
|
|
1016 |
"lr": f"{current_lr:.2e}",
|
1017 |
"batches": f"{batches_processed}/{steps_per_epoch}"
|
1018 |
})
|
1019 |
+
|
1020 |
# Memory cleanup
|
1021 |
gc.collect()
|
1022 |
|
|
|
1039 |
logger.info("Validation progress bar disabled")
|
1040 |
|
1041 |
for q_batch, p_batch, n_batch in val_dataset:
|
|
|
1042 |
val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
1043 |
val_loss_avg(val_loss)
|
1044 |
val_batches_processed += 1
|
|
|
1049 |
"val_loss": f"{val_loss.numpy():.4f}",
|
1050 |
"batches": f"{val_batches_processed}/{val_steps}"
|
1051 |
})
|
1052 |
+
|
1053 |
# Memory cleanup
|
1054 |
gc.collect()
|
1055 |
|
|
|
1056 |
if val_batches_processed >= val_steps:
|
1057 |
break
|
1058 |
|
|
|
1096 |
logger.info("Early stopping triggered.")
|
1097 |
break
|
1098 |
|
1099 |
+
logger.info("Training completed!")
|
1100 |
+
|
|
|
1101 |
@tf.function
|
1102 |
def train_step(
|
1103 |
self,
|
1104 |
q_batch: tf.Tensor,
|
1105 |
p_batch: tf.Tensor,
|
1106 |
+
n_batch: tf.Tensor
|
|
|
1107 |
) -> tf.Tensor:
|
1108 |
"""
|
1109 |
+
Single training step using queries, positives, and hard negatives.
|
|
|
|
|
1110 |
"""
|
1111 |
with tf.GradientTape() as tape:
|
1112 |
# Encode queries
|
|
|
1152 |
)
|
1153 |
loss = tf.reduce_mean(loss)
|
1154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1155 |
# Apply gradients
|
1156 |
gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1157 |
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
|
|
1162 |
self,
|
1163 |
q_batch: tf.Tensor,
|
1164 |
p_batch: tf.Tensor,
|
1165 |
+
n_batch: tf.Tensor
|
|
|
1166 |
) -> tf.Tensor:
|
1167 |
"""
|
1168 |
+
Single validation step using queries, positives, and hard negatives.
|
|
|
1169 |
"""
|
1170 |
q_enc = self.encoder(q_batch, training=False)
|
1171 |
p_enc = self.encoder(p_batch, training=False)
|
|
|
1192 |
)
|
1193 |
loss = tf.reduce_mean(loss)
|
1194 |
|
|
|
|
|
|
|
|
|
1195 |
return loss
|
1196 |
+
# def train_streaming(
|
1197 |
+
# self,
|
1198 |
+
# dialogues: List[dict],
|
1199 |
+
# epochs: int = 20,
|
1200 |
+
# batch_size: int = 16,
|
1201 |
+
# validation_split: float = 0.2,
|
1202 |
+
# checkpoint_dir: str = "checkpoints/",
|
1203 |
+
# use_lr_schedule: bool = True,
|
1204 |
+
# peak_lr: float = 2e-5,
|
1205 |
+
# warmup_steps_ratio: float = 0.1,
|
1206 |
+
# early_stopping_patience: int = 3,
|
1207 |
+
# min_delta: float = 1e-4,
|
1208 |
+
# neg_samples: int = 1
|
1209 |
+
# ) -> None:
|
1210 |
+
# """Streaming training with tf.data pipeline."""
|
1211 |
+
# logger.info("Starting streaming training pipeline with tf.data...")
|
1212 |
+
|
1213 |
+
# # Initialize TFDataPipeline (replaces StreamingDataPipeline)
|
1214 |
+
# dataset_preparer = TFDataPipeline(
|
1215 |
+
# embedding_batch_size=self.config.embedding_batch_size,
|
1216 |
+
# tokenizer=self.tokenizer,
|
1217 |
+
# encoder=self.encoder,
|
1218 |
+
# index=self.index, # Pass CPU version of FAISS index
|
1219 |
+
# response_pool=self.response_pool,
|
1220 |
+
# max_length=self.config.max_context_token_limit,
|
1221 |
+
# neg_samples=neg_samples
|
1222 |
+
# )
|
1223 |
+
|
1224 |
+
# # Calculate total steps for learning rate schedule
|
1225 |
+
# total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
|
1226 |
+
# train_size = int(total_pairs * (1 - validation_split))
|
1227 |
+
# val_size = int(total_pairs * validation_split)
|
1228 |
+
# steps_per_epoch = int(math.ceil(train_size / batch_size))
|
1229 |
+
# val_steps = int(math.ceil(val_size / batch_size))
|
1230 |
+
# total_steps = steps_per_epoch * epochs
|
1231 |
+
|
1232 |
+
# logger.info(f"Total pairs: {total_pairs}")
|
1233 |
+
# logger.info(f"Training pairs: {train_size}")
|
1234 |
+
# logger.info(f"Validation pairs: {val_size}")
|
1235 |
+
# logger.info(f"Steps per epoch: {steps_per_epoch}")
|
1236 |
+
# logger.info(f"Validation steps: {val_steps}")
|
1237 |
+
# logger.info(f"Total steps: {total_steps}")
|
1238 |
+
|
1239 |
+
# # Set up optimizer with learning rate schedule
|
1240 |
+
# if use_lr_schedule:
|
1241 |
+
# warmup_steps = int(total_steps * warmup_steps_ratio)
|
1242 |
+
# lr_schedule = self._get_lr_schedule(
|
1243 |
+
# total_steps=total_steps,
|
1244 |
+
# peak_lr=peak_lr,
|
1245 |
+
# warmup_steps=warmup_steps
|
1246 |
+
# )
|
1247 |
+
# self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
1248 |
+
# logger.info("Using custom learning rate schedule.")
|
1249 |
+
# else:
|
1250 |
+
# self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
|
1251 |
+
# logger.info("Using fixed learning rate.")
|
1252 |
+
|
1253 |
+
# # Initialize checkpoint manager
|
1254 |
+
# checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
|
1255 |
+
# manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
|
1256 |
+
|
1257 |
+
# # Setup TensorBoard
|
1258 |
+
# log_dir = Path(checkpoint_dir) / "tensorboard_logs"
|
1259 |
+
# log_dir.mkdir(parents=True, exist_ok=True)
|
1260 |
+
# current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
1261 |
+
# train_log_dir = str(log_dir / f"train_{current_time}")
|
1262 |
+
# val_log_dir = str(log_dir / f"val_{current_time}")
|
1263 |
+
# train_summary_writer = tf.summary.create_file_writer(train_log_dir)
|
1264 |
+
# val_summary_writer = tf.summary.create_file_writer(val_log_dir)
|
1265 |
+
# logger.info(f"TensorBoard logs will be saved in {log_dir}")
|
1266 |
+
|
1267 |
+
# # Create training and validation datasets
|
1268 |
+
# train_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).take(train_size)
|
1269 |
+
# val_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).skip(train_size).take(val_size)
|
1270 |
+
|
1271 |
+
# # Training loop
|
1272 |
+
# best_val_loss = float("inf")
|
1273 |
+
# epochs_no_improve = 0
|
1274 |
+
|
1275 |
+
# for epoch in range(1, epochs + 1):
|
1276 |
+
# # --- Training Phase ---
|
1277 |
+
# epoch_loss_avg = tf.keras.metrics.Mean()
|
1278 |
+
# batches_processed = 0
|
1279 |
+
|
1280 |
+
# try:
|
1281 |
+
# train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch")
|
1282 |
+
# is_tqdm_train = True
|
1283 |
+
# except ImportError:
|
1284 |
+
# train_pbar = None
|
1285 |
+
# is_tqdm_train = False
|
1286 |
+
# logger.info("Training progress bar disabled")
|
1287 |
+
|
1288 |
+
# for q_batch, p_batch, n_batch in train_dataset:
|
1289 |
+
# #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
|
1290 |
+
# loss = self.train_step(q_batch, p_batch, n_batch)
|
1291 |
+
# epoch_loss_avg(loss)
|
1292 |
+
# batches_processed += 1
|
1293 |
+
|
1294 |
+
# # Log to TensorBoard
|
1295 |
+
# with train_summary_writer.as_default():
|
1296 |
+
# tf.summary.scalar("loss", loss, step=(epoch - 1) * steps_per_epoch + batches_processed)
|
1297 |
+
|
1298 |
+
# # Update progress bar
|
1299 |
+
# if use_lr_schedule:
|
1300 |
+
# current_lr = float(lr_schedule(self.optimizer.iterations))
|
1301 |
+
# else:
|
1302 |
+
# current_lr = float(self.optimizer.learning_rate.numpy())
|
1303 |
+
|
1304 |
+
# if is_tqdm_train:
|
1305 |
+
# train_pbar.update(1)
|
1306 |
+
# train_pbar.set_postfix({
|
1307 |
+
# "loss": f"{loss.numpy():.4f}",
|
1308 |
+
# "lr": f"{current_lr:.2e}",
|
1309 |
+
# "batches": f"{batches_processed}/{steps_per_epoch}"
|
1310 |
+
# })
|
1311 |
+
|
1312 |
+
# # Memory cleanup
|
1313 |
+
# gc.collect()
|
1314 |
+
|
1315 |
+
# if batches_processed >= steps_per_epoch:
|
1316 |
+
# break
|
1317 |
+
|
1318 |
+
# if is_tqdm_train and train_pbar:
|
1319 |
+
# train_pbar.close()
|
1320 |
+
|
1321 |
+
# # --- Validation Phase ---
|
1322 |
+
# val_loss_avg = tf.keras.metrics.Mean()
|
1323 |
+
# val_batches_processed = 0
|
1324 |
+
|
1325 |
+
# try:
|
1326 |
+
# val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
|
1327 |
+
# is_tqdm_val = True
|
1328 |
+
# except ImportError:
|
1329 |
+
# val_pbar = None
|
1330 |
+
# is_tqdm_val = False
|
1331 |
+
# logger.info("Validation progress bar disabled")
|
1332 |
+
|
1333 |
+
# for q_batch, p_batch, n_batch in val_dataset:
|
1334 |
+
# #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
|
1335 |
+
# val_loss = self.validation_step(q_batch, p_batch, n_batch)
|
1336 |
+
# val_loss_avg(val_loss)
|
1337 |
+
# val_batches_processed += 1
|
1338 |
+
|
1339 |
+
# if is_tqdm_val:
|
1340 |
+
# val_pbar.update(1)
|
1341 |
+
# val_pbar.set_postfix({
|
1342 |
+
# "val_loss": f"{val_loss.numpy():.4f}",
|
1343 |
+
# "batches": f"{val_batches_processed}/{val_steps}"
|
1344 |
+
# })
|
1345 |
+
|
1346 |
+
# # Memory cleanup
|
1347 |
+
# gc.collect()
|
1348 |
+
|
1349 |
+
|
1350 |
+
# if val_batches_processed >= val_steps:
|
1351 |
+
# break
|
1352 |
+
|
1353 |
+
# if is_tqdm_val and val_pbar:
|
1354 |
+
# val_pbar.close()
|
1355 |
+
|
1356 |
+
# # End of epoch: compute final epoch stats, log, and save checkpoint
|
1357 |
+
# train_loss = epoch_loss_avg.result().numpy()
|
1358 |
+
# val_loss = val_loss_avg.result().numpy()
|
1359 |
+
# logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
|
1360 |
+
|
1361 |
+
# # Log epoch metrics
|
1362 |
+
# with train_summary_writer.as_default():
|
1363 |
+
# tf.summary.scalar("epoch_loss", train_loss, step=epoch)
|
1364 |
+
# with val_summary_writer.as_default():
|
1365 |
+
# tf.summary.scalar("val_loss", val_loss, step=epoch)
|
1366 |
+
|
1367 |
+
# # Save checkpoint
|
1368 |
+
# manager.save()
|
1369 |
+
|
1370 |
+
# # Store metrics in history
|
1371 |
+
# self.history['train_loss'].append(train_loss)
|
1372 |
+
# self.history['val_loss'].append(val_loss)
|
1373 |
+
|
1374 |
+
# if use_lr_schedule:
|
1375 |
+
# current_lr = float(lr_schedule(self.optimizer.iterations))
|
1376 |
+
# else:
|
1377 |
+
# current_lr = float(self.optimizer.learning_rate.numpy())
|
1378 |
+
|
1379 |
+
# self.history.setdefault('learning_rate', []).append(current_lr)
|
1380 |
+
|
1381 |
+
# # Early stopping logic
|
1382 |
+
# if val_loss < best_val_loss - min_delta:
|
1383 |
+
# best_val_loss = val_loss
|
1384 |
+
# epochs_no_improve = 0
|
1385 |
+
# logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
|
1386 |
+
# else:
|
1387 |
+
# epochs_no_improve += 1
|
1388 |
+
# logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
|
1389 |
+
# if epochs_no_improve >= early_stopping_patience:
|
1390 |
+
# logger.info("Early stopping triggered.")
|
1391 |
+
# break
|
1392 |
+
|
1393 |
+
# logger.info("Streaming training completed!")
|
1394 |
+
|
1395 |
+
|
1396 |
+
# @tf.function
|
1397 |
+
# def train_step(
|
1398 |
+
# self,
|
1399 |
+
# q_batch: tf.Tensor,
|
1400 |
+
# p_batch: tf.Tensor,
|
1401 |
+
# n_batch: tf.Tensor,
|
1402 |
+
# attention_mask: Optional[tf.Tensor] = None
|
1403 |
+
# ) -> tf.Tensor:
|
1404 |
+
# """
|
1405 |
+
# Single training step that uses queries, positives, and negatives in a
|
1406 |
+
# contrastive/InfoNCE style. The label is always 0 (the positive) vs.
|
1407 |
+
# the negative alternatives.
|
1408 |
+
# """
|
1409 |
+
# with tf.GradientTape() as tape:
|
1410 |
+
# # Encode queries
|
1411 |
+
# q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
|
1412 |
+
|
1413 |
+
# # Encode positives
|
1414 |
+
# p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
|
1415 |
+
|
1416 |
+
# # Encode negatives
|
1417 |
+
# # n_batch: [batch_size, neg_samples, max_length]
|
1418 |
+
# shape = tf.shape(n_batch)
|
1419 |
+
# bs = shape[0]
|
1420 |
+
# neg_samples = shape[1]
|
1421 |
+
|
1422 |
+
# # Flatten negatives to feed them in one pass:
|
1423 |
+
# # => [batch_size * neg_samples, max_length]
|
1424 |
+
# n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
1425 |
+
# n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
|
1426 |
+
|
1427 |
+
# # Reshape back => [batch_size, neg_samples, embed_dim]
|
1428 |
+
# n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
1429 |
+
|
1430 |
+
# # Combine the positive embedding and negative embeddings along dim=1
|
1431 |
+
# # => shape [batch_size, 1 + neg_samples, embed_dim]
|
1432 |
+
# # The first column is the positive; subsequent columns are negatives
|
1433 |
+
# combined_p_n = tf.concat(
|
1434 |
+
# [tf.expand_dims(p_enc, axis=1), n_enc],
|
1435 |
+
# axis=1
|
1436 |
+
# ) # [bs, (1+neg_samples), embed_dim]
|
1437 |
+
|
1438 |
+
# # Now compute scores: dot product of q_enc with each column in combined_p_n
|
1439 |
+
# # We'll use `tf.einsum` to handle the batch dimension properly
|
1440 |
+
# # dot_products => shape [batch_size, (1+neg_samples)]
|
1441 |
+
# dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
|
1442 |
+
|
1443 |
+
# # The label for each row is 0 (the first column is the correct/positive)
|
1444 |
+
# labels = tf.zeros([bs], dtype=tf.int32)
|
1445 |
+
|
1446 |
+
# # Cross-entropy over the [batch_size, 1+neg_samples] scores
|
1447 |
+
# loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1448 |
+
# labels=labels,
|
1449 |
+
# logits=dot_products
|
1450 |
+
# )
|
1451 |
+
# loss = tf.reduce_mean(loss)
|
1452 |
+
|
1453 |
+
# # If there's an attention_mask you want to apply (less common in this scenario),
|
1454 |
+
# # you could do something like:
|
1455 |
+
# if attention_mask is not None:
|
1456 |
+
# loss = loss * attention_mask
|
1457 |
+
# loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
1458 |
+
|
1459 |
+
# # Apply gradients
|
1460 |
+
# gradients = tape.gradient(loss, self.encoder.trainable_variables)
|
1461 |
+
# self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
|
1462 |
+
# return loss
|
1463 |
+
|
1464 |
+
# @tf.function
|
1465 |
+
# def validation_step(
|
1466 |
+
# self,
|
1467 |
+
# q_batch: tf.Tensor,
|
1468 |
+
# p_batch: tf.Tensor,
|
1469 |
+
# n_batch: tf.Tensor,
|
1470 |
+
# attention_mask: Optional[tf.Tensor] = None
|
1471 |
+
# ) -> tf.Tensor:
|
1472 |
+
# """
|
1473 |
+
# Single validation step with queries, positives, and negatives.
|
1474 |
+
# Uses the same loss calculation as train_step, but `training=False`.
|
1475 |
+
# """
|
1476 |
+
# q_enc = self.encoder(q_batch, training=False)
|
1477 |
+
# p_enc = self.encoder(p_batch, training=False)
|
1478 |
+
|
1479 |
+
# shape = tf.shape(n_batch)
|
1480 |
+
# bs = shape[0]
|
1481 |
+
# neg_samples = shape[1]
|
1482 |
+
|
1483 |
+
# n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
|
1484 |
+
# n_enc_flat = self.encoder(n_batch_flat, training=False)
|
1485 |
+
# n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
|
1486 |
+
|
1487 |
+
# combined_p_n = tf.concat(
|
1488 |
+
# [tf.expand_dims(p_enc, axis=1), n_enc],
|
1489 |
+
# axis=1
|
1490 |
+
# )
|
1491 |
+
|
1492 |
+
# dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
|
1493 |
+
# labels = tf.zeros([bs], dtype=tf.int32)
|
1494 |
+
|
1495 |
+
# loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
1496 |
+
# labels=labels,
|
1497 |
+
# logits=dot_products
|
1498 |
+
# )
|
1499 |
+
# loss = tf.reduce_mean(loss)
|
1500 |
+
|
1501 |
+
# if attention_mask is not None:
|
1502 |
+
# loss = loss * attention_mask
|
1503 |
+
# loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
|
1504 |
+
|
1505 |
+
# return loss
|
1506 |
|
1507 |
def _get_lr_schedule(
|
1508 |
self,
|
run_model_train.py
CHANGED
@@ -39,7 +39,7 @@ def main():
|
|
39 |
|
40 |
DEBUG_SAMPLES = 5
|
41 |
EPOCHS = 5 if DEBUG_SAMPLES else 20
|
42 |
-
|
43 |
|
44 |
# Optimize batch size for Colab
|
45 |
batch_size = env.optimize_batch_size(base_batch_size=16)
|
@@ -49,20 +49,21 @@ def main():
|
|
49 |
embedding_dim=768, # DistilBERT
|
50 |
max_context_token_limit=512,
|
51 |
freeze_embeddings=False,
|
|
|
52 |
)
|
53 |
|
54 |
-
# Load training data
|
55 |
-
dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
|
56 |
-
print(dialogues)
|
57 |
|
58 |
# Initialize chatbot and verify FAISS index
|
59 |
#with env.strategy.scope():
|
60 |
-
chatbot = RetrievalChatbot(config
|
61 |
chatbot.build_models()
|
62 |
chatbot.verify_faiss_index()
|
63 |
|
64 |
chatbot.train_streaming(
|
65 |
-
|
66 |
epochs=EPOCHS,
|
67 |
batch_size=batch_size,
|
68 |
use_lr_schedule=True,
|
|
|
39 |
|
40 |
DEBUG_SAMPLES = 5
|
41 |
EPOCHS = 5 if DEBUG_SAMPLES else 20
|
42 |
+
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
|
43 |
|
44 |
# Optimize batch size for Colab
|
45 |
batch_size = env.optimize_batch_size(base_batch_size=16)
|
|
|
49 |
embedding_dim=768, # DistilBERT
|
50 |
max_context_token_limit=512,
|
51 |
freeze_embeddings=False,
|
52 |
+
neg_samples=3,
|
53 |
)
|
54 |
|
55 |
+
# # Load training data
|
56 |
+
# dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
|
57 |
+
# print(dialogues)
|
58 |
|
59 |
# Initialize chatbot and verify FAISS index
|
60 |
#with env.strategy.scope():
|
61 |
+
chatbot = RetrievalChatbot(config)
|
62 |
chatbot.build_models()
|
63 |
chatbot.verify_faiss_index()
|
64 |
|
65 |
chatbot.train_streaming(
|
66 |
+
tfrecord_file_path=TF_RECORD_FILE_PATH,
|
67 |
epochs=EPOCHS,
|
68 |
batch_size=batch_size,
|
69 |
use_lr_schedule=True,
|
tf_data_pipeline.py
CHANGED
@@ -689,7 +689,31 @@ class TFDataPipeline:
|
|
689 |
|
690 |
return q_ids, p_ids, n_ids
|
691 |
|
692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
|
694 |
# def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
695 |
# """Find hard negatives for a batch of queries with error handling and retries."""
|
|
|
689 |
|
690 |
return q_ids, p_ids, n_ids
|
691 |
|
692 |
+
# def parse_tfrecord_fn(example_proto, max_length, neg_samples):
|
693 |
+
# """
|
694 |
+
# Parses a single TFRecord example.
|
695 |
+
|
696 |
+
# Args:
|
697 |
+
# example_proto: A serialized TFRecord example.
|
698 |
+
# max_length: The maximum sequence length for tokenization.
|
699 |
+
# neg_samples: The number of hard negatives per query.
|
700 |
+
|
701 |
+
# Returns:
|
702 |
+
# A tuple of (query_ids, positive_ids, negative_ids).
|
703 |
+
# """
|
704 |
+
# feature_description = {
|
705 |
+
# 'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
706 |
+
# 'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
|
707 |
+
# 'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
|
708 |
+
# }
|
709 |
+
# parsed_features = tf.io.parse_single_example(example_proto, feature_description)
|
710 |
+
|
711 |
+
# query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
|
712 |
+
# positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
|
713 |
+
# negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
|
714 |
+
# negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
|
715 |
+
|
716 |
+
# return query_ids, positive_ids, negative_ids
|
717 |
|
718 |
# def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
|
719 |
# """Find hard negatives for a batch of queries with error handling and retries."""
|