JoeArmani commited on
Commit
2183656
·
1 Parent(s): 775baf9

upgrade to tf-dataset

Browse files
chatbot_model.py CHANGED
@@ -2,8 +2,6 @@ import time
2
  from transformers import TFAutoModel, AutoTokenizer
3
  import tensorflow as tf
4
  import numpy as np
5
- import threading
6
- from queue import Queue, Empty
7
  from typing import Generator, List, Tuple, Dict, Optional, Union, Any
8
  import math
9
  from dataclasses import dataclass
@@ -12,7 +10,6 @@ from pathlib import Path
12
  import datetime
13
  import faiss
14
  import gc
15
- import random
16
  from response_quality_checker import ResponseQualityChecker
17
  from cross_encoder_reranker import CrossEncoderReranker
18
  from conversation_summarizer import DeviceAwareModel, Summarizer
@@ -29,7 +26,7 @@ class ChatbotConfig:
29
  """Configuration for the RetrievalChatbot."""
30
  vocab_size: int = 30526 # DistilBERT vocab size + special tokens
31
  max_context_token_limit: int = 512
32
- embedding_dim: int = 512
33
  encoder_units: int = 256
34
  num_attention_heads: int = 8
35
  dropout_rate: float = 0.2
@@ -42,6 +39,7 @@ class ChatbotConfig:
42
  pretrained_model: str = 'distilbert-base-uncased'
43
  dtype: str = 'float32'
44
  freeze_embeddings: bool = False
 
45
  # Additional configurations can be added here
46
 
47
  def to_dict(self) -> dict:
@@ -103,9 +101,9 @@ class EncoderModel(tf.keras.Model):
103
 
104
  # Apply pooling, projection, dropout, and normalization
105
  x = self.pooler(x) # Shape: [batch_size, 768]
106
- x = self.projection(x) # Shape: [batch_size, 512]
107
  x = self.dropout(x, training=training) # Apply dropout
108
- x = self.normalize(x) # Shape: [batch_size, 512]
109
 
110
  return x
111
 
@@ -139,17 +137,6 @@ class RetrievalChatbot(DeviceAwareModel):
139
  summarizer = Summarizer(device=self.device)
140
  self.summarizer = summarizer
141
 
142
- # # Configure XLA optimization if on GPU/TPU
143
- # if self.device in ["GPU", "TPU"]:
144
- # tf.config.optimizer.set_jit(True)
145
- # logger.info(f"XLA compilation enabled for {self.device}")
146
-
147
- # # Configure mixed precision for GPU/TPU
148
- # if self.device != "CPU":
149
- # policy = tf.keras.mixed_precision.Policy('mixed_float16')
150
- # tf.keras.mixed_precision.set_global_policy(policy)
151
- # logger.info("Mixed precision training enabled (float16)")
152
-
153
  # Special tokens
