JoeArmani commited on
Commit
9b5daff
·
1 Parent(s): f5346f7

improve split processes

Browse files
Files changed (4) hide show
  1. chatbot_model.py +86 -94
  2. run_data_preparer.py +1 -1
  3. run_model_train.py +31 -10
  4. tf_data_pipeline.py +230 -57
chatbot_model.py CHANGED
@@ -138,12 +138,14 @@ class RetrievalChatbot(DeviceAwareModel):
138
  device: str = None,
139
  strategy=None,
140
  reranker: Optional[CrossEncoderReranker] = None,
141
- summarizer: Optional[Summarizer] = None
 
142
  ):
143
  super().__init__()
144
  self.config = config
145
  self.strategy = strategy
146
  self.device = device or self._setup_default_device()
 
147
 
148
  # Initialize reranker, summarizer, tokenizer, and memory monitor
149
  self.reranker = reranker or self._initialize_reranker()
@@ -151,13 +153,10 @@ class RetrievalChatbot(DeviceAwareModel):
151
  self.tokenizer = self._initialize_tokenizer()
152
  self.memory_monitor = GPUMemoryMonitor()
153
 
154
- # Initialize models
155
- self.min_batch_size = 8
156
- self.max_batch_size = 128
157
- self.current_batch_size = 32
158
-
159
- # Collect unique responses from dialogues
160
- self.response_pool, self.unique_responses = self._collect_responses(dialogues)
161
 
162
  # Initialize training history
163
  self.history = {
@@ -166,6 +165,15 @@ class RetrievalChatbot(DeviceAwareModel):
166
  "train_metrics": {},
167
  "val_metrics": {}
168
  }
 
 
 
 
 
 
 
 
 
169
 
170
  def _setup_default_device(self) -> str:
171
  """Set up default device if none is provided."""
@@ -236,11 +244,14 @@ class RetrievalChatbot(DeviceAwareModel):
236
  self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
237
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
238
 
239
- # Initialize FAISS index
240
- self._initialize_faiss()
241
-
242
- # Compute and index embeddings
243
- self._compute_and_index_embeddings()
 
 
 
244
 
245
  # Retrieve embedding dimension from encoder
246
  embedding_dim = self.config.embedding_dim
@@ -271,10 +282,10 @@ class RetrievalChatbot(DeviceAwareModel):
271
  self.current_batch_size = new_size
272
 
273
  def _initialize_faiss(self):
274
- """Initialize FAISS with safer GPU handling and memory monitoring."""
275
  logger.info("Initializing FAISS index...")
276
 
277
- # First, detect if we have GPU-enabled FAISS
278
  self.faiss_gpu = False
279
  self.gpu_resources = []
280
 
@@ -294,53 +305,30 @@ class RetrievalChatbot(DeviceAwareModel):
294
  self.gpu_resources.append(res)
295
  self.faiss_gpu = True
296
  logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs")
297
- else:
298
- logger.info("Using CPU-only FAISS build")
299
-
300
  except Exception as e:
301
  logger.warning(f"Using CPU due to GPU initialization error: {e}")
302
 
303
- # TODO: figure out buf with faiss-gpu
304
- # TODO: consider IndexIVFFlat in the future (speed).
305
  try:
306
  # Create appropriate index based on dataset size
307
  if len(self.unique_responses) < 1000:
308
  logger.info("Small dataset detected, using simple FlatIP index")
309
  self.index = faiss.IndexFlatIP(self.config.embedding_dim)
310
  else:
311
- # Use IVF index with dynamic number of clusters
312
- # nlist = min(
313
- # 25, # max clusters
314
- # max(int(math.sqrt(len(self.unique_responses))), 1) # min 1 cluster
315
- # )
316
- # logger.info(f"Using IVF index with {nlist} clusters")
317
-
318
- # quantizer = faiss.IndexFlatIP(self.config.embedding_dim)
319
- # self.index = faiss.IndexIVFFlat(
320
- # quantizer,
321
- # self.config.embedding_dim,
322
- # nlist,
323
- # faiss.METRIC_INNER_PRODUCT
324
- # )
325
  self.index = faiss.IndexFlatIP(self.config.embedding_dim)
326
 
327
- # # Move to GPU(s) if available
328
- # if self.faiss_gpu and self.gpu_resources:
329
- # try:
330
- # if len(self.gpu_resources) > 1:
331
- # self.index = faiss.index_cpu_to_gpus_list(self.index, self.gpu_resources)
332
- # logger.info("FAISS index distributed across multiple GPUs")
333
- # else:
334
- # self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, self.index)
335
- # logger.info("FAISS index moved to single GPU")
336
- # except Exception as e:
337
- # logger.warning(f"Failed to move index to GPU: {e}. Falling back to CPU")
338
- # self.faiss_gpu = False
339
-
340
- # # Set search parameters for IVF index
341
- # if isinstance(self.index, faiss.IndexIVFFlat):
342
- # self.index.nprobe = min(10, nlist)
343
-
344
  except Exception as e:
345
  logger.error(f"Error initializing FAISS: {e}")
346
  raise
@@ -353,21 +341,16 @@ class RetrievalChatbot(DeviceAwareModel):
353
  """
354
  Encodes responses with more conservative memory management.
355
  """
 
 
 
 
356
  all_embeddings = []
357
  self.current_batch_size = batch_size
358
 
359
  if self.memory_monitor.has_gpu:
360
  batch_size = 128
361
 
362
- # Memory stats
363
- # if self.memory_monitor.has_gpu:
364
- # initial_stats = self.memory_monitor.get_memory_stats()
365
- # if initial_stats:
366
- # logger.info("Initial GPU memory state:")
367
- # logger.info(f"Total: {initial_stats.total / 1e9:.2f}GB")
368
- # logger.info(f"Used: {initial_stats.used / 1e9:.2f}GB")
369
- # logger.info(f"Free: {initial_stats.free / 1e9:.2f}GB")
370
-
371
  total_processed = 0
372
 
373
  with tqdm(total=len(responses), desc="Encoding responses") as pbar:
@@ -434,7 +417,10 @@ class RetrievalChatbot(DeviceAwareModel):
434
  raise
435
 
436
  # Concatenate results
437
- #logger.info("Concatenating embeddings...")
 
 
 
438
  if len(all_embeddings) == 1:
439
  final_embeddings = all_embeddings[0]
440
  else:
@@ -727,7 +713,11 @@ class RetrievalChatbot(DeviceAwareModel):
727
  raise
728
 
729
  def verify_faiss_index(self):
730
- """Verify that FAISS index matches the response pool."""
 
 
 
 
731
  indexed_size = self.index.ntotal
732
  pool_size = len(self.response_pool)
733
  logger.info(f"FAISS index size: {indexed_size}")
@@ -813,6 +803,10 @@ class RetrievalChatbot(DeviceAwareModel):
813
 
814
  def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
815
  """Retrieve top-k responses using FAISS."""
 
 
 
 
816
  # Encode the query
817
  q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
818
  q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
@@ -874,32 +868,6 @@ class RetrievalChatbot(DeviceAwareModel):
874
  logger.info(f"Models and tokenizer loaded from {load_dir}.")
875
  return chatbot
876
 
877
- def parse_tfrecord_fn(example_proto, max_length, neg_samples):
878
- """
879
- Parses a single TFRecord example.
880
-
881
- Args:
882
- example_proto: A serialized TFRecord example.
883
- max_length: The maximum sequence length for tokenization.
884
- neg_samples: The number of hard negatives per query.
885
-
886
- Returns:
887
- A tuple of (query_ids, positive_ids, negative_ids).
888
- """
889
- feature_description = {
890
- 'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
891
- 'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
892
- 'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
893
- }
894
- parsed_features = tf.io.parse_single_example(example_proto, feature_description)
895
-
896
- query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
897
- positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
898
- negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
899
- negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
900
-
901
- return query_ids, positive_ids, negative_ids
902
-
903
  def train_streaming(
904
  self,
905
  tfrecord_file_path: str,
@@ -915,10 +883,34 @@ class RetrievalChatbot(DeviceAwareModel):
915
  ) -> None:
916
  """Training using a pre-prepared TFRecord dataset."""
917
  logger.info("Starting training with pre-prepared TFRecord dataset...")
918
-
919
- # Calculate total steps for learning rate schedule
920
- # Estimate total pairs by counting the number of records in the TFRecord
921
- # Assuming each record corresponds to one pair
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
922
  raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
923
  total_pairs = sum(1 for _ in raw_dataset)
924
  logger.info(f"Total pairs in TFRecord: {total_pairs}")
@@ -964,12 +956,12 @@ class RetrievalChatbot(DeviceAwareModel):
964
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
965
 
966
  # Define the parsing function with the appropriate max_length and neg_samples
967
- parse_fn = lambda x: self.parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples)
968
 
969
  # Create the full dataset
970
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
971
  dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
972
- dataset = dataset.shuffle(buffer_size=10000) # Adjust buffer size as needed TODO: what is this?
973
  dataset = dataset.batch(batch_size, drop_remainder=True)
974
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
975
 
 
138
  device: str = None,
139
  strategy=None,
140
  reranker: Optional[CrossEncoderReranker] = None,
141
+ summarizer: Optional[Summarizer] = None,
142
+ mode: str = 'preparation'
143
  ):
144
  super().__init__()
