JoeArmani commited on
Commit
f5346f7
·
1 Parent(s): 74af405

updating training process

Browse files
Files changed (3) hide show
  1. chatbot_model.py +368 -78
  2. run_model_train.py +7 -6
  3. tf_data_pipeline.py +25 -1
chatbot_model.py CHANGED
@@ -229,7 +229,6 @@ class RetrievalChatbot(DeviceAwareModel):
229
  self.encoder = EncoderModel(
230
  self.config,
231
  name="shared_encoder",
232
- shared_weights=True # If weight sharing is intended
233
  )
234
 
235
  # Resize token embeddings after adding special tokens
@@ -875,37 +874,35 @@ class RetrievalChatbot(DeviceAwareModel):
875
  logger.info(f"Models and tokenizer loaded from {load_dir}.")
876
  return chatbot
877
 
878
- # @staticmethod
879
- # def load_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
880
- # """
881
- # Load training data from a JSON file.
882
-
883
- # Args:
884
- # data_path (Union[str, Path]): Path to the JSON file containing dialogues.
885
- # debug_samples (Optional[int]): Number of samples to load for debugging.
886
-
887
- # Returns:
888
- # List[dict]: List of dialogue dictionaries.
889
- # """
890
- # logger.info(f"Loading training data from {data_path}...")
891
- # data_path = Path(data_path)
892
- # if not data_path.exists():
893
- # logger.error(f"Data file {data_path} does not exist.")
894
- # return []
895
-
896
- # with open(data_path, 'r', encoding='utf-8') as f:
897
- # dialogues = json.load(f)
898
-
899
- # if debug_samples is not None:
900
- # dialogues = dialogues[:debug_samples]
901
- # logger.info(f"Debug mode: Limited to {debug_samples} dialogues")
902
-
903
- # logger.info(f"Loaded {len(dialogues)} dialogues.")
904
- # return dialogues
905
 
906
  def train_streaming(
907
  self,
908
- dialogues: List[dict],
909
  epochs: int = 20,
910
  batch_size: int = 16,
911
  validation_split: float = 0.2,
@@ -915,31 +912,23 @@ class RetrievalChatbot(DeviceAwareModel):
915
  warmup_steps_ratio: float = 0.1,
916
  early_stopping_patience: int = 3,
917
  min_delta: float = 1e-4,
918
- neg_samples: int = 1
919
  ) -> None:
920
- """Streaming training with tf.data pipeline."""
921
- logger.info("Starting streaming training pipeline with tf.data...")
922
-
923
- # Initialize TFDataPipeline (replaces StreamingDataPipeline)
924
- dataset_preparer = TFDataPipeline(
925
- embedding_batch_size=self.config.embedding_batch_size,
926
- tokenizer=self.tokenizer,
927
- encoder=self.encoder,
928
- index=self.index, # Pass CPU version of FAISS index
929
- response_pool=self.response_pool,
930
- max_length=self.config.max_context_token_limit,
931
- neg_samples=neg_samples
932
- )
933
 
934
  # Calculate total steps for learning rate schedule
935
- total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
 
 
 
 
 
936
  train_size = int(total_pairs * (1 - validation_split))
937
- val_size = int(total_pairs * validation_split)
938
- steps_per_epoch = int(math.ceil(train_size / batch_size))
939
- val_steps = int(math.ceil(val_size / batch_size))
940
  total_steps = steps_per_epoch * epochs
941
 
942
- logger.info(f"Total pairs: {total_pairs}")
943
  logger.info(f"Training pairs: {train_size}")
944
  logger.info(f"Validation pairs: {val_size}")
945
  logger.info(f"Steps per epoch: {steps_per_epoch}")
@@ -974,9 +963,19 @@ class RetrievalChatbot(DeviceAwareModel):
974
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
975
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
976
 
977
- # Create training and validation datasets
978
- train_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).take(train_size)
979
- val_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).skip(train_size).take(val_size)
 
 
 
 
 
 
 
 
 
 
980
 
981
  # Training loop
982
  best_val_loss = float("inf")
@@ -996,7 +995,6 @@ class RetrievalChatbot(DeviceAwareModel):
996
  logger.info("Training progress bar disabled")
997
 
998
  for q_batch, p_batch, n_batch in train_dataset:
999
- #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
1000
  loss = self.train_step(q_batch, p_batch, n_batch)