154
  self.special_tokens = {
155
  "user": "<USER>",
@@ -354,6 +341,9 @@ class RetrievalChatbot(DeviceAwareModel):
354
  """
355
  all_embeddings = []
356
  self.current_batch_size = batch_size
 
 
 
357
 
358
  # Memory stats
359
  # if self.memory_monitor.has_gpu:
@@ -541,9 +531,9 @@ class RetrievalChatbot(DeviceAwareModel):
541
  logger.info("Starting vector addition process...")
542
 
543
  # Even smaller batches
544
- initial_batch_size = 50 # Start smaller
545
- min_batch_size = 10
546
- max_batch_size = 500 # Lower maximum
547
 
548
  total_added = 0
549
  retry_count = 0
@@ -572,7 +562,6 @@ class RetrievalChatbot(DeviceAwareModel):
572
  # Update progress
573
  batch_size = len(batch)
574
  total_added += batch_size
575
- #logger.info(f"Added batch of {batch_size} vectors ({total_added}/{len(response_embeddings)} total)")
576
 
577
  # Memory cleanup every few batches
578
  if total_added % (initial_batch_size * 5) == 0:
@@ -618,7 +607,7 @@ class RetrievalChatbot(DeviceAwareModel):
618
  cpu_index = self.index
619
 
620
  # Add remaining vectors on CPU with very small batches
621
- batch_size = 50 # Extremely conservative batch size for CPU
622
  total_added = already_added
623
 
624
  for i in range(0, len(remaining_embeddings), batch_size):
@@ -911,36 +900,33 @@ class RetrievalChatbot(DeviceAwareModel):
911
  warmup_steps_ratio: float = 0.1,
912
  early_stopping_patience: int = 3,
913
  min_delta: float = 1e-4,
914
- buffer_size: int = 10,
915
  neg_samples: int = 1
916
  ) -> None:
917
- """
918
- Streaming version of training that interleaves training/val batches by
919
- giving priority to training until we meet `steps_per_epoch`, then
920
- sending leftover batches to validation.
921
- """
922
- logger.info("Starting streaming training pipeline...")
923
 
924
- # Initialize dataset preparer
925
- dataset_preparer = StreamingDataPipeline(
 
926
  tokenizer=self.tokenizer,
927
  encoder=self.encoder,
928
- index=self.index,
929
  response_pool=self.response_pool,
930
  max_length=self.config.max_context_token_limit,
931
- batch_size=batch_size,
932
  neg_samples=neg_samples
933
  )
934
 
935
  # Calculate total steps for learning rate schedule
936
  total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
937
- train_size = total_pairs * (1 - validation_split)
 
938
  steps_per_epoch = int(math.ceil(train_size / batch_size))
939
- val_steps = int(math.ceil((total_pairs * validation_split) / batch_size))
940
  total_steps = steps_per_epoch * epochs
941
 
942
  logger.info(f"Total pairs: {total_pairs}")
943
  logger.info(f"Training pairs: {train_size}")
 
944
  logger.info(f"Steps per epoch: {steps_per_epoch}")
945
  logger.info(f"Validation steps: {val_steps}")
946
  logger.info(f"Total steps: {total_steps}")
@@ -971,276 +957,245 @@ class RetrievalChatbot(DeviceAwareModel):
971
  val_log_dir = str(log_dir / f"val_{current_time}")
972
  train_summary_writer = tf.summary.create_file_writer(train_log_dir)
973
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
974
-
975
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
976
 
 
 
 
 
977
  # Training loop
978
  best_val_loss = float("inf")
979
  epochs_no_improve = 0
980
 
981
- try:
982
- epoch_pbar = tqdm(range(1, epochs + 1), desc="Training", unit="epoch")
983
- is_tqdm_epoch = True
984
- except ImportError:
985
- epoch_pbar = range(1, epochs + 1)
986
- is_tqdm_epoch = False
987
- logger.info("Epoch progress bar disabled - continuing without visual progress")
988
-
989
- for epoch in epoch_pbar:
990
- # Shared queues for streaming pipeline
991
- train_queue = Queue(maxsize=buffer_size)
992
- val_queue = Queue(maxsize=buffer_size)
993
- stop_flag = threading.Event()
994
-
995
- def data_pipeline_worker():
996
- """Thread function that processes dialogues and sends batches to train or val."""
997
- try:
998
- train_batches_needed = steps_per_epoch # 9 in your logs
999
- val_batches_needed = val_steps # 3 in your logs
1000
- train_batches_sent = 0
1001
- val_batches_sent = 0
1002
 
1003
- logger.info(f"Pipeline starting: need {train_batches_needed} train batches, {val_batches_needed} val batches")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
 
1005
- # Possibly shuffle your processed pairs to avoid repeating them in the same order
1006
- # (If you haven't already done so in the pipeline)
1007
- random.shuffle(dataset_preparer.processed_pairs)
 
 
1008
 
1009
- while (train_batches_sent < train_batches_needed or
1010
- val_batches_sent < val_batches_needed):
1011
-
1012
- # We loop over the generator
1013
- for batch in dataset_preparer.process_dialogues(dialogues):
1014
- if stop_flag.is_set():
1015
- logger.warning("Pipeline stopped early")
1016
- break
1017
-
1018
- if train_batches_sent < train_batches_needed:
1019
- train_queue.put(batch)
1020
- train_batches_sent += 1
1021
- elif val_batches_sent < val_batches_needed:
1022
- val_queue.put(batch)
1023
- val_batches_sent += 1
1024
- else:
1025
- # We have enough batches for both train & val
1026
- break
1027
-
1028
- # If we still haven't met our target steps, REPEAT the data
1029
- if train_batches_sent < train_batches_needed or val_batches_sent < val_batches_needed:
1030
- logger.info("Data exhausted, repeating since we still need more batches...")
1031
- # Optionally shuffle again
1032
- random.shuffle(dataset_preparer.processed_pairs)
1033
- else:
1034
- # We have enough
1035
- break
1036
 
1037
- logger.info(
1038
- f"Pipeline complete: sent {train_batches_sent}/{train_batches_needed} train batches, "
1039
- f"{val_batches_sent}/{val_batches_needed} val batches"
1040
- )
1041
 
1042
- except Exception as e:
1043
- logger.error(f"Error in pipeline worker: {str(e)}")
1044
- raise e
1045
- finally:
1046
- train_queue.put(None)
1047
- val_queue.put(None)
1048
 
1049
- # Start data preparation pipeline in background thread
1050
- pipeline_thread = threading.Thread(target=data_pipeline_worker)
1051
- pipeline_thread.start()
1052
 
1053
  try:
1054
- # --- Training Phase ---
1055
- epoch_loss_avg = tf.keras.metrics.Mean()
1056
- batches_processed = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
 
1058
- try:
1059
- train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}")
1060
- is_tqdm_train = True
1061
- except ImportError:
1062
- train_pbar = None
1063
- is_tqdm_train = False
1064
- logger.info("Training progress bar disabled")
1065
-
1066
- while batches_processed < steps_per_epoch:
1067
- try:
1068
- batch = train_queue.get(timeout=1200) # 20 minutes timeout
1069
- if batch is None:
1070
- logger.warning(f"Received end signal after only {batches_processed}/{steps_per_epoch} batches")
1071
- break
1072
-
1073
- q_batch, p_batch = batch[0], batch[1]
1074
- attention_mask = batch[2] if len(batch) > 2 else None
1075
-
1076
- loss = self.train_step(q_batch, p_batch, attention_mask)
1077
- epoch_loss_avg(loss)
1078
- batches_processed += 1
1079
-
1080
- # Log to TensorBoard
1081
- with train_summary_writer.as_default():
1082
- tf.summary.scalar("loss", loss, step=epoch)
1083
-
1084
- # Update progress bar
1085
- if use_lr_schedule:
1086
- current_lr = float(lr_schedule(self.optimizer.iterations))
1087
- else:
1088
- current_lr = float(self.optimizer.learning_rate.numpy())
1089
 
1090
- if is_tqdm_train:
1091
- train_pbar.update(1)
1092
- train_pbar.set_postfix({
1093
- "loss": f"{loss.numpy():.4f}",
1094
- "lr": f"{current_lr:.2e}",
1095
- "batches": f"{batches_processed}/{steps_per_epoch}"
1096
- })
1097
 
1098
- except Empty:
1099
- logger.warning(f"Queue timeout after {batches_processed}/{steps_per_epoch} batches")
1100
- break
1101
 
1102
- if is_tqdm_train and train_pbar:
1103
- train_pbar.close()
 
 
1104
 
1105
- # --- Validation Phase ---
1106
- val_loss_avg = tf.keras.metrics.Mean()
1107
- val_batches_processed = 0
 
 
1108
 
1109
- try:
1110
- val_pbar = tqdm(total=val_steps, desc="Validation")
1111
- is_tqdm_val = True
1112
- except ImportError:
1113
- val_pbar = None
1114
- is_tqdm_val = False
1115
- logger.info("Validation progress bar disabled")
1116
-
1117
- while val_batches_processed < val_steps:
1118
- try:
1119
- batch = val_queue.get(timeout=30)
1120
- if batch is None:
1121
- logger.warning(
1122
- f"Received end signal after {val_batches_processed}/{val_steps} validation batches"
1123
- )
1124
- break
1125
-
1126
- q_batch, p_batch = batch[0], batch[1]
1127
- attention_mask = batch[2] if len(batch) > 2 else None
1128
-
1129
- val_loss = self.validation_step(q_batch, p_batch, attention_mask)
1130
- val_loss_avg(val_loss)
1131
- val_batches_processed += 1
1132
-
1133
- if is_tqdm_val:
1134
- val_pbar.update(1)
1135
- val_pbar.set_postfix({
1136
- "val_loss": f"{val_loss.numpy():.4f}",
1137
- "batches": f"{val_batches_processed}/{val_steps}"
1138
- })
1139
-
1140
- except Empty:
1141
- logger.warning(
1142
- f"Validation queue timeout after {val_batches_processed}/{val_steps} batches"
1143
- )
1144
- break
1145
-
1146
- if is_tqdm_val and val_pbar:
1147
- val_pbar.close()
1148
-
1149
- # End of epoch: compute final epoch stats
1150
- train_loss = epoch_loss_avg.result().numpy()
1151
- val_loss = val_loss_avg.result().numpy()
1152
- logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
1153
-
1154
- # Log epoch metrics
1155
- with val_summary_writer.as_default():
1156
- tf.summary.scalar("val_loss", val_loss, step=epoch)
1157
-
1158
- # Save checkpoint
1159
- manager.save()
1160
-
1161
- # Store metrics in history
1162
- self.history['train_loss'].append(train_loss)
1163
- self.history['val_loss'].append(val_loss)
1164
 
1165
- if use_lr_schedule:
1166
- current_lr = float(lr_schedule(self.optimizer.iterations))
1167
- else:
1168
- current_lr = float(self.optimizer.learning_rate.numpy())
1169
 
1170
- self.history.setdefault('learning_rate', []).append(current_lr)
 
 
 
1171
 
1172
- # Early stopping logic
1173
- if val_loss < best_val_loss - min_delta:
1174
- best_val_loss = val_loss
1175
- epochs_no_improve = 0
1176
- logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
1177
- else:
1178
- epochs_no_improve += 1
1179
- logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
1180
- if epochs_no_improve >= early_stopping_patience:
1181
- logger.info("Early stopping triggered.")
1182
- break
1183
 
1184
- except Exception as e:
1185
- logger.error(f"Error during training: {str(e)}")
1186
- stop_flag.set()
1187
- raise e
1188
- finally:
1189
- # Clean up epoch resources
1190
- stop_flag.set()
1191
- pipeline_thread.join()
 
 
 
1192
 
1193
  logger.info("Streaming training completed!")
1194
 
1195
 
1196
  @tf.function
1197
- def train_step(self, q_batch: tf.Tensor, p_batch: tf.Tensor, attention_mask: Optional[tf.Tensor] = None) -> tf.Tensor:
1198
- """Single training step with tf.function optimization and partial batch handling."""
 
 
 
 
 
 
 
 
 
 
1199
  with tf.GradientTape() as tape:
1200
- q_enc = self.encoder(q_batch, training=True)
1201
- p_enc = self.encoder(p_batch, training=True)
1202
-
1203
- sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
1204
-
1205
- # Handle partial batches
1206
- batch_size = tf.shape(q_enc)[0]
1207
- labels = tf.range(batch_size, dtype=tf.int32)
1208
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1209
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1210
- labels=labels, logits=sim_matrix
 
1211
  )
1212
-
1213
- # If there's an attention mask, apply it
 
 
1214
  if attention_mask is not None:
1215
  loss = loss * attention_mask
1216
- # normalize by the sum of attention_mask
1217
  loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1218
- else:
1219
- loss = tf.reduce_mean(loss)
1220
 
 
1221
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
1222
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
1223
  return loss
1224
 
1225
  @tf.function
1226
- def validation_step(self, q_batch: tf.Tensor, p_batch: tf.Tensor, attention_mask: Optional[tf.Tensor] = None) -> tf.Tensor:
1227
- """Single validation step with partial batch handling."""
 
 
 
 
 
 
 
 
 
1228
  q_enc = self.encoder(q_batch, training=False)
1229
  p_enc = self.encoder(p_batch, training=False)
1230
 
1231
- sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
1232
- batch_size = tf.shape(q_enc)[0]
1233
- labels = tf.range(batch_size, dtype=tf.int32)
 
 
 
 
 
 
 
 
 
 
 
 
1234
 
1235
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1236
- labels=labels, logits=sim_matrix
 
1237
  )
1238
-
 
1239
  if attention_mask is not None:
1240
  loss = loss * attention_mask
1241
  loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1242
- else:
1243
- loss = tf.reduce_mean(loss)
1244
 
1245
  return loss
1246
 
@@ -1382,235 +1337,33 @@ class RetrievalChatbot(DeviceAwareModel):
1382
  conversation_parts.append(f"{self.special_tokens['user']} {query}")
1383
  return "\n".join(conversation_parts)
1384
 
1385
- class StreamingDataPipeline:
1386
- """Helper class to manage the streaming data preparation pipeline with optimized caching and GPU usage."""
1387
  def __init__(
1388
- self,
1389
- tokenizer,
1390
- encoder,
1391
- index,
1392
- response_pool,
1393
- max_length: int,
1394
- batch_size: int,
1395
- neg_samples: int
1396
  ):
 
1397
  self.tokenizer = tokenizer
1398
  self.encoder = encoder
1399
- self.index = index
1400
  self.response_pool = response_pool
1401
  self.max_length = max_length
1402
- self.base_batch_size = batch_size
1403
  self.neg_samples = neg_samples
 
 
 
1404
  self.memory_monitor = GPUMemoryMonitor()
1405
-
1406
- # Caching structures
1407
- self.hard_negatives_cache = {}
1408
- self.processed_pairs = []
1409
- self.query_embeddings_cache = {}
1410
-
1411
- # Error tracking
1412
- self.error_count = 0
1413
  self.max_retries = 3
1414
-
1415
- # Batch processing settings
1416
- self.current_batch_size = batch_size
1417
- self.batch_increase_factor = 1.25
1418
-
1419
- # TODO: use GPU/strategy
1420
- if len(response_pool) < 100:
1421
- self.embedding_batch_size = 16
1422
- self.search_batch_size = 16
1423
- self.max_batch_size = 32
1424
- self.min_batch_size = 4
1425
- else:
1426
- self.embedding_batch_size = 64
1427
- self.search_batch_size = 64
1428
- self.min_batch_size = 8
1429
- self.max_batch_size = 64
1430
-
1431
- def save_cache(self, cache_dir: Path) -> None:
1432
- """Save all cached data for future runs."""
1433
- cache_dir = Path(cache_dir)
1434
- cache_dir.mkdir(parents=True, exist_ok=True)
1435
-
1436
- logger.info(f"Saving cache to {cache_dir}")
1437
-
1438
- # Save embeddings cache
1439
- embeddings_path = cache_dir / "query_embeddings.npy"
1440
- np.save(
1441
- embeddings_path,
1442
- {k: v.numpy() if hasattr(v, 'numpy') else v
1443
- for k, v in self.query_embeddings_cache.items()}
1444
- )
1445
-
1446
- # Save hard negatives and processed pairs
1447
- with open(cache_dir / "hard_negatives.json", 'w') as f:
1448
- json.dump(self.hard_negatives_cache, f)
1449
-
1450
- with open(cache_dir / "processed_pairs.json", 'w') as f:
1451
- json.dump(self.processed_pairs, f)
1452
-
1453
- logger.info("Cache saved successfully")
1454
-
1455
- def load_cache(self, cache_dir: Path) -> bool:
1456
- """Load cached data if available."""
1457
- cache_dir = Path(cache_dir)
1458
- required_files = [
1459
- "query_embeddings.npy",
1460
- "hard_negatives.json",
1461
- "processed_pairs.json"
1462
- ]
1463
-
1464
- if not all((cache_dir / f).exists() for f in required_files):
1465
- logger.info("Cache files not found")
1466
- return False
1467
-
1468
- try:
1469
- logger.info("Loading cache...")
1470
-
1471
- # Load embeddings
1472
- self.query_embeddings_cache = np.load(
1473
- cache_dir / "query_embeddings.npy",
1474
- allow_pickle=True
1475
- ).item()
1476
-
1477
- # Load other caches
1478
- with open(cache_dir / "hard_negatives.json", 'r') as f:
1479
- self.hard_negatives_cache = json.load(f)
1480
-
1481
- with open(cache_dir / "processed_pairs.json", 'r') as f:
1482
- self.processed_pairs = json.load(f)
1483
-
1484
- logger.info(f"Cache loaded successfully: {len(self.processed_pairs)} pairs")
1485
- return True
1486
-
1487
- except Exception as e:
1488
- logger.error(f"Error loading cache: {e}")
1489
- return False
1490
 
1491
- def _adjust_batch_size(self) -> None:
1492
- """Dynamically adjust batch size based on GPU memory usage."""
1493
- if self.memory_monitor:
1494
- if self.memory_monitor.should_reduce_batch_size():
1495
- new_size = max(self.min_batch_size, self.current_batch_size // 2)
1496
- if new_size != self.current_batch_size:
1497
- if new_size < self.min_batch_size:
1498
- logger.info(f"Reducing batch size to {new_size} due to high memory usage")
1499
- self.current_batch_size = new_size
1500
- gc.collect()
1501
- if tf.config.list_physical_devices('GPU'):
1502
- tf.keras.backend.clear_session()
1503
-
1504
- elif self.memory_monitor.can_increase_batch_size():
1505
- new_size = min(self.max_batch_size, int(self.current_batch_size * self.batch_increase_factor)) # More gradual increase
1506
- if new_size != self.current_batch_size:
1507
- if new_size > self.max_batch_size:
1508
- logger.info(f"Increasing batch size to {new_size}")
1509
- self.current_batch_size = new_size
1510
-
1511
- def _add_progress_metrics(self, pbar, **metrics) -> None:
1512
- """Add memory and batch size metrics to progress bars."""
1513
- if self.memory_monitor:
1514
- gpu_usage = self.memory_monitor.get_memory_usage()
1515
- metrics['gpu_mem'] = f"{gpu_usage:.1%}"
1516
- metrics['batch_size'] = self.current_batch_size
1517
- pbar.set_postfix(**metrics)
1518
-
1519
- def preprocess_dialogues(self, dialogues: List[dict]) -> None:
1520
- """Preprocess all dialogues with error recovery and caching."""
1521
- retry_count = 0
1522
-
1523
- while retry_count < self.max_retries:
1524
- try:
1525
- self._preprocess_dialogues_internal(dialogues)
1526
- break
1527
- except Exception as e:
1528
- retry_count += 1
1529
- logger.warning(f"Preprocessing attempt {retry_count} failed: {e}")
1530
- if retry_count == self.max_retries:
1531
- logger.error("Max retries reached. Falling back to CPU processing")
1532
- self._fallback_to_cpu_processing(dialogues)
1533
-
1534
- def _preprocess_dialogues_internal(self, dialogues: List[dict]) -> None:
1535
- """Internal preprocessing implementation with progress tracking."""
1536
- logger.info("Starting dialogue preprocessing...")
1537
-
1538
- # Collect unique queries and pairs
1539
- unique_queries = set()
1540
- query_positive_pairs = []
1541
-
1542
- with tqdm(total=len(dialogues), desc="Collecting dialogue pairs") as pbar:
1543
- for dialogue in dialogues:
1544
- pairs = self._extract_pairs_from_dialogue(dialogue)
1545
- for query, positive in pairs:
1546
- unique_queries.add(query)
1547
- query_positive_pairs.append((query, positive))
1548
- pbar.update(1)
1549
- self._add_progress_metrics(pbar, pairs=len(query_positive_pairs))
1550
-
1551
- # Precompute embeddings
1552
- logger.info("Precomputing query embeddings...")
1553
- self.precompute_query_embeddings(list(unique_queries))
1554
-
1555
- # Find hard negatives
1556
- logger.info("Finding hard negatives for all pairs...")
1557
- self._find_hard_negatives_for_pairs(query_positive_pairs)
1558
 
1559
- def precompute_query_embeddings(self, queries: List[str]) -> None:
1560
- """Precompute embeddings for all unique queries in batches."""
1561
- unique_queries = list(set(queries))
1562
-
1563
- with tqdm(total=len(unique_queries), desc="Precomputing query embeddings") as pbar:
1564
- for i in range(0, len(unique_queries), self.embedding_batch_size):
1565
- # Adjust batch size based on memory
1566
- self._adjust_batch_size()
1567
- batch_size = min(self.embedding_batch_size, len(unique_queries) - i)
1568
-
1569
- # Get batch of queries
1570
- batch_queries = unique_queries[i:i + batch_size]
1571
-
1572
- try:
1573
- # Tokenize batch
1574
- encoded = self.tokenizer(
1575
- batch_queries,
1576
- padding=True,
1577
- truncation=True,
1578
- max_length=self.max_length,
1579
- return_tensors='tf'
1580
- )
1581
-
1582
- # Get embeddings
1583
- embeddings = self.encoder(encoded['input_ids'], training=False)
1584
- embeddings_np = embeddings.numpy().astype('float32')
1585
-
1586
- # Normalize for similarity search
1587
- faiss.normalize_L2(embeddings_np)
1588
-
1589
- # Cache embeddings
1590
- for query, emb in zip(batch_queries, embeddings_np):
1591
- self.query_embeddings_cache[query] = emb
1592
-
1593
- pbar.update(len(batch_queries))
1594
- self._add_progress_metrics(
1595
- pbar,
1596
- cached=len(self.query_embeddings_cache),
1597
- batch_size=batch_size
1598
- )
1599
-
1600
- except Exception as e:
1601
- logger.warning(f"Error processing batch: {e}")
1602
- # Reduce batch size and retry
1603
- self.embedding_batch_size = max(self.min_batch_size, self.embedding_batch_size // 2)
1604
- continue
1605
-
1606
- # Memory cleanup after successful batch
1607
- if i % (self.embedding_batch_size * 10) == 0:
1608
- gc.collect()
1609
- if tf.config.list_physical_devices('GPU'):
1610
- tf.keras.backend.clear_session()
1611
-
1612
- logger.info(f"Cached embeddings for {len(self.query_embeddings_cache)} unique queries")
1613
-
1614
  def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
1615
  """Extract query-response pairs from a dialogue."""
1616
  pairs = []
@@ -1631,305 +1384,474 @@ class StreamingDataPipeline:
1631
 
1632
  return pairs
1633
 
1634
- def _fallback_to_cpu_processing(self, dialogues: List[dict]) -> None:
1635
- """Fallback processing method using CPU only."""
1636
- logger.info("Falling back to CPU-only processing")
1637
- # Reset GPU-specific settings
1638
- self.current_batch_size = self.min_batch_size
1639
- self.embedding_batch_size = 32
1640
- self.search_batch_size = 16
1641
-
1642
- # Attempt preprocessing with reduced batches
1643
- self._preprocess_dialogues_internal(dialogues)
1644
-
1645
- def process_dialogues(self, dialogues: List[dict]) -> Generator[Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]], None, None]:
1646
- """
1647
- Process dialogues using cached data with dynamic batch sizing.
1648
- Yields (q_tokens['input_ids'], p_tokens['input_ids'], attention_mask) tuples.
1649
- """
1650
- # Preprocess if not already done
1651
- if not self.processed_pairs:
1652
- self.preprocess_dialogues(dialogues)
1653
-
1654
- # Generate batches from cached data
1655
- current_queries = []
1656
- current_positives = []
1657
-
1658
- # Counters for logging
1659
- total_examples_yielded = 0
1660
- total_batches_yielded = 0
1661
-
1662
- with tqdm(total=len(self.processed_pairs), desc="Generating training batches", leave=False) as pbar:
1663
- for i, (query, positive) in enumerate(self.processed_pairs):
1664
- # Periodically adjust batch size
1665
- if i % 10 == 0: # Check more frequently (e.g., every 10 pairs)
1666
- self._adjust_batch_size()
1667
-
1668
- # Add original pair
1669
- current_queries.append(query)
1670
- current_positives.append(positive)
1671
-
1672
- # Add cached hard negatives for each query
1673
- hard_negatives = self.hard_negatives_cache.get((query, positive), [])
1674
- for neg_text in hard_negatives:
1675
- current_queries.append(query)
1676
- current_positives.append(neg_text)
1677
-
1678
- # If we have enough examples to form a full batch, yield it
1679
- while len(current_queries) >= self.current_batch_size:
1680
- batch_queries = current_queries[:self.current_batch_size]
1681
- batch_positives = current_positives[:self.current_batch_size]
1682
-
1683
- # Update counters and logs
1684
- batch_size_to_yield = len(batch_queries)
1685
- total_examples_yielded += batch_size_to_yield
1686
- total_batches_yielded += 1
1687
-
1688
- yield self._prepare_batch(batch_queries, batch_positives, pad_to_batch_size=False)
1689
-
1690
- # Remove used entries
1691
- current_queries = current_queries[self.current_batch_size:]
1692
- current_positives = current_positives[self.current_batch_size:]
1693
-
1694
- # Update progress bar
1695
- pbar.update(1)
1696
- self._add_progress_metrics(
1697
- pbar,
1698
- pairs_processed=pbar.n,
1699
- pending_pairs=len(current_queries)
1700
- )
1701
-
1702
- # After the loop, if anything is left, yield a final partial batch
1703
- if current_queries:
1704
- leftover_size = len(current_queries)
1705
- total_examples_yielded += leftover_size
1706
- total_batches_yielded += 1
1707
-
1708
- yield self._prepare_batch(
1709
- current_queries,
1710
- current_positives,
1711
- pad_to_batch_size=True
1712
- )
1713
-
1714
- def _find_hard_negatives_for_pairs(self, query_positive_pairs: List[Tuple[str, str]]) -> None:
1715
- """Process pairs in batches to find hard negatives with GPU acceleration."""
1716
- total_pairs = len(query_positive_pairs)
1717
-
1718
- # Use smaller batch size for small datasets
1719
- if len(self.response_pool) < 1000:
1720
- batch_size = min(8, self.search_batch_size)
1721
- else:
1722
- batch_size = self.search_batch_size
1723
-
1724
- try:
1725
- pbar = tqdm(total=total_pairs, desc="Finding hard negatives")
1726
- is_tqdm = True
1727
- except ImportError:
1728
- pbar = None
1729
- is_tqdm = False
1730
- logger.info("Progress bar disabled - continuing without visual progress")
1731
-
1732
- for i in range(0, total_pairs, batch_size):
1733
- self._adjust_batch_size()
1734
-
1735
- batch_pairs = query_positive_pairs[i:i + batch_size]
1736
- batch_queries, batch_positives = zip(*batch_pairs)
1737
-
1738
- batch_negatives = self._find_hard_negatives_batch(
1739
- list(batch_queries),
1740
- list(batch_positives)
1741
- )
1742
-
1743
- for query, positive, negatives in zip(batch_queries, batch_positives, batch_negatives):
1744
- self.hard_negatives_cache[(query, positive)] = negatives
1745
- self.processed_pairs.append((query, positive))
1746
-
1747
- if is_tqdm:
1748
- pbar.update(len(batch_pairs))
1749
- self._add_progress_metrics(
1750
- pbar,
1751
- cached=len(self.processed_pairs),
1752
- progress=f"{i+len(batch_pairs)}/{total_pairs}"
1753
- )
1754
-
1755
- if is_tqdm:
1756
- pbar.close()
1757
-
1758
  def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
1759
  """Find hard negatives for a batch of queries with error handling and retries."""
1760
  retry_count = 0
1761
  total_responses = len(self.response_pool)
1762
-
1763
- # For very small datasets (testing), just use random sampling
1764
- if total_responses < 100:
1765
- all_negatives = []
1766
- for positive in positives:
1767
- available = [r for r in self.response_pool if r.strip() != positive.strip()]
1768
- if available:
1769
- negatives = list(np.random.choice(
1770
- available,
1771
- size=min(self.neg_samples, len(available)),
1772
- replace=False
1773
- ))
1774
- else:
1775
- negatives = []
1776
- # Pad with empty strings if needed
1777
- while len(negatives) < self.neg_samples:
1778
- negatives.append("")
1779
- all_negatives.append(negatives)
1780
- return all_negatives
1781
-
1782
  while retry_count < self.max_retries:
1783
  try:
1784
- # Get cached embeddings and ensure they're the right type
1785
  query_embeddings = np.vstack([
1786
  self.query_embeddings_cache[q] for q in queries
1787
  ]).astype(np.float32)
1788
 
1789
- if not query_embeddings.flags['C_CONTIGUOUS']:
1790
- query_embeddings = np.ascontiguousarray(query_embeddings)
1791
-
1792
- # Normalize embeddings
1793
  faiss.normalize_L2(query_embeddings)
1794
-
1795
- k = 1 #min(total_responses - 1, max(3, self.neg_samples + 2))
1796
  #logger.debug(f"Searching with k={k} among {total_responses} responses")
1797
-
1798
- assert query_embeddings.dtype == np.float32, f"Embeddings are not float32: {query_embeddings.dtype}" # Assertion here
1799
 
1800
- try:
1801
- distances, indices = self.index.search(query_embeddings, k)
1802
- except RuntimeError as e:
1803
- logger.error(f"FAISS search failed: {e}")
1804
- return self._fallback_random_negatives(queries, positives)
1805
-
1806
- # Process results
1807
  all_negatives = []
1808
- for i, (query_indices, query, positive) in enumerate(zip(indices, queries, positives)):
1809
  negatives = []
1810
  positive_strip = positive.strip()
1811
-
1812
- # Filter valid indices and deduplicate
1813
  seen = {positive_strip}
 
1814
  for idx in query_indices:
1815
  if idx >= 0 and idx < total_responses:
1816
  candidate = self.response_pool[idx].strip()
1817
- if candidate and candidate not in seen: # Check for non-empty strings
1818
  seen.add(candidate)
1819
  negatives.append(candidate)
1820
  if len(negatives) >= self.neg_samples:
1821
  break
1822
-
1823
- # If we don't have enough negatives, use random sampling from remaining pool
1824
- if len(negatives) < self.neg_samples:
1825
- available = [r for r in self.response_pool if r.strip() not in seen and r.strip()]
1826
- if available:
1827
- additional = np.random.choice(
1828
- available,
1829
- size=min(self.neg_samples - len(negatives), len(available)),
1830
- replace=False
1831
- )
1832
- negatives.extend(additional)
1833
-
1834
- # Still pad with empty strings if needed
1835
  while len(negatives) < self.neg_samples:
1836
- negatives.append("")
1837
-
1838
  all_negatives.append(negatives)
1839
-
1840
  return all_negatives
1841
-
1842
  except Exception as e:
1843
  retry_count += 1
1844
  logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
1845
  if retry_count == self.max_retries:
1846
  logger.error("Max retries reached for hard negative search")
1847
- return [[] for _ in queries] # Return empty lists on complete failure
1848
  gc.collect()
1849
  if tf.config.list_physical_devices('GPU'):
1850
  tf.keras.backend.clear_session()
1851
-
1852
- def _fallback_random_negatives(self, queries: List[str], positives: List[str]) -> List[List[str]]:
1853
- """Fallback to random sampling when similarity search fails."""
1854
- logger.warning("Falling back to random negative sampling")
1855
- all_negatives = []
1856
- for positive in positives:
1857
- available = [r for r in self.response_pool if r.strip() != positive.strip()]
1858
- negatives = list(np.random.choice(
1859
- available,
1860
- size=min(self.neg_samples, len(available)),
1861
- replace=False
1862
- )) if available else []
1863
- while len(negatives) < self.neg_samples:
1864
- negatives.append("")
1865
- all_negatives.append(negatives)
1866
- return all_negatives
1867
-
1868
- def _prepare_batch(
1869
- self,
1870
- queries: List[str],
1871
- positives: List[str],
1872
- pad_to_batch_size: bool = False
1873
- ) -> Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]]:
1874
- """Prepare a batch with dynamic padding and memory optimization."""
1875
- actual_size = len(queries)
1876
-
1877
- # Handle padding if requested and needed
1878
- if pad_to_batch_size and actual_size < self.current_batch_size:
1879
- padding_needed = self.current_batch_size - actual_size
1880
- queries.extend([queries[0]] * padding_needed)
1881
- positives.extend([positives[0]] * padding_needed)
1882
- # Create attention mask for padded examples
1883
- attention_mask = tf.concat([
1884
- tf.ones((actual_size,), dtype=tf.float32),
1885
- tf.zeros((padding_needed,), dtype=tf.float32)
1886
- ], axis=0)
1887
- else:
1888
- attention_mask = None
1889
 
1890
- try:
1891
- # Tokenize batch
1892
- q_tokens = self.tokenizer(
1893
- queries,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1894
  padding='max_length',
1895
  truncation=True,
1896
  max_length=self.max_length,
1897
  return_tensors='tf'
1898
  )
1899
- p_tokens = self.tokenizer(
1900
- positives,
1901
- padding='max_length',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1902
  truncation=True,
1903
  max_length=self.max_length,
1904
  return_tensors='tf'
1905
  )
1906
 
1907
- return q_tokens['input_ids'], p_tokens['input_ids'], attention_mask
1908
-
1909
- except Exception as e:
1910
- logger.error(f"Error preparing batch: {e}")
1911
- # Emergency memory cleanup
1912
- gc.collect()
1913
- if tf.config.list_physical_devices('GPU'):
1914
- tf.keras.backend.clear_session()
1915
- raise
1916
 
1917
- def estimate_total_pairs(self, dialogues: List[dict]) -> int:
1918
- """Estimate total number of training pairs including hard negatives."""
1919
- base_pairs = sum(
1920
- len([
1921
- 1 for i in range(len(d.get('turns', [])) - 1)
1922
- if (d['turns'][i].get('speaker') == 'user' and
1923
- d['turns'][i+1].get('speaker') == 'assistant')
1924
- ])
1925
- for d in dialogues
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1926
  )
1927
- # Account for hard negatives
1928
- return base_pairs * (1 + self.neg_samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1929
 
1930
- def cleanup(self):
1931
- """Cleanup resources and memory."""
1932
- self.query_embeddings_cache.clear()
1933
- gc.collect()
1934
- if tf.config.list_physical_devices('GPU'):
1935
- tf.keras.backend.clear_session()
 
2
  from transformers import TFAutoModel, AutoTokenizer
3
  import tensorflow as tf
4
  import numpy as np
 
 
5
  from typing import Generator, List, Tuple, Dict, Optional, Union, Any
6
  import math
7
  from dataclasses import dataclass
 
10
  import datetime
11
  import faiss
12
  import gc
 
13
  from response_quality_checker import ResponseQualityChecker
14
  from cross_encoder_reranker import CrossEncoderReranker
15
  from conversation_summarizer import DeviceAwareModel, Summarizer
 
26
  """Configuration for the RetrievalChatbot."""
27
  vocab_size: int = 30526 # DistilBERT vocab size + special tokens
28
  max_context_token_limit: int = 512
29
+ embedding_dim: int = 768
30
  encoder_units: int = 256
31
  num_attention_heads: int = 8
32
  dropout_rate: float = 0.2
 
39
  pretrained_model: str = 'distilbert-base-uncased'
40
  dtype: str = 'float32'
41
  freeze_embeddings: bool = False
42
+ embedding_batch_size: int = 128
43
  # Additional configurations can be added here
44
 
45
  def to_dict(self) -> dict:
 
101
 
102
  # Apply pooling, projection, dropout, and normalization
103
  x = self.pooler(x) # Shape: [batch_size, 768]
104
+ x = self.projection(x) # Shape: [batch_size, 768]
105
  x = self.dropout(x, training=training) # Apply dropout
106
+ x = self.normalize(x) # Shape: [batch_size, 768]
107
 
108
  return x
109
 
 
137
  summarizer = Summarizer(device=self.device)
138
  self.summarizer = summarizer
139
 
 
 
 
 
 
 
 
 
 
 
 
140
  # Special tokens
141
  self.special_tokens = {
142
  "user": "<USER>",
 
341
  """
342
  all_embeddings = []
343
  self.current_batch_size = batch_size
344
+
345
+ if self.memory_monitor.has_gpu:
346
+ batch_size = 128
347
 
348
  # Memory stats
349
  # if self.memory_monitor.has_gpu:
 
531
  logger.info("Starting vector addition process...")
532
 
533
  # Even smaller batches
534
+ initial_batch_size = 128
535
+ min_batch_size = 32
536
+ max_batch_size = 1024
537
 
538
  total_added = 0
539
  retry_count = 0
 
562
  # Update progress
563
  batch_size = len(batch)
564
  total_added += batch_size
 
565
 
566
  # Memory cleanup every few batches
567
  if total_added % (initial_batch_size * 5) == 0:
 
607
  cpu_index = self.index
608
 
609
  # Add remaining vectors on CPU with very small batches
610
+ batch_size = 128
611
  total_added = already_added
612
 
613
  for i in range(0, len(remaining_embeddings), batch_size):
 
900
  warmup_steps_ratio: float = 0.1,
901
  early_stopping_patience: int = 3,
902
  min_delta: float = 1e-4,
 
903
  neg_samples: int = 1
904
  ) -> None:
905
+ """Streaming training with tf.data pipeline."""
906
+ logger.info("Starting streaming training pipeline with tf.data...")
 
 
 
 
907
 
908
+ # Initialize TFDataPipeline (replaces StreamingDataPipeline)
909
+ dataset_preparer = TFDataPipeline(
910
+ embedding_batch_size=self.config.embedding_batch_size,
911
  tokenizer=self.tokenizer,
912
  encoder=self.encoder,
913
+ index=self.index, # Pass CPU version of FAISS index
914
  response_pool=self.response_pool,
915
  max_length=self.config.max_context_token_limit,
 
916
  neg_samples=neg_samples
917
  )
918
 
919
  # Calculate total steps for learning rate schedule
920
  total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
921
+ train_size = int(total_pairs * (1 - validation_split))
922
+ val_size = int(total_pairs * validation_split)
923
  steps_per_epoch = int(math.ceil(train_size / batch_size))
924
+ val_steps = int(math.ceil(val_size / batch_size))
925
  total_steps = steps_per_epoch * epochs
926
 
927
  logger.info(f"Total pairs: {total_pairs}")
928
  logger.info(f"Training pairs: {train_size}")
929
+ logger.info(f"Validation pairs: {val_size}")
930
  logger.info(f"Steps per epoch: {steps_per_epoch}")