145
  self.config = config
146
  self.strategy = strategy
147
  self.device = device or self._setup_default_device()
148
+ self.mode = mode.lower()
149
 
150
  # Initialize reranker, summarizer, tokenizer, and memory monitor
151
  self.reranker = reranker or self._initialize_reranker()
 
153
  self.tokenizer = self._initialize_tokenizer()
154
  self.memory_monitor = GPUMemoryMonitor()
155
 
156
+ # # Initialize models
157
+ # self.min_batch_size = 8
158
+ # self.max_batch_size = 128
159
+ # self.current_batch_size = 32
 
 
 
160
 
161
  # Initialize training history
162
  self.history = {
 
165
  "train_metrics": {},
166
  "val_metrics": {}
167
  }
168
+
169
+ # Collect unique responses from dialogues
170
+ if self.mode == 'preparation':
171
+ # Collect unique responses from dialogues only in preparation mode
172
+ self.response_pool, self.unique_responses = self._collect_responses(dialogues)
173
+ else:
174
+ # In training mode, assume response_pool is handled via TFRecord
175
+ self.response_pool = []
176
+ self.unique_responses = []
177
 
178
  def _setup_default_device(self) -> str:
179
  """Set up default device if none is provided."""
 
244
  self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
245
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
246
 
247
+ if self.mode == 'preparation':
248
+ # Initialize FAISS index only in preparation mode
249
+ self._initialize_faiss()
250
+ # Compute and index embeddings
251
+ self._compute_and_index_embeddings()
252
+ else:
253
+ # In training mode, skip FAISS indexing from dialogues
254
+ logger.info("Training mode: Skipping FAISS index initialization from dialogues.")
255
 
256
  # Retrieve embedding dimension from encoder
257
  embedding_dim = self.config.embedding_dim
 
282
  self.current_batch_size = new_size
283
 
284
  def _initialize_faiss(self):
285
+ """Initialize FAISS with safe GPU handling and memory monitoring."""
286
  logger.info("Initializing FAISS index...")
287
 
288
+ # Detect if we have GPU-enabled FAISS
289
  self.faiss_gpu = False
290
  self.gpu_resources = []
291
 
 
305
  self.gpu_resources.append(res)
306
  self.faiss_gpu = True
307
  logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs")
 
 
 
308
  except Exception as e:
309
  logger.warning(f"Using CPU due to GPU initialization error: {e}")
310
 
 
 
311
  try:
312
  # Create appropriate index based on dataset size
313
  if len(self.unique_responses) < 1000:
314
  logger.info("Small dataset detected, using simple FlatIP index")
315
  self.index = faiss.IndexFlatIP(self.config.embedding_dim)
316
  else:
317
+ # For larger datasets, consider using more efficient indices like IVF
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  self.index = faiss.IndexFlatIP(self.config.embedding_dim)
319
 
320
+ # Move to GPU(s) if available and needed
321
+ if self.faiss_gpu and self.gpu_resources:
322
+ try:
323
+ if len(self.gpu_resources) > 1:
324
+ self.index = faiss.index_cpu_to_gpus_list(self.index, self.gpu_resources)
325
+ logger.info("FAISS index distributed across multiple GPUs")
326
+ else:
327
+ self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, self.index)
328
+ logger.info("FAISS index moved to single GPU")
329
+ except Exception as e:
330
+ logger.warning(f"Failed to move index to GPU: {e}. Falling back to CPU")
331
+ self.faiss_gpu = False
 
 
 
 
 
332
  except Exception as e:
333
  logger.error(f"Error initializing FAISS: {e}")
334
  raise
 
341
  """
342
  Encodes responses with more conservative memory management.
343
  """
344
+ if not responses:
345
+ logger.info("No responses to encode. Returning empty tensor.")
346
+ return tf.constant([], dtype=tf.float32)
347
+
348
  all_embeddings = []
349
  self.current_batch_size = batch_size
350
 
351
  if self.memory_monitor.has_gpu:
352
  batch_size = 128
353
 
 
 
 
 
 
 
 
 
 
354
  total_processed = 0
355
 
356
  with tqdm(total=len(responses), desc="Encoding responses") as pbar:
 
417
  raise
418
 
419
  # Concatenate results
420
+ if not all_embeddings:
421
+ logger.info("No embeddings were encoded. Returning empty tensor.")
422
+ return tf.constant([], dtype=tf.float32)
423
+
424
  if len(all_embeddings) == 1:
425
  final_embeddings = all_embeddings[0]
426
  else:
 
713
  raise
714
 
715
  def verify_faiss_index(self):
716
+ """Verify that FAISS index matches the response pool, if index exists."""
717
+ if not hasattr(self, 'index') or self.index is None:
718
+ logger.info("FAISS index not initialized. Skipping verification.")
719
+ return
720
+
721
  indexed_size = self.index.ntotal