1001
  epoch_loss_avg(loss)
1002
  batches_processed += 1
@@ -1018,7 +1016,7 @@ class RetrievalChatbot(DeviceAwareModel):
1018
  "lr": f"{current_lr:.2e}",
1019
  "batches": f"{batches_processed}/{steps_per_epoch}"
1020
  })
1021
-
1022
  # Memory cleanup
1023
  gc.collect()
1024
 
@@ -1041,7 +1039,6 @@ class RetrievalChatbot(DeviceAwareModel):
1041
  logger.info("Validation progress bar disabled")
1042
 
1043
  for q_batch, p_batch, n_batch in val_dataset:
1044
- #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
1045
  val_loss = self.validation_step(q_batch, p_batch, n_batch)
1046
  val_loss_avg(val_loss)
1047
  val_batches_processed += 1
@@ -1052,11 +1049,10 @@ class RetrievalChatbot(DeviceAwareModel):
1052
  "val_loss": f"{val_loss.numpy():.4f}",
1053
  "batches": f"{val_batches_processed}/{val_steps}"
1054
  })
1055
-
1056
  # Memory cleanup
1057
  gc.collect()
1058
 
1059
-
1060
  if val_batches_processed >= val_steps:
1061
  break
1062
 
@@ -1100,21 +1096,17 @@ class RetrievalChatbot(DeviceAwareModel):
1100
  logger.info("Early stopping triggered.")
1101
  break
1102
 
1103
- logger.info("Streaming training completed!")
1104
-
1105
-
1106
  @tf.function
1107
  def train_step(
1108
  self,
1109
  q_batch: tf.Tensor,
1110
  p_batch: tf.Tensor,
1111
- n_batch: tf.Tensor,
1112
- attention_mask: Optional[tf.Tensor] = None
1113
  ) -> tf.Tensor:
1114
  """
1115
- Single training step that uses queries, positives, and negatives in a
1116
- contrastive/InfoNCE style. The label is always 0 (the positive) vs.
1117
- the negative alternatives.
1118
  """
1119
  with tf.GradientTape() as tape:
1120
  # Encode queries
@@ -1160,12 +1152,6 @@ class RetrievalChatbot(DeviceAwareModel):
1160
  )
1161
  loss = tf.reduce_mean(loss)
1162
 
1163
- # If there's an attention_mask you want to apply (less common in this scenario),
1164
- # you could do something like:
1165
- if attention_mask is not None:
1166
- loss = loss * attention_mask
1167
- loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1168
-
1169
  # Apply gradients
1170
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
1171
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
@@ -1176,12 +1162,10 @@ class RetrievalChatbot(DeviceAwareModel):
1176
  self,
1177
  q_batch: tf.Tensor,
1178
  p_batch: tf.Tensor,
1179
- n_batch: tf.Tensor,
1180
- attention_mask: Optional[tf.Tensor] = None
1181
  ) -> tf.Tensor:
1182
  """
1183
- Single validation step with queries, positives, and negatives.
1184
- Uses the same loss calculation as train_step, but `training=False`.
1185
  """
1186
  q_enc = self.encoder(q_batch, training=False)
1187
  p_enc = self.encoder(p_batch, training=False)
@@ -1208,11 +1192,317 @@ class RetrievalChatbot(DeviceAwareModel):
1208
  )
1209
  loss = tf.reduce_mean(loss)
1210
 
1211
- if attention_mask is not None:
1212
- loss = loss * attention_mask
1213
- loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1214
-
1215
  return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1216
 