931
  logger.info(f"Validation steps: {val_steps}")
932
  logger.info(f"Total steps: {total_steps}")
 
957
  val_log_dir = str(log_dir / f"val_{current_time}")
958
  train_summary_writer = tf.summary.create_file_writer(train_log_dir)
959
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
 
960
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
961
 
962
+ # Create training and validation datasets
963
+ train_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).take(train_size)
964
+ val_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).skip(train_size).take(val_size)
965
+
966
  # Training loop
967
  best_val_loss = float("inf")
968
  epochs_no_improve = 0
969
 
970
+ for epoch in range(1, epochs + 1):
971
+ # --- Training Phase ---
972
+ epoch_loss_avg = tf.keras.metrics.Mean()
973
+ batches_processed = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
974
 
975
+ try:
976
+ train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch")
977
+ is_tqdm_train = True
978
+ except ImportError:
979
+ train_pbar = None
980
+ is_tqdm_train = False
981
+ logger.info("Training progress bar disabled")
982
+
983
+ for q_batch, p_batch, n_batch in train_dataset:
984
+ #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
985
+ loss = self.train_step(q_batch, p_batch, n_batch)
986
+ epoch_loss_avg(loss)
987
+ batches_processed += 1
988
+
989
+ # Log to TensorBoard
990
+ with train_summary_writer.as_default():
991
+ tf.summary.scalar("loss", loss, step=(epoch - 1) * steps_per_epoch + batches_processed)
992
 