722
  pool_size = len(self.response_pool)
723
  logger.info(f"FAISS index size: {indexed_size}")
 
803
 
804
  def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
805
  """Retrieve top-k responses using FAISS."""
806
+ if not hasattr(self, 'index') or self.index is None:
807
+ logger.warning("FAISS index not initialized. Cannot retrieve responses.")
808
+ return []
809
+
810
  # Encode the query
811
  q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
812
  q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
 
868
  logger.info(f"Models and tokenizer loaded from {load_dir}.")
869
  return chatbot
870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
  def train_streaming(
872
  self,
873
  tfrecord_file_path: str,
 
883
  ) -> None:
884
  """Training using a pre-prepared TFRecord dataset."""
885
  logger.info("Starting training with pre-prepared TFRecord dataset...")
886
+
887
+ def parse_tfrecord_fn(example_proto, max_length, neg_samples):
888
+ """
889
+ Parses a single TFRecord example.
890
+
891
+ Args:
892
+ example_proto: A serialized TFRecord example.
893
+ max_length: The maximum sequence length for tokenization.
894
+ neg_samples: The number of hard negatives per query.
895
+
896
+ Returns:
897
+ A tuple of (query_ids, positive_ids, negative_ids).
898
+ """
899
+ feature_description = {
900
+ 'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
901
+ 'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
902
+ 'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
903
+ }
904
+ parsed_features = tf.io.parse_single_example(example_proto, feature_description)
905
+
906
+ query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
907
+ positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
908
+ negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
909
+ negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
910
+
911
+ return query_ids, positive_ids, negative_ids
912
+
913
+ # Calculate total steps by counting the number of records in the TFRecord
914
  raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
915
  total_pairs = sum(1 for _ in raw_dataset)
916
  logger.info(f"Total pairs in TFRecord: {total_pairs}")
 
956
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
957
 
958
  # Define the parsing function with the appropriate max_length and neg_samples
959
+ parse_fn = lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples)
960
 
961
  # Create the full dataset
962
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
963
  dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
964
+ dataset = dataset.shuffle(buffer_size=10000) # Adjust buffer size as needed
965
  dataset = dataset.batch(batch_size, drop_remainder=True)
966
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
967
 
run_data_preparer.py CHANGED
@@ -30,7 +30,7 @@ def main():
30
  TF_RECORD_DIR = 'training_data'
31
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
32
  FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
33
- ENVIRONMENT = 'test' # or 'production'
34
  if ENVIRONMENT == 'test':
35
  FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
36
  else:
 
30
  TF_RECORD_DIR = 'training_data'
31
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
32
  FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
33
+ ENVIRONMENT = 'production' # or 'test'
34
  if ENVIRONMENT == 'test':
35
  FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
36
  else:
run_model_train.py CHANGED
@@ -31,36 +31,57 @@ def run_interactive_chat(chatbot, quality_checker):
31
  for resp, score in candidates[1:4]:
32
  print(f"Score: {score:.4f} - {resp}")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def main():
 
 
 
 
35
  # Initialize environment
36
  tf.keras.backend.clear_session()
37
  env = EnvironmentSetup()
38
  env.initialize()
39
 
40
- DEBUG_SAMPLES = 5
41
- EPOCHS = 5 if DEBUG_SAMPLES else 20
42
  TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
43
 
44
  # Optimize batch size for Colab
45
  batch_size = env.optimize_batch_size(base_batch_size=16)
46
 
 
47
  # Initialize configuration
48
  config = ChatbotConfig(
49
  embedding_dim=768, # DistilBERT
50
  max_context_token_limit=512,
51
  freeze_embeddings=False,
52
- neg_samples=3,
53
  )
54
 
55
- # # Load training data
56
- # dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
57
- # print(dialogues)
58
-
59
- # Initialize chatbot and verify FAISS index
60
  #with env.strategy.scope():
61
- chatbot = RetrievalChatbot(config)
62
  chatbot.build_models()
63
- chatbot.verify_faiss_index()
 
 
64
 