1217
  def _get_lr_schedule(
1218
  self,
 
229
  self.encoder = EncoderModel(
230
  self.config,
231
  name="shared_encoder",
 
232
  )
233
 
234
  # Resize token embeddings after adding special tokens
 
874
  logger.info(f"Models and tokenizer loaded from {load_dir}.")
875
  return chatbot
876
 
877
+ def parse_tfrecord_fn(example_proto, max_length, neg_samples):
878
+ """
879
+ Parses a single TFRecord example.
880
+
881
+ Args:
882
+ example_proto: A serialized TFRecord example.
883
+ max_length: The maximum sequence length for tokenization.
884
+ neg_samples: The number of hard negatives per query.
885
+
886
+ Returns:
887
+ A tuple of (query_ids, positive_ids, negative_ids).
888
+ """
889
+ feature_description = {
890
+ 'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
891
+ 'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
892
+ 'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
893
+ }
894
+ parsed_features = tf.io.parse_single_example(example_proto, feature_description)
895
+
896
+ query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
897
+ positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
898
+ negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
899
+ negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
900
+
901
+ return query_ids, positive_ids, negative_ids
 
 
902
 
903
  def train_streaming(
904
  self,
905
+ tfrecord_file_path: str,
906
  epochs: int = 20,
907
  batch_size: int = 16,
908
  validation_split: float = 0.2,
 
912
  warmup_steps_ratio: float = 0.1,
913
  early_stopping_patience: int = 3,
914
  min_delta: float = 1e-4,
 
915
  ) -> None:
916
+ """Training using a pre-prepared TFRecord dataset."""
917
+ logger.info("Starting training with pre-prepared TFRecord dataset...")
 
 
 
 
 
 
 
 
 
 
 
918
 
919
  # Calculate total steps for learning rate schedule
920
+ # Estimate total pairs by counting the number of records in the TFRecord
921
+ # Assuming each record corresponds to one pair
922
+ raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
923
+ total_pairs = sum(1 for _ in raw_dataset)
924
+ logger.info(f"Total pairs in TFRecord: {total_pairs}")
925
+
926
  train_size = int(total_pairs * (1 - validation_split))
927
+ val_size = total_pairs - train_size
928
+ steps_per_epoch = math.ceil(train_size / batch_size)
929
+ val_steps = math.ceil(val_size / batch_size)
930
  total_steps = steps_per_epoch * epochs
931
 
 
932
  logger.info(f"Training pairs: {train_size}")
933
  logger.info(f"Validation pairs: {val_size}")
934
  logger.info(f"Steps per epoch: {steps_per_epoch}")
 
963
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
964
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
965
 
966
+ # Define the parsing function with the appropriate max_length and neg_samples
967
+ parse_fn = lambda x: self.parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples)
968
+
969
+ # Create the full dataset
970
+ dataset = tf.data.TFRecordDataset(tfrecord_file_path)
971
+ dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
972
+ dataset = dataset.shuffle(buffer_size=10000) # Adjust buffer size as needed TODO: what is this?
973
+ dataset = dataset.batch(batch_size, drop_remainder=True)
974
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
975
+
976
+ # Split into training and validation
977
+ train_dataset = dataset.take(train_size)
978
+ val_dataset = dataset.skip(train_size).take(val_size)
979
 
980
  # Training loop
981
  best_val_loss = float("inf")
 
995
  logger.info("Training progress bar disabled")
996
 
997
  for q_batch, p_batch, n_batch in train_dataset:
 
998
  loss = self.train_step(q_batch, p_batch, n_batch)
999
  epoch_loss_avg(loss)
1000
  batches_processed += 1
 
1016
  "lr": f"{current_lr:.2e}",
1017
  "batches": f"{batches_processed}/{steps_per_epoch}"
1018
  })
1019
+
1020
  # Memory cleanup
1021
  gc.collect()
1022
 
 
1039
  logger.info("Validation progress bar disabled")
1040
 
1041
  for q_batch, p_batch, n_batch in val_dataset:
 
1042
  val_loss = self.validation_step(q_batch, p_batch, n_batch)
1043
  val_loss_avg(val_loss)
1044
  val_batches_processed += 1
 
1049
  "val_loss": f"{val_loss.numpy():.4f}",
1050
  "batches": f"{val_batches_processed}/{val_steps}"
1051
  })
1052
+
1053
  # Memory cleanup
1054
  gc.collect()
1055
 
 
1056
  if val_batches_processed >= val_steps:
1057
  break
1058
 
 
1096
  logger.info("Early stopping triggered.")
1097
  break
1098
 
1099
+ logger.info("Training completed!")
1100
+
 
1101
  @tf.function
1102
  def train_step(
1103
  self,
1104
  q_batch: tf.Tensor,
1105
  p_batch: tf.Tensor,
1106
+ n_batch: tf.Tensor
 
1107
  ) -> tf.Tensor:
1108
  """
1109
+ Single training step using queries, positives, and hard negatives.
 
 
1110
  """
1111
  with tf.GradientTape() as tape:
1112
  # Encode queries
 
1152
  )
1153
  loss = tf.reduce_mean(loss)
1154
 
 
 
 
 
 
 
1155
  # Apply gradients