993
+ # Update progress bar
994
+ if use_lr_schedule:
995
+ current_lr = float(lr_schedule(self.optimizer.iterations))
996
+ else:
997
+ current_lr = float(self.optimizer.learning_rate.numpy())
998
 
999
+ if is_tqdm_train:
1000
+ train_pbar.update(1)
1001
+ train_pbar.set_postfix({
1002
+ "loss": f"{loss.numpy():.4f}",
1003
+ "lr": f"{current_lr:.2e}",
1004
+ "batches": f"{batches_processed}/{steps_per_epoch}"
1005
+ })
1006
+
1007
+ # Memory cleanup
1008
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1009
 
1010
+ if batches_processed >= steps_per_epoch:
1011
+ break
 
 
1012
 
1013
+ if is_tqdm_train and train_pbar:
1014
+ train_pbar.close()
 
 
 
 
1015
 
1016
+ # --- Validation Phase ---
1017
+ val_loss_avg = tf.keras.metrics.Mean()
1018
+ val_batches_processed = 0
1019
 
1020
  try:
1021
+ val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
1022
+ is_tqdm_val = True
1023
+ except ImportError:
1024
+ val_pbar = None
1025
+ is_tqdm_val = False
1026
+ logger.info("Validation progress bar disabled")
1027
+
1028
+ for q_batch, p_batch, n_batch in val_dataset:
1029
+ #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
1030
+ val_loss = self.validation_step(q_batch, p_batch, n_batch)
1031
+ val_loss_avg(val_loss)
1032
+ val_batches_processed += 1
1033
+
1034
+ if is_tqdm_val:
1035
+ val_pbar.update(1)
1036
+ val_pbar.set_postfix({
1037
+ "val_loss": f"{val_loss.numpy():.4f}",
1038
+ "batches": f"{val_batches_processed}/{val_steps}"
1039
+ })
1040
+
1041
+ # Memory cleanup
1042
+ gc.collect()
1043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1044
 
1045
+ if val_batches_processed >= val_steps:
1046
+ break
 
 
 
 
 
1047
 
1048
+ if is_tqdm_val and val_pbar:
1049
+ val_pbar.close()
 
1050
 
1051
+ # End of epoch: compute final epoch stats, log, and save checkpoint
1052
+ train_loss = epoch_loss_avg.result().numpy()
1053
+ val_loss = val_loss_avg.result().numpy()
1054
+ logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
1055
 
1056
+ # Log epoch metrics
1057
+ with train_summary_writer.as_default():
1058
+ tf.summary.scalar("epoch_loss", train_loss, step=epoch)
1059
+ with val_summary_writer.as_default():
1060
+ tf.summary.scalar("val_loss", val_loss, step=epoch)
1061
 
1062
+ # Save checkpoint
1063
+ manager.save()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1064
 
1065
+ # Store metrics in history
1066
+ self.history['train_loss'].append(train_loss)
1067
+ self.history['val_loss'].append(val_loss)
 
1068
 
1069
+ if use_lr_schedule:
1070
+ current_lr = float(lr_schedule(self.optimizer.iterations))
1071
+ else:
1072
+ current_lr = float(self.optimizer.learning_rate.numpy())
1073
 
1074
+ self.history.setdefault('learning_rate', []).append(current_lr)
 
 
 
 
 
 
 
 
 
 
1075
 
1076
+ # Early stopping logic
1077
+ if val_loss < best_val_loss - min_delta:
1078
+ best_val_loss = val_loss
1079
+ epochs_no_improve = 0
1080
+ logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
1081
+ else:
1082
+ epochs_no_improve += 1
1083
+ logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
1084
+ if epochs_no_improve >= early_stopping_patience:
1085
+ logger.info("Early stopping triggered.")
1086
+ break
1087
 
1088
  logger.info("Streaming training completed!")
1089
 
1090
 
1091
  @tf.function