65
  chatbot.train_streaming(
66
  tfrecord_file_path=TF_RECORD_FILE_PATH,
 
31
  for resp, score in candidates[1:4]:
32
  print(f"Score: {score:.4f} - {resp}")
33
 
34
+ def inspect_tfrecord(tfrecord_file_path, num_examples=3):
35
+ def parse_example(example_proto):
36
+ feature_description = {
37
+ 'query_ids': tf.io.FixedLenFeature([512], tf.int64), # Adjust max_length if different
38
+ 'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
39
+ 'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
40
+ }
41
+ return tf.io.parse_single_example(example_proto, feature_description)
42
+
43
+ dataset = tf.data.TFRecordDataset(tfrecord_file_path)
44
+ dataset = dataset.map(parse_example)
45
+
46
+ for i, example in enumerate(dataset.take(num_examples)):
47
+ print(f"Example {i+1}:")
48
+ print(f"Query IDs: {example['query_ids'].numpy()}")
49
+ print(f"Positive IDs: {example['positive_ids'].numpy()}")
50
+ print(f"Negative IDs: {example['negative_ids'].numpy()}")
51
+ print("-" * 50)
52
+
53
  def main():
54
+
55
+ # Quick test to inspect TFRecord
56
+ #inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
57
+
58
  # Initialize environment
59
  tf.keras.backend.clear_session()
60
  env = EnvironmentSetup()
61
  env.initialize()
62
 
63
+ # Training configuration
64
+ EPOCHS = 20
65
  TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
66
 
67
  # Optimize batch size for Colab
68
  batch_size = env.optimize_batch_size(base_batch_size=16)
69
 
70
+
71
  # Initialize configuration
72
  config = ChatbotConfig(
73
  embedding_dim=768, # DistilBERT
74
  max_context_token_limit=512,
75
  freeze_embeddings=False,
 
76
  )
77
 
78
+ # Initialize chatbot
 
 
 
 
79
  #with env.strategy.scope():
80
+ chatbot = RetrievalChatbot(config, mode='training')
81
  chatbot.build_models()
82
+
83
+ if chatbot.mode == 'preparation':
84
+ chatbot.verify_faiss_index()
85
 
86
  chatbot.train_streaming(
87
  tfrecord_file_path=TF_RECORD_FILE_PATH,
tf_data_pipeline.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import faiss
5
  import tensorflow as tf
6
  import h5py
 
7
  from tqdm import tqdm
8
  import json
9
  from pathlib import Path
@@ -146,7 +147,7 @@ class TFDataPipeline:
146
  def collect_responses(self, dialogues: List[dict]) -> List[str]:
147
  """Extract unique assistant responses from dialogues."""
148
  response_set = set()
149
- for dialogue in dialogues:
150
  turns = dialogue.get('turns', [])
151
  for turn in turns:
152
  speaker = turn.get('speaker')
@@ -180,20 +181,17 @@ class TFDataPipeline:
180
 
181
  def _compute_and_index_response_embeddings(self):
182
  """
183
- Computes embeddings for the response pool and adds them to the FAISS index.
184
  """
185
  logger.info("Computing embeddings for the response pool...")
186
 
187
- # Log the contents and types of response_pool
188
- for idx, response in enumerate(self.response_pool[:5], 1): # Log first 5 responses
189
- logger.debug(f"Response {idx}: {response} (Type: {type(response)})")
190
-
191
  # Ensure all responses are strings
192
  if not all(isinstance(response, str) for response in self.response_pool):
193
  logger.error("All elements in response_pool must be strings.")
194
  raise ValueError("Invalid data type in response_pool.")
195
 
196
- # Proceed with tokenization
 
197
  encoded_responses = self.tokenizer(
198
  self.response_pool,
199
  padding=True,
@@ -203,28 +201,87 @@ class TFDataPipeline:
203
  )
204
  response_ids = encoded_responses['input_ids']
205
 
206
- # Compute embeddings in batches
207
  batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
 
 
208
  embeddings = []
209
- for i in range(0, len(response_ids), batch_size):
210
- batch_ids = response_ids[i:i+batch_size]
211
- # Compute embeddings
212
- batch_embeddings = self.encoder(batch_ids, training=False).numpy()
213
- # Normalize embeddings if using inner product or cosine similarity
214
- faiss.normalize_L2(batch_embeddings)
215
- embeddings.append(batch_embeddings)
 
 
 
216
 
217
  if embeddings:
218
  embeddings = np.vstack(embeddings).astype(np.float32)
219
- # Add embeddings to FAISS index
220
  logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
221
- self.index.add(embeddings)
 
 
 
 
 
 
 
 
 
 
 
222
  logger.info("Response embeddings added to FAISS index.")
223
  else:
224
  logger.warning("No embeddings to add to FAISS index.")
225
 
226
  # **Sanity Check:** Verify the number of embeddings in FAISS index
227
  logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
230
  """Find hard negatives for a batch of queries with error handling and retries."""
@@ -349,53 +406,169 @@ class TFDataPipeline:
349
 
350
  return query_ids, positive_ids, negative_ids
351
 
352
- def prepare_and_save_data(self, dialogues: List[dict], tfrecord_file_path: str, batch_size: int = 32):
353
- """Processes dialogues in batches and saves to a TFRecord file."""
354
- with tf.io.TFRecordWriter(tfrecord_file_path) as writer:
355
- total_dialogues = len(dialogues)
356
- logger.debug(f"Total dialogues to process: {total_dialogues}")
357
-
358
- with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
359
- for i in range(0, total_dialogues, batch_size):
360
- batch_dialogues = dialogues[i:i+batch_size]
361
- # Process each batch_dialogues
362
- # Extract pairs, find negatives, tokenize, and serialize
363
- # Example:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  for dialogue in batch_dialogues:
365
  pairs = self._extract_pairs_from_dialogue(dialogue)
366
- queries = []
367
- positives = []
368
-
369
  for query, positive in pairs:
370
- queries.append(query)
371
- positives.append(positive)
372
-
373
- if queries:
374
- # **Compute and cache query embeddings before searching**
375
- self._compute_embeddings(queries)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
- # Find hard negatives
378
- hard_negatives = self._find_hard_negatives_batch(queries, positives)
379
 
380
- for idx, negatives in enumerate(hard_negatives[:5]): # Log first 5 examples
381
- logger.debug(f"Query: {queries[idx]}")
382
- logger.debug(f"Positive: {positives[idx]}")
383
- logger.debug(f"Hard Negatives: {negatives}")
384
- # Tokenize and encode
385
- query_ids, positive_ids, negative_ids = self._tokenize_and_encode(queries, positives, hard_negatives)
386
 
387
- # Serialize each example and write to TFRecord
388
- for q_id, p_id, n_id in zip(query_ids, positive_ids, negative_ids):
389
- feature = {
390
- 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
391
- 'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
392
- 'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
393
- }
394
- example = tf.train.Example(features=tf.train.Features(feature=feature))
395
- writer.write(example.SerializeToString())
396
 
397
- pbar.update(len(batch_dialogues))
398
- logger.info(f"Data preparation complete. TFRecord saved at {tfrecord_file_path}")
399
 
400
  def _tokenize_negatives_tf(self, negatives):
401
  """Tokenizes negatives using tf.py_function."""
 
4
  import faiss
5
  import tensorflow as tf
6
  import h5py
7
+ import math
8
  from tqdm import tqdm
9
  import json
10
  from pathlib import Path
 
147
  def collect_responses(self, dialogues: List[dict]) -> List[str]:
148
  """Extract unique assistant responses from dialogues."""
149
  response_set = set()
150
+ for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
151
  turns = dialogue.get('turns', [])
152
  for turn in turns:
153
  speaker = turn.get('speaker')
 
181
 
182
  def _compute_and_index_response_embeddings(self):
183
  """
184
+ Computes embeddings for the response pool and adds them to the FAISS index with progress bars.
185
  """
186
  logger.info("Computing embeddings for the response pool...")
187
 
 
 
 
 
188
  # Ensure all responses are strings
189
  if not all(isinstance(response, str) for response in self.response_pool):
190
  logger.error("All elements in response_pool must be strings.")
191
  raise ValueError("Invalid data type in response_pool.")
192
 
193
+ # Tokenization
194
+ logger.info("Tokenizing responses...")
195
  encoded_responses = self.tokenizer(
196
  self.response_pool,
197
  padding=True,
 
201
  )
202
  response_ids = encoded_responses['input_ids']
203
 
204
+ # Compute embeddings in batches with progress bar
205
  batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
206
+ total_responses = len(response_ids)
207
+ logger.info(f"Computing embeddings in batches of {batch_size}...")
208
  embeddings = []
209
+
210
+ with tqdm(total=total_responses, desc="Computing Embeddings", unit="response") as pbar:
211
+ for i in range(0, total_responses, batch_size):
212
+ batch_ids = response_ids[i:i + batch_size]
213
+ # Compute embeddings
214
+ batch_embeddings = self.encoder(batch_ids, training=False).numpy()
215
+ # Normalize embeddings for cosine similarity
216
+ faiss.normalize_L2(batch_embeddings)
217
+ embeddings.append(batch_embeddings)
218
+ pbar.update(len(batch_ids))
219
 
220
  if embeddings:
221
  embeddings = np.vstack(embeddings).astype(np.float32)
222
+ # Add embeddings to FAISS index with progress bar
223
  logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
224
+
225
+ # Determine number of batches for indexing
226
+ index_batch_size = getattr(self, 'index_batch_size', 1000) # Adjust as needed
227
+ total_embeddings = len(embeddings)
228
+ num_index_batches = math.ceil(total_embeddings / index_batch_size)
229
+
230
+ with tqdm(total=total_embeddings, desc="Indexing Embeddings", unit="embedding") as pbar_index:
231
+ for i in range(0, total_embeddings, index_batch_size):
232
+ batch = embeddings[i:i + index_batch_size]
233
+ self.index.add(batch)
234
+ pbar_index.update(len(batch))
235
+
236
  logger.info("Response embeddings added to FAISS index.")
237
  else:
238
  logger.warning("No embeddings to add to FAISS index.")
239
 
240
  # **Sanity Check:** Verify the number of embeddings in FAISS index
241
  logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
242
+ # def _compute_and_index_response_embeddings(self):
243
+ # """
244
+ # Computes embeddings for the response pool and adds them to the FAISS index.
245
+ # """
246
+ # logger.info("Computing embeddings for the response pool...")
247
+
248
+ # # Ensure all responses are strings
249
+ # if not all(isinstance(response, str) for response in self.response_pool):
250
+ # logger.error("All elements in response_pool must be strings.")
251
+ # raise ValueError("Invalid data type in response_pool.")
252
+
253
+ # # Proceed with tokenization
254
+ # encoded_responses = self.tokenizer(
255
+ # self.response_pool,
256
+ # padding=True,
257
+ # truncation=True,
258
+ # max_length=self.max_length,
259
+ # return_tensors='tf'
260
+ # )
261
+ # response_ids = encoded_responses['input_ids']
262
+
263
+ # # Compute embeddings in batches
264
+ # batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
265
+ # embeddings = []
266
+ # for i in range(0, len(response_ids), batch_size):
267
+ # batch_ids = response_ids[i:i+batch_size]
268
+ # # Compute embeddings
269
+ # batch_embeddings = self.encoder(batch_ids, training=False).numpy()
270
+ # # Normalize embeddings if using inner product or cosine similarity
271
+ # faiss.normalize_L2(batch_embeddings)
272
+ # embeddings.append(batch_embeddings)
273
+
274
+ # if embeddings:
275
+ # embeddings = np.vstack(embeddings).astype(np.float32)
276
+ # # Add embeddings to FAISS index
277
+ # logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
278
+ # self.index.add(embeddings)
279
+ # logger.info("Response embeddings added to FAISS index.")
280
+ # else:
281
+ # logger.warning("No embeddings to add to FAISS index.")
282
+
283
+ # # **Sanity Check:** Verify the number of embeddings in FAISS index
284
+ # logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
285
 
286
  def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
287
  """Find hard negatives for a batch of queries with error handling and retries."""
 
406
 
407
  return query_ids, positive_ids, negative_ids
408
 
409
+ # Testing updated batch tokenization
410
+ def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
411
+ """
412
+ Processes dialogues in batches and saves to a TFRecord file using optimized batch tokenization and encoding.
413
+
414
+ Args:
415
+ dialogues (List[dict]): List of dialogue dictionaries.
416
+ tf_record_path (str): Path to save the TFRecord file.
417
+ batch_size (int): Number of dialogues to process per batch.
418
+ """
419
+ logger.info(f"Preparing and saving data to {tf_record_path}...")
420
+
421
+ num_dialogues = len(dialogues)
422
+ num_batches = math.ceil(num_dialogues / batch_size)
423
+
424
+ with tf.io.TFRecordWriter(tf_record_path) as writer:
425
+ # Initialize progress bar
426
+ with tqdm(total=num_batches, desc="Preparing Data Batches", unit="batch") as pbar:
427
+ for i in range(num_batches):
428
+ start_idx = i * batch_size
429
+ end_idx = min(start_idx + batch_size, num_dialogues)
430
+ batch_dialogues = dialogues[start_idx:end_idx]
431
+
432
+ # Extract all query-positive pairs in the batch
433
+ queries = []
434
+ positives = []
435
  for dialogue in batch_dialogues:
436
  pairs = self._extract_pairs_from_dialogue(dialogue)
 
 
 
437
  for query, positive in pairs:
438
+ if len(query) <= self.max_length and len(positive) <= self.max_length:
439
+ queries.append(query)
440
+ positives.append(positive)
441
+
442
+ if not queries:
443
+ pbar.update(1)
444
+ continue # Skip if no valid queries
445
+
446
+ # Compute and cache query embeddings
447
+ try:
448
+ self._compute_embeddings(queries)
449
+ except Exception as e:
450
+ logger.error(f"Error computing embeddings: {e}")
451
+ pbar.update(1)
452
+ continue # Skip to the next batch
453
+
454
+ # Find hard negatives for the batch
455
+ try:
456
+ hard_negatives = self._find_hard_negatives_batch(queries, positives)
457
+ except Exception as e:
458
+ logger.error(f"Error finding hard negatives: {e}")
459
+ pbar.update(1)
460
+ continue # Skip to the next batch
461
+
462
+ # Tokenize and encode all queries, positives, and negatives in the batch
463
+ try:
464
+ encoded_queries = self.tokenizer.batch_encode_plus(
465
+ queries,
466
+ max_length=self.config.max_context_token_limit,
467
+ truncation=True,
468
+ padding='max_length',
469
+ return_tensors='tf'
470
+ )
471
+ encoded_positives = self.tokenizer.batch_encode_plus(
472
+ positives,
473
+ max_length=self.config.max_context_token_limit,
474
+ truncation=True,
475
+ padding='max_length',
476
+ return_tensors='tf'
477
+ )
478
+ except Exception as e:
479
+ logger.error(f"Error during tokenization: {e}")
480
+ pbar.update(1)
481
+ continue # Skip to the next batch
482
+
483
+ # Flatten hard_negatives while maintaining alignment
484
+ # Assuming hard_negatives is a list of lists, where each sublist corresponds to a query
485
+ try:
486
+ flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
487
+ encoded_negatives = self.tokenizer.batch_encode_plus(
488
+ flattened_negatives,
489
+ max_length=self.config.max_context_token_limit,
490
+ truncation=True,
491
+ padding='max_length',
492
+ return_tensors='tf'
493
+ )
494
+
495
+ # Reshape encoded_negatives['input_ids'] to [num_queries, num_negatives, max_length]
496
+ num_negatives = self.config.neg_samples
497
+ reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_token_limit)
498
+ except Exception as e:
499
+ logger.error(f"Error during negatives tokenization: {e}")
500
+ pbar.update(1)
501
+ continue # Skip to the next batch
502
+
503
+ # Serialize each example and write to TFRecord
504
+ for j in range(len(queries)):
505
+ try:
506
+ q_id = encoded_queries['input_ids'][j].numpy()
507
+ p_id = encoded_positives['input_ids'][j].numpy()
508
+ n_id = reshaped_negatives[j]
509
+
510
+ feature = {
511
+ 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
512
+ 'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
513
+ 'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
514
+ }
515
+ example = tf.train.Example(features=tf.train.Features(feature=feature))
516
+ writer.write(example.SerializeToString())
517
+ except Exception as e:
518
+ logger.error(f"Error serializing example {j} in batch {i}: {e}")
519
+ continue # Skip to the next example
520
+
521
+ # Update progress bar
522
+ pbar.update(1)
523
+
524
+ logger.info(f"Data preparation complete. TFRecord saved.")
525
+ # def prepare_and_save_data(self, dialogues: List[dict], tfrecord_file_path: str, batch_size: int = 32):
526
+ # """Processes dialogues in batches and saves to a TFRecord file."""
527
+ # with tf.io.TFRecordWriter(tfrecord_file_path) as writer:
528
+ # total_dialogues = len(dialogues)
529
+ # logger.debug(f"Total dialogues to process: {total_dialogues}")
530
+
531
+ # with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
532
+ # for i in range(0, total_dialogues, batch_size):
533
+ # batch_dialogues = dialogues[i:i+batch_size]
534
+ # # Process each batch_dialogues
535
+ # # Extract pairs, find negatives, tokenize, and serialize
536
+ # # Example:
537
+ # for dialogue in batch_dialogues:
538
+ # pairs = self._extract_pairs_from_dialogue(dialogue)
539
+ # queries = []
540
+ # positives = []
541
+
542
+ # for query, positive in pairs:
543
+ # queries.append(query)
544
+ # positives.append(positive)
545
+
546
+ # if queries:
547
+ # # **Compute and cache query embeddings before searching**
548
+ # self._compute_embeddings(queries)
549
 
550
+ # # Find hard negatives
551
+ # hard_negatives = self._find_hard_negatives_batch(queries, positives)
552
 
553
+ # # for idx, negatives in enumerate(hard_negatives[:5]): # Log first 5 examples
554
+ # # logger.debug(f"Query: {queries[idx]}")
555
+ # # logger.debug(f"Positive: {positives[idx]}")
556
+ # # logger.debug(f"Hard Negatives: {negatives}")
557
+ # # Tokenize and encode
558
+ # query_ids, positive_ids, negative_ids = self._tokenize_and_encode(queries, positives, hard_negatives)
559
 
560
+ # # Serialize each example and write to TFRecord
561
+ # for q_id, p_id, n_id in zip(query_ids, positive_ids, negative_ids):
562
+ # feature = {
563
+ # 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
564
+ # 'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
565
+ # 'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
566
+ # }
567
+ # example = tf.train.Example(features=tf.train.Features(feature=feature))
568
+ # writer.write(example.SerializeToString())
569
 
570
+ # pbar.update(len(batch_dialogues))
571
+ # logger.info(f"Data preparation complete. TFRecord saved at {tfrecord_file_path}")
572
 
573
  def _tokenize_negatives_tf(self, negatives):
574
  """Tokenizes negatives using tf.py_function."""