1156
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
1157
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
 
1162
  self,
1163
  q_batch: tf.Tensor,
1164
  p_batch: tf.Tensor,
1165
+ n_batch: tf.Tensor
 
1166
  ) -> tf.Tensor:
1167
  """
1168
+ Single validation step using queries, positives, and hard negatives.
 
1169
  """
1170
  q_enc = self.encoder(q_batch, training=False)
1171
  p_enc = self.encoder(p_batch, training=False)
 
1192
  )
1193
  loss = tf.reduce_mean(loss)
1194
 
 
 
 
 
1195
  return loss
1196
+ # def train_streaming(
1197
+ # self,
1198
+ # dialogues: List[dict],
1199
+ # epochs: int = 20,
1200
+ # batch_size: int = 16,
1201
+ # validation_split: float = 0.2,
1202
+ # checkpoint_dir: str = "checkpoints/",
1203
+ # use_lr_schedule: bool = True,
1204
+ # peak_lr: float = 2e-5,
1205
+ # warmup_steps_ratio: float = 0.1,
1206
+ # early_stopping_patience: int = 3,
1207
+ # min_delta: float = 1e-4,
1208
+ # neg_samples: int = 1
1209
+ # ) -> None:
1210
+ # """Streaming training with tf.data pipeline."""
1211
+ # logger.info("Starting streaming training pipeline with tf.data...")
1212
+
1213
+ # # Initialize TFDataPipeline (replaces StreamingDataPipeline)
1214
+ # dataset_preparer = TFDataPipeline(
1215
+ # embedding_batch_size=self.config.embedding_batch_size,
1216
+ # tokenizer=self.tokenizer,
1217
+ # encoder=self.encoder,
1218
+ # index=self.index, # Pass CPU version of FAISS index
1219
+ # response_pool=self.response_pool,
1220
+ # max_length=self.config.max_context_token_limit,
1221
+ # neg_samples=neg_samples
1222
+ # )
1223
+
1224
+ # # Calculate total steps for learning rate schedule
1225
+ # total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
1226
+ # train_size = int(total_pairs * (1 - validation_split))
1227
+ # val_size = int(total_pairs * validation_split)
1228
+ # steps_per_epoch = int(math.ceil(train_size / batch_size))
1229
+ # val_steps = int(math.ceil(val_size / batch_size))
1230
+ # total_steps = steps_per_epoch * epochs
1231
+
1232
+ # logger.info(f"Total pairs: {total_pairs}")
1233
+ # logger.info(f"Training pairs: {train_size}")
1234
+ # logger.info(f"Validation pairs: {val_size}")
1235
+ # logger.info(f"Steps per epoch: {steps_per_epoch}")
1236
+ # logger.info(f"Validation steps: {val_steps}")
1237
+ # logger.info(f"Total steps: {total_steps}")
1238
+
1239
+ # # Set up optimizer with learning rate schedule
1240
+ # if use_lr_schedule:
1241
+ # warmup_steps = int(total_steps * warmup_steps_ratio)
1242
+ # lr_schedule = self._get_lr_schedule(
1243
+ # total_steps=total_steps,
1244
+ # peak_lr=peak_lr,
1245
+ # warmup_steps=warmup_steps
1246
+ # )
1247
+ # self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
1248
+ # logger.info("Using custom learning rate schedule.")
1249
+ # else:
1250
+ # self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
1251
+ # logger.info("Using fixed learning rate.")
1252
+
1253
+ # # Initialize checkpoint manager
1254
+ # checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
1255
+ # manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
1256
+
1257
+ # # Setup TensorBoard
1258
+ # log_dir = Path(checkpoint_dir) / "tensorboard_logs"
1259
+ # log_dir.mkdir(parents=True, exist_ok=True)
1260
+ # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
1261
+ # train_log_dir = str(log_dir / f"train_{current_time}")
1262
+ # val_log_dir = str(log_dir / f"val_{current_time}")
1263
+ # train_summary_writer = tf.summary.create_file_writer(train_log_dir)
1264
+ # val_summary_writer = tf.summary.create_file_writer(val_log_dir)
1265
+ # logger.info(f"TensorBoard logs will be saved in {log_dir}")
1266
+
1267
+ # # Create training and validation datasets
1268
+ # train_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).take(train_size)
1269
+ # val_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).skip(train_size).take(val_size)
1270
+
1271
+ # # Training loop
1272
+ # best_val_loss = float("inf")
1273
+ # epochs_no_improve = 0
1274
+
1275
+ # for epoch in range(1, epochs + 1):
1276
+ # # --- Training Phase ---
1277
+ # epoch_loss_avg = tf.keras.metrics.Mean()
1278
+ # batches_processed = 0
1279
+
1280
+ # try:
1281
+ # train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch")
1282
+ # is_tqdm_train = True
1283
+ # except ImportError:
1284
+ # train_pbar = None
1285
+ # is_tqdm_train = False
1286
+ # logger.info("Training progress bar disabled")
1287
+
1288
+ # for q_batch, p_batch, n_batch in train_dataset:
1289
+ # #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
1290
+ # loss = self.train_step(q_batch, p_batch, n_batch)
1291
+ # epoch_loss_avg(loss)
1292
+ # batches_processed += 1
1293
+
1294
+ # # Log to TensorBoard
1295
+ # with train_summary_writer.as_default():
1296
+ # tf.summary.scalar("loss", loss, step=(epoch - 1) * steps_per_epoch + batches_processed)
1297
+
1298
+ # # Update progress bar
1299
+ # if use_lr_schedule:
1300
+ # current_lr = float(lr_schedule(self.optimizer.iterations))
1301
+ # else:
1302
+ # current_lr = float(self.optimizer.learning_rate.numpy())
1303
+
1304
+ # if is_tqdm_train:
1305
+ # train_pbar.update(1)
1306
+ # train_pbar.set_postfix({
1307
+ # "loss": f"{loss.numpy():.4f}",
1308
+ # "lr": f"{current_lr:.2e}",
1309
+ # "batches": f"{batches_processed}/{steps_per_epoch}"
1310
+ # })
1311
+
1312
+ # # Memory cleanup
1313
+ # gc.collect()
1314
+
1315
+ # if batches_processed >= steps_per_epoch:
1316
+ # break
1317
+
1318
+ # if is_tqdm_train and train_pbar:
1319
+ # train_pbar.close()
1320
+
1321
+ # # --- Validation Phase ---
1322
+ # val_loss_avg = tf.keras.metrics.Mean()
1323
+ # val_batches_processed = 0
1324
+
1325
+ # try:
1326
+ # val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
1327
+ # is_tqdm_val = True
1328
+ # except ImportError:
1329
+ # val_pbar = None
1330
+ # is_tqdm_val = False
1331
+ # logger.info("Validation progress bar disabled")
1332
+
1333
+ # for q_batch, p_batch, n_batch in val_dataset:
1334
+ # #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
1335
+ # val_loss = self.validation_step(q_batch, p_batch, n_batch)
1336
+ # val_loss_avg(val_loss)
1337
+ # val_batches_processed += 1
1338
+
1339
+ # if is_tqdm_val:
1340
+ # val_pbar.update(1)
1341
+ # val_pbar.set_postfix({
1342
+ # "val_loss": f"{val_loss.numpy():.4f}",
1343
+ # "batches": f"{val_batches_processed}/{val_steps}"
1344
+ # })
1345
+
1346
+ # # Memory cleanup
1347
+ # gc.collect()
1348
+
1349
+
1350
+ # if val_batches_processed >= val_steps:
1351
+ # break
1352
+
1353
+ # if is_tqdm_val and val_pbar:
1354
+ # val_pbar.close()
1355
+
1356
+ # # End of epoch: compute final epoch stats, log, and save checkpoint
1357
+ # train_loss = epoch_loss_avg.result().numpy()
1358
+ # val_loss = val_loss_avg.result().numpy()
1359
+ # logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
1360
+
1361
+ # # Log epoch metrics
1362
+ # with train_summary_writer.as_default():
1363
+ # tf.summary.scalar("epoch_loss", train_loss, step=epoch)
1364
+ # with val_summary_writer.as_default():
1365
+ # tf.summary.scalar("val_loss", val_loss, step=epoch)
1366
+
1367
+ # # Save checkpoint
1368
+ # manager.save()
1369
+
1370
+ # # Store metrics in history
1371
+ # self.history['train_loss'].append(train_loss)
1372
+ # self.history['val_loss'].append(val_loss)
1373
+
1374
+ # if use_lr_schedule:
1375
+ # current_lr = float(lr_schedule(self.optimizer.iterations))
1376
+ # else:
1377
+ # current_lr = float(self.optimizer.learning_rate.numpy())
1378
+
1379
+ # self.history.setdefault('learning_rate', []).append(current_lr)
1380
+
1381
+ # # Early stopping logic
1382
+ # if val_loss < best_val_loss - min_delta:
1383
+ # best_val_loss = val_loss
1384
+ # epochs_no_improve = 0
1385
+ # logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
1386
+ # else:
1387
+ # epochs_no_improve += 1
1388
+ # logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
1389
+ # if epochs_no_improve >= early_stopping_patience:
1390
+ # logger.info("Early stopping triggered.")
1391
+ # break
1392
+
1393
+ # logger.info("Streaming training completed!")
1394
+
1395
+
1396
+ # @tf.function
1397
+ # def train_step(
1398
+ # self,
1399
+ # q_batch: tf.Tensor,
1400
+ # p_batch: tf.Tensor,
1401
+ # n_batch: tf.Tensor,
1402
+ # attention_mask: Optional[tf.Tensor] = None
1403
+ # ) -> tf.Tensor:
1404
+ # """
1405
+ # Single training step that uses queries, positives, and negatives in a
1406
+ # contrastive/InfoNCE style. The label is always 0 (the positive) vs.
1407
+ # the negative alternatives.
1408
+ # """
1409
+ # with tf.GradientTape() as tape:
1410
+ # # Encode queries
1411
+ # q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
1412
+
1413
+ # # Encode positives
1414
+ # p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
1415
+
1416
+ # # Encode negatives
1417
+ # # n_batch: [batch_size, neg_samples, max_length]
1418
+ # shape = tf.shape(n_batch)
1419
+ # bs = shape[0]
1420
+ # neg_samples = shape[1]
1421
+
1422
+ # # Flatten negatives to feed them in one pass:
1423
+ # # => [batch_size * neg_samples, max_length]
1424
+ # n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
1425
+ # n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
1426
+
1427
+ # # Reshape back => [batch_size, neg_samples, embed_dim]
1428
+ # n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
1429
+
1430
+ # # Combine the positive embedding and negative embeddings along dim=1
1431
+ # # => shape [batch_size, 1 + neg_samples, embed_dim]
1432
+ # # The first column is the positive; subsequent columns are negatives
1433
+ # combined_p_n = tf.concat(
1434
+ # [tf.expand_dims(p_enc, axis=1), n_enc],
1435
+ # axis=1
1436
+ # ) # [bs, (1+neg_samples), embed_dim]
1437
+
1438
+ # # Now compute scores: dot product of q_enc with each column in combined_p_n
1439
+ # # We'll use `tf.einsum` to handle the batch dimension properly
1440
+ # # dot_products => shape [batch_size, (1+neg_samples)]
1441
+ # dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
1442
+
1443
+ # # The label for each row is 0 (the first column is the correct/positive)
1444
+ # labels = tf.zeros([bs], dtype=tf.int32)
1445
+
1446
+ # # Cross-entropy over the [batch_size, 1+neg_samples] scores
1447
+ # loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1448
+ # labels=labels,
1449
+ # logits=dot_products
1450
+ # )
1451
+ # loss = tf.reduce_mean(loss)
1452
+
1453
+ # # If there's an attention_mask you want to apply (less common in this scenario),
1454
+ # # you could do something like:
1455
+ # if attention_mask is not None:
1456
+ # loss = loss * attention_mask
1457
+ # loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1458
+
1459
+ # # Apply gradients
1460
+ # gradients = tape.gradient(loss, self.encoder.trainable_variables)
1461
+ # self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
1462
+ # return loss
1463
+
1464
+ # @tf.function
1465
+ # def validation_step(
1466
+ # self,
1467
+ # q_batch: tf.Tensor,
1468
+ # p_batch: tf.Tensor,
1469
+ # n_batch: tf.Tensor,
1470
+ # attention_mask: Optional[tf.Tensor] = None
1471
+ # ) -> tf.Tensor:
1472
+ # """
1473
+ # Single validation step with queries, positives, and negatives.
1474
+ # Uses the same loss calculation as train_step, but `training=False`.
1475
+ # """
1476
+ # q_enc = self.encoder(q_batch, training=False)
1477
+ # p_enc = self.encoder(p_batch, training=False)
1478
+
1479
+ # shape = tf.shape(n_batch)
1480
+ # bs = shape[0]
1481
+ # neg_samples = shape[1]
1482
+
1483
+ # n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
1484
+ # n_enc_flat = self.encoder(n_batch_flat, training=False)
1485
+ # n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
1486
+
1487
+ # combined_p_n = tf.concat(
1488
+ # [tf.expand_dims(p_enc, axis=1), n_enc],
1489
+ # axis=1
1490
+ # )
1491
+
1492
+ # dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
1493
+ # labels = tf.zeros([bs], dtype=tf.int32)
1494
+
1495
+ # loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1496
+ # labels=labels,
1497
+ # logits=dot_products
1498
+ # )
1499
+ # loss = tf.reduce_mean(loss)
1500
+
1501
+ # if attention_mask is not None:
1502
+ # loss = loss * attention_mask
1503
+ # loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1504
+
1505
+ # return loss
1506
 