1092
+ def train_step(
1093
+ self,
1094
+ q_batch: tf.Tensor,
1095
+ p_batch: tf.Tensor,
1096
+ n_batch: tf.Tensor,
1097
+ attention_mask: Optional[tf.Tensor] = None
1098
+ ) -> tf.Tensor:
1099
+ """
1100
+ Single training step that uses queries, positives, and negatives in a
1101
+ contrastive/InfoNCE style. The label is always 0 (the positive) vs.
1102
+ the negative alternatives.
1103
+ """
1104
  with tf.GradientTape() as tape:
1105
+ # Encode queries
1106
+ q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
1107
+
1108
+ # Encode positives
1109
+ p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
1110
+
1111
+ # Encode negatives
1112
+ # n_batch: [batch_size, neg_samples, max_length]
1113
+ shape = tf.shape(n_batch)
1114
+ bs = shape[0]
1115
+ neg_samples = shape[1]
1116
+
1117
+ # Flatten negatives to feed them in one pass:
1118
+ # => [batch_size * neg_samples, max_length]
1119
+ n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
1120
+ n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
1121
+
1122
+ # Reshape back => [batch_size, neg_samples, embed_dim]
1123
+ n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
1124
+
1125
+ # Combine the positive embedding and negative embeddings along dim=1
1126
+ # => shape [batch_size, 1 + neg_samples, embed_dim]
1127
+ # The first column is the positive; subsequent columns are negatives
1128
+ combined_p_n = tf.concat(
1129
+ [tf.expand_dims(p_enc, axis=1), n_enc],
1130
+ axis=1
1131
+ ) # [bs, (1+neg_samples), embed_dim]
1132
+
1133
+ # Now compute scores: dot product of q_enc with each column in combined_p_n
1134
+ # We'll use `tf.einsum` to handle the batch dimension properly
1135
+ # dot_products => shape [batch_size, (1+neg_samples)]
1136
+ dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
1137
+
1138
+ # The label for each row is 0 (the first column is the correct/positive)
1139
+ labels = tf.zeros([bs], dtype=tf.int32)
1140
+
1141
+ # Cross-entropy over the [batch_size, 1+neg_samples] scores
1142
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1143
+ labels=labels,
1144
+ logits=dot_products
1145
  )
1146
+ loss = tf.reduce_mean(loss)
1147
+
1148
+ # If there's an attention_mask you want to apply (less common in this scenario),
1149
+ # you could do something like:
1150
  if attention_mask is not None:
1151
  loss = loss * attention_mask
 
1152
  loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
 
 
1153
 
1154
+ # Apply gradients
1155
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
1156
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
1157
  return loss
1158
 
1159
  @tf.function
1160
+ def validation_step(
1161
+ self,
1162
+ q_batch: tf.Tensor,
1163
+ p_batch: tf.Tensor,
1164
+ n_batch: tf.Tensor,
1165
+ attention_mask: Optional[tf.Tensor] = None
1166
+ ) -> tf.Tensor:
1167
+ """
1168
+ Single validation step with queries, positives, and negatives.
1169
+ Uses the same loss calculation as train_step, but `training=False`.
1170
+ """
1171
  q_enc = self.encoder(q_batch, training=False)
1172
  p_enc = self.encoder(p_batch, training=False)
1173
 
1174
+ shape = tf.shape(n_batch)
1175
+ bs = shape[0]
1176
+ neg_samples = shape[1]
1177
+
1178
+ n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
1179
+ n_enc_flat = self.encoder(n_batch_flat, training=False)
1180
+ n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
1181
+
1182
+ combined_p_n = tf.concat(
1183
+ [tf.expand_dims(p_enc, axis=1), n_enc],
1184
+ axis=1
1185
+ )
1186
+
1187
+ dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
1188
+ labels = tf.zeros([bs], dtype=tf.int32)
1189
 
1190
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1191
+ labels=labels,
1192
+ logits=dot_products
1193
  )
1194
+ loss = tf.reduce_mean(loss)
1195
+
1196
  if attention_mask is not None:
1197
  loss = loss * attention_mask
1198
  loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
 
 
1199
 
1200
  return loss
1201
 
 
1337
  conversation_parts.append(f"{self.special_tokens['user']} {query}")
1338
  return "\n".join(conversation_parts)
1339
 
1340
+ class TFDataPipeline:
 
1341
  def __init__(
1342
+ self,
1343
+ embedding_batch_size,
1344
+ tokenizer,
1345
+ encoder,
1346
+ index,
1347
+ response_pool,
1348
+ max_length: int,
1349
+ neg_samples: int,
1350
  ):
1351
+ self.embedding_batch_size = embedding_batch_size
1352
  self.tokenizer = tokenizer
1353
  self.encoder = encoder
1354
+ self.index = index # CPU version of the index
1355
  self.response_pool = response_pool
1356
  self.max_length = max_length
 
1357
  self.neg_samples = neg_samples
1358
+ self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
1359
+ self.search_batch_size = 8 if len(response_pool) < 100 else 32
1360
+ self.max_batch_size = 32 if len(response_pool) < 100 else 256
1361
  self.memory_monitor = GPUMemoryMonitor()
 
 
 
 
 
 
 
 
1362
  self.max_retries = 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1363
 
1364
+ # In-memory cache for embeddings
1365
+ self.query_embeddings_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
  def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
1368
  """Extract query-response pairs from a dialogue."""
1369
  pairs = []
 
1384
 
1385
  return pairs
1386
 
1387
+ def estimate_total_pairs(self, dialogues: List[dict]) -> int:
1388
+ """Estimate total number of training pairs including hard negatives."""
1389
+ base_pairs = sum(
1390
+ len([
1391
+ 1 for i in range(len(d.get('turns', [])) - 1)
1392
+ if (d['turns'][i].get('speaker') == 'user' and
1393
+ d['turns'][i+1].get('speaker') == 'assistant')
1394
+ ])
1395
+ for d in dialogues
1396
+ )
1397
+ # Account for hard negatives
1398
+ return base_pairs * (1 + self.neg_samples)
1399
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1400
  def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
1401
  """Find hard negatives for a batch of queries with error handling and retries."""
1402
  retry_count = 0
1403
  total_responses = len(self.response_pool)
1404
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1405
  while retry_count < self.max_retries:
1406
  try:
 
1407
  query_embeddings = np.vstack([
1408
  self.query_embeddings_cache[q] for q in queries
1409
  ]).astype(np.float32)
1410
 
1411
+ query_embeddings = np.ascontiguousarray(query_embeddings)
 
 
 
1412
  faiss.normalize_L2(query_embeddings)
1413
+
1414
+ k = 1 # TODO: try higher k for better results
1415
  #logger.debug(f"Searching with k={k} among {total_responses} responses")
 
 
1416
 
1417
+ distances, indices = self.index.search(query_embeddings, k)
1418
+
 
 
 
 
 
1419
  all_negatives = []
1420
+ for query_indices, query, positive in zip(indices, queries, positives):
1421
  negatives = []
1422
  positive_strip = positive.strip()
 
 
1423
  seen = {positive_strip}
1424
+
1425
  for idx in query_indices:
1426
  if idx >= 0 and idx < total_responses:
1427
  candidate = self.response_pool[idx].strip()
1428
+ if candidate and candidate not in seen:
1429
  seen.add(candidate)
1430
  negatives.append(candidate)
1431
  if len(negatives) >= self.neg_samples:
1432
  break
1433
+
1434
+ # Pad with a special empty negative if necessary
 
 
 
 
 
 
 
 
 
 
 
1435
  while len(negatives) < self.neg_samples:
1436
+ negatives.append("<EMPTY_NEGATIVE>") # Use a special token
1437
+
1438
  all_negatives.append(negatives)
1439
+
1440
  return all_negatives
1441
+
1442
  except Exception as e:
1443
  retry_count += 1
1444
  logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
1445
  if retry_count == self.max_retries:
1446
  logger.error("Max retries reached for hard negative search")
1447
+ return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries] # Return empty negatives for all queries
1448
  gc.collect()
1449
  if tf.config.list_physical_devices('GPU'):
1450
  tf.keras.backend.clear_session()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1451
 
1452
+ def _tokenize_negatives_tf(self, negatives):
1453
+ """Tokenizes negatives using tf.py_function."""
1454
+ # Handle the case where negatives is an empty tensor
1455
+ if tf.size(negatives) == 0:
1456
+ return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
1457
+
1458
+ # Convert EagerTensor to a list of strings
1459
+ negatives_list = []
1460
+ for neg_list in negatives.numpy():
1461
+ decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
1462
+ negatives_list.append(decoded_negs)
1463
+
1464
+ # Flatten the list of lists
1465
+ flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
1466
+
1467
+ # Tokenize the flattened negatives
1468
+ if flattened_negatives:
1469
+ n_tokens = self.tokenizer(
1470
+ flattened_negatives,
1471
  padding='max_length',
1472
  truncation=True,
1473
  max_length=self.max_length,
1474
  return_tensors='tf'
1475
  )
1476
+ # Reshape the tokens
1477
+ n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
1478
+ return n_tokens_reshaped
1479
+ else:
1480
+ return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
1481
+
1482
+ def _compute_embeddings(self, queries: List[str]) -> None:
1483
+ """Computes and caches embeddings for new queries."""
1484
+ new_queries = [q for q in queries if q not in self.query_embeddings_cache]
1485
+ if not new_queries:
1486
+ return # All queries already cached
1487
+
1488
+ new_embeddings = []
1489
+ for i in range(0, len(new_queries), self.embedding_batch_size):
1490
+ batch_queries = new_queries[i:i + self.embedding_batch_size]
1491
+
1492
+ encoded = self.tokenizer(
1493
+ batch_queries,
1494
+ padding=True,
1495
  truncation=True,
1496
  max_length=self.max_length,
1497
  return_tensors='tf'
1498
  )
1499
 