1507
  def _get_lr_schedule(
1508
  self,
run_model_train.py CHANGED
@@ -39,7 +39,7 @@ def main():
39
 
40
  DEBUG_SAMPLES = 5
41
  EPOCHS = 5 if DEBUG_SAMPLES else 20
42
- TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
43
 
44
  # Optimize batch size for Colab
45
  batch_size = env.optimize_batch_size(base_batch_size=16)
@@ -49,20 +49,21 @@ def main():
49
  embedding_dim=768, # DistilBERT
50
  max_context_token_limit=512,
51
  freeze_embeddings=False,
 
52
  )
53
 
54
- # Load training data
55
- dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
56
- print(dialogues)
57
 
58
  # Initialize chatbot and verify FAISS index
59
  #with env.strategy.scope():
60
- chatbot = RetrievalChatbot(config, dialogues)
61
  chatbot.build_models()
62
  chatbot.verify_faiss_index()
63
 
64
  chatbot.train_streaming(
65
- dialogues=dialogues,
66
  epochs=EPOCHS,
67
  batch_size=batch_size,
68
  use_lr_schedule=True,
 
39
 
40
  DEBUG_SAMPLES = 5
41
  EPOCHS = 5 if DEBUG_SAMPLES else 20
42
+ TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
43
 
44
  # Optimize batch size for Colab
45
  batch_size = env.optimize_batch_size(base_batch_size=16)
 
49
  embedding_dim=768, # DistilBERT
50
  max_context_token_limit=512,
51
  freeze_embeddings=False,
52
+ neg_samples=3,
53
  )
54
 
55
+ # # Load training data
56
+ # dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
57
+ # print(dialogues)
58
 
59
  # Initialize chatbot and verify FAISS index
60
  #with env.strategy.scope():
61
+ chatbot = RetrievalChatbot(config)
62
  chatbot.build_models()
63
  chatbot.verify_faiss_index()
64
 
65
  chatbot.train_streaming(
66
+ tfrecord_file_path=TF_RECORD_FILE_PATH,
67
  epochs=EPOCHS,
68
  batch_size=batch_size,
69
  use_lr_schedule=True,
tf_data_pipeline.py CHANGED
@@ -689,7 +689,31 @@ class TFDataPipeline:
689
 
690
  return q_ids, p_ids, n_ids
691
 
692
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
  # def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
695
  # """Find hard negatives for a batch of queries with error handling and retries."""
 
689
 
690
  return q_ids, p_ids, n_ids
691
 
692
+ # def parse_tfrecord_fn(example_proto, max_length, neg_samples):
693
+ # """
694
+ # Parses a single TFRecord example.
695
+
696
+ # Args:
697
+ # example_proto: A serialized TFRecord example.
698
+ # max_length: The maximum sequence length for tokenization.
699
+ # neg_samples: The number of hard negatives per query.
700
+
701
+ # Returns:
702
+ # A tuple of (query_ids, positive_ids, negative_ids).
703
+ # """
704
+ # feature_description = {
705
+ # 'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
706
+ # 'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
707
+ # 'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
708
+ # }
709
+ # parsed_features = tf.io.parse_single_example(example_proto, feature_description)
710
+
711
+ # query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
712
+ # positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
713
+ # negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
714
+ # negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
715
+
716
+ # return query_ids, positive_ids, negative_ids
717
 
718
  # def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
719
  # """Find hard negatives for a batch of queries with error handling and retries."""