1500
+ # Compute embeddings on CPU
1501
+ with tf.device('/CPU:0'):
1502
+ batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
 
 
 
 
 
 
1503
 
1504
+ new_embeddings.extend(batch_embeddings)
1505
+
1506
+ # Update cache with new embeddings
1507
+ for query, emb in zip(new_queries, new_embeddings):
1508
+ self.query_embeddings_cache[query] = emb
1509
+
1510
+ def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]:
1511
+ """
1512
+ Generates training examples: (query, positive, hard_negatives).
1513
+ Wrapped the outer loop with tqdm for progress tracking.
1514
+ """
1515
+ total_dialogues = len(dialogues)
1516
+ logger.debug(f"Total dialogues to process: {total_dialogues}")
1517
+
1518
+ # Initialize tqdm progress bar
1519
+ with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
1520
+ for dialogue in dialogues:
1521
+ pairs = self._extract_pairs_from_dialogue(dialogue)
1522
+ for query, positive in pairs:
1523
+ # Ensure embeddings are computed, find hard negatives, etc.
1524
+ self._compute_embeddings([query])
1525
+ hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
1526
+ yield (query, positive, hard_negatives)
1527
+ pbar.update(1)
1528
+
1529
+ def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
1530
+ """Prepares a batch of data for training."""
1531
+
1532
+ # Convert EagerTensors to lists of strings
1533
+ queries_list = [query.decode("utf-8") for query in queries.numpy()]
1534
+ positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
1535
+
1536
+ # Tokenize queries and positives
1537
+ q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
1538
+ p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
1539
+
1540
+ # Decode negatives and ensure they are lists of strings
1541
+ negatives_list = []
1542
+ for neg_list in negatives.numpy():
1543
+ decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
1544
+ negatives_list.append(decoded_negs)
1545
+
1546
+ # Flatten negatives for tokenization if there are any valid negatives
1547
+ flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
1548
+
1549
+ # Tokenize negatives if there are any
1550
+ n_tokens_reshaped = None
1551
+ if flattened_negatives:
1552
+ n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
1553
+
1554
+ # Reshape n_tokens to match the expected shape based on the number of negatives per query
1555
+ # This part may need adjustment if the number of negatives varies per query
1556
+ n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
1557
+ else:
1558
+ # Create a placeholder tensor for the case where there are no negatives
1559
+ n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
1560
+
1561
+ # Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
1562
+ # Adjust shape to [batch_size, num_neg_samples, max_length]
1563
+ if n_tokens_reshaped.shape[1] != self.neg_samples:
1564
+ # Pad or truncate the second dimension to match neg_samples
1565
+ padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
1566
+ n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
1567
+ n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
1568
+
1569
+ # Concatenate the positive and negative examples along the 'neg_samples' dimension
1570
+ combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
1571
+
1572
+ return q_tokens['input_ids'], combined_p_n_tokens
1573
+
1574
+ def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
1575
+ """
1576
+ Creates a tf.data.Dataset for streaming training that yields
1577
+ (input_ids_query, input_ids_positive, input_ids_negatives).
1578
+ """
1579
+ # 1) Start with a generator dataset
1580
+ dataset = tf.data.Dataset.from_generator(
1581
+ lambda: self.data_generator(dialogues),
1582
+ output_signature=(
1583
+ tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
1584
+ tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
1585
+ tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
1586
+ )
1587
  )
1588
+
1589
+ # 2) Batch the raw strings
1590
+ dataset = dataset.batch(batch_size)
1591
+
1592
+ # 3) Now map them through a tokenize step (via py_function)
1593
+ dataset = dataset.map(
1594
+ lambda q, p, n: self._tokenize_triple(q, p, n),
1595
+ num_parallel_calls=1 #tf.data.AUTOTUNE
1596
+ )
1597
+
1598
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
1599
+ return dataset
1600
+
1601
+ def _tokenize_triple(
1602
+ self,
1603
+ q: tf.Tensor,
1604
+ p: tf.Tensor,
1605
+ n: tf.Tensor
1606
+ ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1607
+ """
1608
+ Wraps a Python function via tf.py_function to convert tf.Tensors of strings
1609
+ -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
1610
+
1611
+ q is shape [batch_size], p is shape [batch_size],
1612
+ n is shape [batch_size, neg_samples] (i.e., each row is a list of negatives).
1613
+ """
1614
+ # Use tf.py_function with limited parallelism
1615
+ q_ids, p_ids, n_ids = tf.py_function(
1616
+ func=self._tokenize_triple_py,
1617
+ inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
1618
+ Tout=[tf.int32, tf.int32, tf.int32]
1619
+ )
1620
+
1621
+ # Manually set shape information
1622
+ q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
1623
+ p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
1624
+ n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
1625
+
1626
+ return q_ids, p_ids, n_ids
1627
+ # def _tokenize_triple(
1628
+ # self,
1629
+ # q: tf.Tensor,
1630
+ # p: tf.Tensor,
1631
+ # n: tf.Tensor
1632
+ # ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1633
+ # """
1634
+ # Wraps a Python function via tf.py_function to convert tf.Tensors of strings
1635
+ # -> Python lists of strings -> HF tokenizer -> Tensors of IDs.
1636
+
1637
+ # q is shape [batch_size], p is shape [batch_size],
1638
+ # n is shape [batch_size, None] (i.e. each row is a variable number of negatives).
1639
+ # """
1640
+ # # Use tf.py_function
1641
+ # # We pass in self.max_length as well, so we can do it in one shot.
1642
+ # q_ids, p_ids, n_ids = tf.py_function(
1643
+ # func=self._tokenize_triple_py,
1644
+ # inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
1645
+ # Tout=[tf.int32, tf.int32, tf.int32]
1646
+ # )
1647
+
1648
+ # # We must manually set shape information so that TF data pipeline knows the dimensions
1649
+ # q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
1650
+ # p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
1651
+ # n_ids.set_shape([None, self.neg_samples, self.max_length])
1652
+ # # The negative dimension is set to `self.neg_samples` for consistency.
1653
+
1654
+ # return q_ids, p_ids, n_ids
1655
+
1656
+ def _tokenize_triple_py(
1657
+ self,
1658
+ q: tf.Tensor,
1659
+ p: tf.Tensor,
1660
+ n: tf.Tensor,
1661
+ max_len: tf.Tensor,
1662
+ neg_samples: tf.Tensor
1663
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
1664
+ """
1665
+ Python function that:
1666
+ - Decodes each tf.string Tensor to a Python list of strings
1667
+ - Calls the HF tokenizer
1668
+ - Reshapes negatives
1669
+ - Returns np.array of int32s for (q_ids, p_ids, n_ids).
1670
+
1671
+ q: shape [batch_size], p: shape [batch_size]
1672
+ n: shape [batch_size, neg_samples]
1673
+ max_len: scalar int
1674
+ neg_samples: scalar int
1675
+ """
1676
+ max_len = int(max_len.numpy()) # Convert to Python int
1677
+ neg_samples = int(neg_samples.numpy())
1678
+
1679
+ # 1) Convert Tensors -> Python lists of strings
1680
+ q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
1681
+ p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
1682
+
1683
+ # shape [batch_size, neg_samples], decode each row
1684
+ n_list = []
1685
+ for row in n.numpy():
1686
+ # row is shape [neg_samples], each is a tf.string
1687
+ decoded = [neg.decode("utf-8") for neg in row]
1688
+ n_list.append(decoded)
1689
+
1690
+ # 2) Tokenize queries & positives
1691
+ q_enc = self.tokenizer(
1692
+ q_list,
1693
+ padding="max_length",
1694
+ truncation=True,
1695
+ max_length=max_len,
1696
+ return_tensors="np"
1697
+ )
1698
+ p_enc = self.tokenizer(
1699
+ p_list,
1700
+ padding="max_length",
1701
+ truncation=True,
1702
+ max_length=max_len,
1703
+ return_tensors="np"
1704
+ )
1705
+
1706
+ # 3) Tokenize negatives
1707
+ # Flatten [batch_size, neg_samples] -> single list
1708
+ flattened_negatives = [neg for row in n_list for neg in row]
1709
+ if len(flattened_negatives) == 0:
1710
+ # No negatives at all: return a zero array
1711
+ n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
1712
+ else:
1713
+ n_enc = self.tokenizer(
1714
+ flattened_negatives,
1715
+ padding="max_length",
1716
+ truncation=True,
1717
+ max_length=max_len,
1718
+ return_tensors="np"
1719
+ )
1720
+ # shape [batch_size * neg_samples, max_len]
1721
+ n_input_ids = n_enc["input_ids"]
1722
+
1723
+ # We want to reshape to [batch_size, neg_samples, max_len]
1724
+ # Handle cases where there might be fewer negatives
1725
+ batch_size = len(q_list)
1726
+ n_ids_list = []
1727
+ for i in range(batch_size):
1728
+ start_idx = i * neg_samples
1729
+ end_idx = start_idx + neg_samples
1730
+ row_negs = n_input_ids[start_idx:end_idx]
1731
+
1732
+ # If fewer negatives, pad with zeros
1733
+ if row_negs.shape[0] < neg_samples:
1734
+ deficit = neg_samples - row_negs.shape[0]
1735
+ pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
1736
+ row_negs = np.concatenate([row_negs, pad_arr], axis=0)
1737
+
1738
+ n_ids_list.append(row_negs)
1739
+
1740
+ # stack them -> shape [batch_size, neg_samples, max_len]
1741
+ n_ids = np.stack(n_ids_list, axis=0)
1742
+
1743
+ # 4) Return as np.int32 arrays
1744
+ q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
1745
+ p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
1746
+ n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
1747
+
1748
+ return q_ids, p_ids, n_ids
1749
+ # def _tokenize_triple_py(
1750
+ # self,
1751
+ # q: tf.Tensor,
1752
+ # p: tf.Tensor,
1753
+ # n: tf.Tensor,
1754
+ # max_len: tf.Tensor,
1755
+ # neg_samples: tf.Tensor
1756
+ # ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
1757
+ # """
1758
+ # Python function that:
1759
+ # - Decodes each tf.string Tensor to a Python list of strings
1760
+ # - Calls the HF tokenizer
1761
+ # - Reshapes negatives
1762
+ # - Returns np.array of int32s for (q_ids, p_ids, n_ids).
1763
+
1764
+ # q: shape [batch_size], p: shape [batch_size]
1765
+ # n: shape [batch_size, None]
1766
+ # max_len: scalar int
1767
+ # neg_samples: scalar int
1768
+ # """
1769
+ # max_len = int(max_len.numpy()) # convert to python int
1770
+ # neg_samples = int(neg_samples.numpy())
1771
+
1772
+ # # 1) Convert Tensors -> Python lists of strings
1773
+ # q_list = [q_i.decode("utf-8") for q_i in q.numpy()] # shape [batch_size]
1774
+ # p_list = [p_i.decode("utf-8") for p_i in p.numpy()] # shape [batch_size]
1775
+
1776
+ # # shape [batch_size, variable_negatives], decode each row
1777
+ # n_list = []
1778
+ # for row in n.numpy():
1779
+ # # row is shape [N], each is a tf.string
1780
+ # decoded = [neg.decode("utf-8") for neg in row]
1781
+ # n_list.append(decoded)
1782
+
1783
+ # # 2) Tokenize queries & positives
1784
+ # q_enc = self.tokenizer(
1785
+ # q_list,
1786
+ # padding="max_length",
1787
+ # truncation=True,
1788
+ # max_length=max_len,
1789
+ # return_tensors="np" # you can do return_tensors="tf", but "np" is often simpler here
1790
+ # )
1791
+ # p_enc = self.tokenizer(
1792
+ # p_list,
1793
+ # padding="max_length",
1794
+ # truncation=True,
1795
+ # max_length=max_len,
1796
+ # return_tensors="np"
1797
+ # )
1798
+
1799
+ # # 3) Tokenize negatives
1800
+ # # Flatten [batch_size, variable_negatives] -> single list
1801
+ # flattened_negatives = [neg for row in n_list for neg in row]
1802
+ # if len(flattened_negatives) == 0:
1803
+ # # No negatives at all: return a zero array
1804
+ # n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32)
1805
+ # else:
1806
+ # n_enc = self.tokenizer(
1807
+ # flattened_negatives,
1808
+ # padding="max_length",
1809
+ # truncation=True,
1810
+ # max_length=max_len,
1811
+ # return_tensors="np"
1812
+ # )
1813
+ # # shape [batch_size * total_negatives, max_len]
1814
+ # n_input_ids = n_enc["input_ids"]
1815
+
1816
+ # # We want to reshape to [batch_size, neg_samples, max_len].
1817
+ # # If each row truly has exactly `neg_samples` (or fewer), we can do:
1818
+ # # n_input_ids = n_input_ids.reshape(len(q_list), neg_samples, max_len)
1819
+ # # But if the rows have variable # of negatives, we must clamp or pad.
1820
+ # # For simplicity, let's just "take first neg_samples" per row
1821
+ # # and pad if fewer.
1822
+
1823
+ # # We'll do it row by row:
1824
+ # batch_size = len(q_list)
1825
+ # row_offsets = 0
1826
+ # n_ids_list = []
1827
+ # for row_idx in range(batch_size):
1828
+ # row_negs = n_list[row_idx]
1829
+ # row_count = len(row_negs)
1830
+
1831
+ # # slice from the flattened array
1832
+ # row_slice = n_input_ids[row_offsets:row_offsets + row_count]
1833
+ # row_offsets += row_count
1834
+
1835
+ # # Now pick out up to neg_samples
1836
+ # row_slice = row_slice[:neg_samples]
1837
+
1838
+ # # If fewer than neg_samples, pad
1839
+ # if row_slice.shape[0] < neg_samples:
1840
+ # deficit = neg_samples - row_slice.shape[0]
1841
+ # pad_arr = np.zeros((deficit, max_len), dtype=np.int32)
1842
+ # row_slice = np.concatenate([row_slice, pad_arr], axis=0)
1843
+
1844
+ # # row_slice is now shape [neg_samples, max_len]
1845
+ # n_ids_list.append(row_slice)
1846
+
1847
+ # # stack them -> shape [batch_size, neg_samples, max_len]
1848
+ # n_ids = np.stack(n_ids_list, axis=0)
1849
+
1850
+ # # 4) Return as np.int32 arrays (tokenizer should already return int32,
1851
+ # # but we can cast to be sure)
1852
+ # q_ids = q_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
1853
+ # p_ids = p_enc["input_ids"].astype(np.int32) # shape [batch_size, max_len]
1854
+ # n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
1855
 
1856
+ # return q_ids, p_ids, n_ids
1857
+
 
 
 
 
deduplicate_augmented_dialogues.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ import logging
4
+ from typing import List, Dict
5
+ from collections import defaultdict
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ def load_json_file(file_path: str) -> List[Dict]:
11
+ """Load and parse a JSON file."""
12
+ try:
13
+ with open(file_path, 'r', encoding='utf-8') as f:
14
+ return json.load(f)
15
+ except json.JSONDecodeError as e:
16
+ logger.error(f"Error parsing JSON from {file_path}: {e}")
17
+ return []
18
+ except Exception as e:
19
+ logger.error(f"Error reading file {file_path}: {e}")
20
+ return []
21
+
22
+ def combine_json_files(input_directory: str, output_file: str):
23
+ """
24
+ Combine multiple JSON files while removing duplicates based on dialogue_id.
25
+
26
+ Args:
27
+ input_directory: Directory containing JSON files to process
28
+ output_file: Path to save the combined output
29
+ """
30
+ # Track unique dialogues and their source files
31
+ dialogue_map = {}
32
+ duplicate_count = 0
33
+
34
+ # Process all JSON files in the directory
35
+ input_path = Path(input_directory)
36
+ for json_file in input_path.glob('*.json'):
37
+ logger.info(f"Processing {json_file}")
38
+
39
+ data = load_json_file(str(json_file))
40
+
41
+ # Process each dialogue in the file
42
+ for dialogue in data:
43
+ dialogue_id = dialogue.get('dialogue_id')
44
+
45
+ if not dialogue_id:
46
+ logger.warning(f"Found dialogue without ID in {json_file}")
47
+ continue
48
+
49
+ # Keep the first occurrence
50
+ if dialogue_id in dialogue_map:
51
+ duplicate_count += 1
52
+ logger.debug(f"Duplicate dialogue_id found: {dialogue_id}")
53
+
54
+ else:
55
+ dialogue_map[dialogue_id] = dialogue
56
+
57
+ # Convert the map of unique dialogues back to a list
58
+ unique_dialogues = list(dialogue_map.values())
59
+
60
+ # Save combined dialogues to a new file
61
+ try:
62
+ with open(output_file, 'w', encoding='utf-8') as f:
63
+ json.dump(unique_dialogues, f, indent=4)
64
+ logger.info(f"Successfully combined files. Found {duplicate_count} duplicates.")
65
+ logger.info(f"Total unique dialogues: {len(unique_dialogues)}")
66
+ except Exception as e:
67
+ logger.error(f"Error writing output file: {e}")
68
+
69
+ # Usage example
70
+ if __name__ == "__main__":
71
+ combine_json_files(
72
+ input_directory="/Users/joe/Desktop/Grad School/CSC525/CSC525_mod8_option2_joseph_armani/processed_outputs",
73
+ output_file="augmented_dialogues.json"
74
+ )
run_model_train.py CHANGED
@@ -5,7 +5,6 @@ from response_quality_checker import ResponseQualityChecker
5
  from chatbot_validator import ChatbotValidator
6
  from training_plotter import TrainingPlotter
7
 
8
-
9
  # Configure logging
10
  from logger_config import config_logger
11
  logger = config_logger(__name__)
@@ -38,7 +37,7 @@ def main():
38
  env = EnvironmentSetup()
39
  env.initialize()
40
 
41
- DEBUG_SAMPLES = 15
42
  EPOCHS = 5 if DEBUG_SAMPLES else 20
43
  TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
44
 
@@ -47,13 +46,14 @@ def main():
47
 
48
  # Initialize configuration
49
  config = ChatbotConfig(
50
- embedding_dim=512, # 768, # Match DistilBERT's dimension
51
  max_context_token_limit=512,
52
  freeze_embeddings=False,
53
  )
54
 
55
  # Load training data
56
  dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
 
57
 
58
  # Initialize chatbot and verify FAISS index
59
  #with env.strategy.scope():
 
5
  from chatbot_validator import ChatbotValidator
6
  from training_plotter import TrainingPlotter
7
 
 
8
  # Configure logging
9
  from logger_config import config_logger
10
  logger = config_logger(__name__)
 
37
  env = EnvironmentSetup()
38
  env.initialize()
39
 
40
+ DEBUG_SAMPLES = 5
41
  EPOCHS = 5 if DEBUG_SAMPLES else 20
42
  TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
43
 
 
46
 
47
  # Initialize configuration
48
  config = ChatbotConfig(
49
+ embedding_dim=768, # DistilBERT
50
  max_context_token_limit=512,
51
  freeze_embeddings=False,
52
  )
53
 
54
  # Load training data
55
  dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
56
+ print(dialogues)
57
 
58
  # Initialize chatbot and verify FAISS index
59
  #with env.strategy.scope():