JoeArmani commited on
Commit
9decf80
·
1 Parent(s): 19403c5

FAISS and streaming updates

Browse files
chatbot_model.py CHANGED
@@ -1,26 +1,35 @@
 
1
  from transformers import TFAutoModel, AutoTokenizer
2
  import tensorflow as tf
3
  import numpy as np
4
- from typing import List, Tuple, Dict, Optional, Union, Any
 
 
5
  import math
6
  from dataclasses import dataclass
7
  import json
8
- from tqdm import tqdm
9
  from pathlib import Path
10
  import datetime
11
  import faiss
 
 
12
  from response_quality_checker import ResponseQualityChecker
13
  from cross_encoder_reranker import CrossEncoderReranker
14
  from conversation_summarizer import DeviceAwareModel, Summarizer
 
 
15
  from logger_config import config_logger
 
 
 
16
  logger = config_logger(__name__)
17
 
18
  @dataclass
19
  class ChatbotConfig:
20
  """Configuration for the RetrievalChatbot."""
21
- vocab_size: int = 30526 # DistilBERT vocab size
22
  max_context_token_limit: int = 512
23
- embedding_dim: int = 512 # Match DistilBERT's dimension
24
  encoder_units: int = 256
25
  num_attention_heads: int = 8
26
  dropout_rate: float = 0.2
@@ -130,16 +139,16 @@ class RetrievalChatbot(DeviceAwareModel):
130
  summarizer = Summarizer(device=self.device)
131
  self.summarizer = summarizer
132
 
133
- # Configure XLA optimization if on GPU/TPU
134
- if self.device in ["GPU", "TPU"]:
135
- tf.config.optimizer.set_jit(True)
136
- logger.info(f"XLA compilation enabled for {self.device}")
137
 
138
- # Configure mixed precision for GPU/TPU
139
- if self.device != "CPU":
140
- policy = tf.keras.mixed_precision.Policy('mixed_float16')
141
- tf.keras.mixed_precision.set_global_policy(policy)
142
- logger.info("Mixed precision training enabled (float16)")
143
 
144
  # Special tokens
145
  self.special_tokens = {
@@ -155,18 +164,13 @@ class RetrievalChatbot(DeviceAwareModel):
155
  {'additional_special_tokens': list(self.special_tokens.values())}
156
  )
157
 
158
- # Build encoders within device strategy scope
159
- if self.strategy:
160
- with self.strategy.scope():
161
- self._build_models()
162
- else:
163
- self._build_models()
164
-
165
- # Initialize FAISS index
166
- self._initialize_faiss()
167
 
168
- # Precompute and index response embeddings
169
- self._precompute_and_index_responses(dialogues)
170
 
171
  # Initialize training history
172
  self.history = {
@@ -176,9 +180,10 @@ class RetrievalChatbot(DeviceAwareModel):
176
  "val_metrics": {}
177
  }
178
 
179
- def _build_models(self):
180
  """Initialize the shared encoder."""
181
  logger.info("Building encoder model...")
 
182
 
183
  # Shared encoder for both queries and responses
184
  self.encoder = EncoderModel(
@@ -191,11 +196,10 @@ class RetrievalChatbot(DeviceAwareModel):
191
  self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
192
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
193
 
194
- # Debug embeddings attributes
195
- logger.info("Inspecting embeddings attributes:")
196
- for attr in dir(self.encoder.pretrained.distilbert.embeddings):
197
- if not attr.startswith('_'):
198
- logger.info(f" {attr}")
199
 
200
  # Try different ways to get embedding dimension
201
  try:
@@ -227,45 +231,18 @@ class RetrievalChatbot(DeviceAwareModel):
227
  logger.error("Vocabulary size is less than embedding dimension.")
228
  raise ValueError("Vocabulary size is less than embedding dimension.")
229
 
230
- def _initialize_faiss(self):
231
- """Initialize FAISS index based on available resources."""
232
- logger.info("Initializing FAISS index...")
233
- # Determine if GPU FAISS is available
234
- try:
235
- res = faiss.StandardGpuResources()
236
- self.faiss_gpu = True
237
- logger.info("FAISS GPU resources initialized.")
238
- except Exception as e:
239
- self.faiss_gpu = False
240
- logger.info("FAISS GPU resources not available. Using FAISS CPU.")
241
-
242
- # Initialize FAISS index for Inner Product (for cosine similarity)
243
- if self.faiss_gpu:
244
- self.index = faiss.IndexFlatIP(self.config.embedding_dim)
245
- self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
246
- else:
247
- self.index = faiss.IndexFlatIP(self.config.embedding_dim)
248
- logger.info("FAISS index initialized.")
249
-
250
- def verify_faiss_index(self):
251
- """Verify that FAISS index matches the response pool."""
252
- indexed_size = self.index.ntotal
253
- pool_size = len(self.response_pool)
254
- logger.info(f"FAISS index size: {indexed_size}")
255
- logger.info(f"Response pool size: {pool_size}")
256
- if indexed_size != pool_size:
257
- logger.warning("Mismatch between FAISS index size and response pool size.")
258
- else:
259
- logger.info("FAISS index correctly matches the response pool.")
260
-
261
-
262
- def _precompute_and_index_responses(self, dialogues: List[dict]):
263
- """Precompute embeddings for all responses and index them using FAISS."""
264
- logger.info("Precomputing response embeddings and indexing with FAISS...")
265
 
266
- # Use tqdm for collecting responses
267
  responses = []
268
- for dialogue in tqdm(dialogues, desc="Collecting assistant responses"):
 
 
 
 
 
 
269
  turns = dialogue.get('turns', [])
270
  for turn in turns:
271
  if turn.get('speaker') == 'assistant' and 'text' in turn:
@@ -275,33 +252,97 @@ class RetrievalChatbot(DeviceAwareModel):
275
  unique_responses = list(set(responses))
276
  logger.info(f"Found {len(unique_responses)} unique responses.")
277
 
278
- # Encode responses
279
- logger.info("Encoding unique responses")
280
- response_embeddings = self.encode_responses(unique_responses)
281
- response_embeddings = response_embeddings.numpy()
282
-
283
- # Ensure float32
284
- if response_embeddings.dtype != np.float32:
285
- response_embeddings = response_embeddings.astype('float32')
286
-
287
- # Ensure the array is contiguous in memory
288
- if not response_embeddings.flags['C_CONTIGUOUS']:
289
- logger.info("Making embeddings contiguous in memory.")
290
- response_embeddings = np.ascontiguousarray(response_embeddings)
 
 
 
 
 
 
 
 
291
 
292
- # Normalize embeddings for cosine similarity
293
- logger.info("Normalizing embeddings with FAISS.")
294
- faiss.normalize_L2(response_embeddings)
295
 
296
- # Add to FAISS index
297
- logger.info("Adding embeddings to FAISS index...")
298
- self.index.add(response_embeddings)
299
- logger.info(f"Indexed {self.index.ntotal} responses.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- # Store responses and embeddings
302
- self.response_pool = unique_responses
303
- self.response_embeddings = response_embeddings
304
- logger.info("Precomputation and indexing completed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  def encode_responses(
307
  self,
@@ -309,55 +350,390 @@ class RetrievalChatbot(DeviceAwareModel):
309
  batch_size: int = 64
310
  ) -> tf.Tensor:
311
  """
312
- Encodes a list of responses into embeddings, using chunked/batched processing
313
- to avoid running out of memory when there are many responses.
314
-
315
- Args:
316
- responses (List[str]): The list of response texts to encode.
317
- batch_size (int): How many responses to encode per chunk.
318
- Adjust based on available GPU/CPU memory.
319
-
320
- Returns:
321
- tf.Tensor: Tensor of shape (N, emb_dim) with all response embeddings.
322
  """
323
- # Accumulate embeddings in a list and concatenate at the end
324
  all_embeddings = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
- # Process the responses in chunks of 'batch_size'
327
- for start_idx in range(0, len(responses), batch_size):
328
- end_idx = start_idx + batch_size
329
- batch_texts = responses[start_idx:end_idx]
330
-
331
- # Tokenize the current batch
332
- encodings = self.tokenizer(
333
- batch_texts,
334
- padding='max_length',
335
- truncation=True,
336
- max_length=self.config.max_context_token_limit,
337
- return_tensors='tf',
338
- )
339
-
340
- # Run the encoder forward pass
341
- input_ids = encodings['input_ids']
342
- embeddings_batch = self.encoder(input_ids, training=False)
343
-
344
- # Cast to float32 if needed
345
- if embeddings_batch.dtype != tf.float32:
346
- embeddings_batch = tf.cast(embeddings_batch, tf.float32)
347
 
348
- # Collect
349
- all_embeddings.append(embeddings_batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
- # Concatenate all batch embeddings along axis=0
 
352
  if len(all_embeddings) == 1:
353
- # Only one batch
354
  final_embeddings = all_embeddings[0]
355
  else:
356
- # Multiple batches, concatenate
357
  final_embeddings = tf.concat(all_embeddings, axis=0)
358
 
359
  return final_embeddings
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
362
  """Encode a query with optional conversation context."""
363
  # Prepare query with context
@@ -436,7 +812,7 @@ class RetrievalChatbot(DeviceAwareModel):
436
  """Retrieve top-k responses using FAISS."""
437
  # Encode the query
438
  q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
439
- q_emb_np = q_emb.numpy().astype('float32') # Ensure type matches FAISS requirements
440
 
441
  # Normalize the query embedding for cosine similarity
442
  faiss.normalize_L2(q_emb_np)
@@ -523,130 +899,9 @@ class RetrievalChatbot(DeviceAwareModel):
523
  logger.info(f"Loaded {len(dialogues)} dialogues.")
524
  return dialogues
525
 
526
- def prepare_dataset(
527
- self,
528
- dialogues: List[dict],
529
- neg_samples: int = 1,
530
- debug_samples: int = None
531
- ) -> Tuple[tf.Tensor, tf.Tensor]:
532
- """
533
- Prepares dataset for multiple-negatives ranking,
534
- but also appends 'hard negative' pairs for each query.
535
-
536
- We'll generate:
537
- - (query, positive) as usual
538
- - (query, negative) for each query, using FAISS top-1 approx. negative.
539
- Then, in-batch training sees them as 'two different positives'
540
- for the same query, forcing the model to discriminate them.
541
- """
542
-
543
- logger.info("Preparing in-batch dataset with hard negatives...")
544
-
545
- queries, positives = [], []
546
-
547
- # Assemble (q, p)
548
- for dialogue in dialogues:
549
- turns = dialogue.get('turns', [])
550
- for i in range(len(turns) - 1):
551
- current_turn = turns[i]
552
- next_turn = turns[i+1]
553
-
554
- if (current_turn.get('speaker') == 'user'
555
- and next_turn.get('speaker') == 'assistant'
556
- and 'text' in current_turn
557
- and 'text' in next_turn):
558
-
559
- query_text = current_turn['text'].strip()
560
- pos_text = next_turn['text'].strip()
561
-
562
- queries.append(query_text)
563
- positives.append(pos_text)
564
-
565
- # Debug slicing
566
- if debug_samples is not None:
567
- queries = queries[:debug_samples]
568
- positives = positives[:debug_samples]
569
- logger.info(f"Debug mode: limited to {debug_samples} pairs.")
570
-
571
- logger.info(f"Prepared {len(queries)} (query, positive) pairs initially.")
572
-
573
- # Find a hard negative from FAISS for each (q, p)
574
- # Create a second 'positive' row => (q, negative). In-batch, it's seen as a different 'positive' row, but is a hard negative.
575
- augmented_queries = []
576
- augmented_positives = []
577
-
578
- for q_text, p_text in zip(queries, positives):
579
- neg_texts = self._find_hard_negative(q_text, p_text, top_k=5, neg_samples=neg_samples)
580
- for neg_text in neg_texts:
581
- augmented_queries.append(q_text)
582
- augmented_positives.append(neg_text)
583
-
584
- logger.info(f"Found hard negatives for {len(augmented_queries)} queries.")
585
-
586
- # Combine them into a single big list -> Original pairs: (q, p) & Hard neg pairs: (q, n)
587
- final_queries = queries + augmented_queries
588
- final_positives = positives + augmented_positives
589
- logger.info(f"Total dataset size after adding hard neg: {len(final_queries)}")
590
-
591
- # Tokenize
592
- encoded_queries = self.tokenizer(
593
- final_queries,
594
- padding='max_length',
595
- truncation=True,
596
- max_length=self.config.max_context_token_limit,
597
- return_tensors='tf'
598
- )
599
- encoded_positives = self.tokenizer(
600
- final_positives,
601
- padding='max_length',
602
- truncation=True,
603
- max_length=self.config.max_context_token_limit,
604
- return_tensors='tf'
605
- )
606
-
607
- q_tensor = encoded_queries['input_ids']
608
- p_tensor = encoded_positives['input_ids']
609
-
610
- logger.info("Tokenized and padded sequences for in-batch training + hard negatives.")
611
- return q_tensor, p_tensor
612
-
613
- def _find_hard_negative(
614
- self,
615
- query_text: str,
616
- positive_text: str,
617
- top_k: int = 5,
618
- neg_samples: int = 1
619
- ) -> List[str]:
620
- """
621
- Return up to `neg_samples` unique negatives from top_k FAISS results,
622
- excluding the known positive_text.
623
- """
624
- # Encode the query to get the embedding
625
- query_emb = self.encode_query(query_text)
626
- q_emb_np = query_emb.numpy().astype('float32')
627
-
628
- # Normalize for cosine similarity
629
- faiss.normalize_L2(q_emb_np)
630
-
631
- # Search in FAISS
632
- distances, indices = self.index.search(q_emb_np, top_k)
633
-
634
- # Exclude the actual positive from these results
635
- hard_negatives = []
636
- for idx in indices[0]:
637
- if idx < len(self.response_pool):
638
- candidate = self.response_pool[idx].strip()
639
- if candidate != positive_text.strip():
640
- hard_negatives.append(candidate)
641
- if len(hard_negatives) == neg_samples:
642
- break
643
-
644
- return hard_negatives
645
-
646
- def train(
647
  self,
648
- q_pad: tf.Tensor,
649
- p_pad: tf.Tensor,
650
  epochs: int = 20,
651
  batch_size: int = 16,
652
  validation_split: float = 0.2,
@@ -656,23 +911,41 @@ class RetrievalChatbot(DeviceAwareModel):
656
  warmup_steps_ratio: float = 0.1,
657
  early_stopping_patience: int = 3,
658
  min_delta: float = 1e-4,
659
- accum_steps: int = 2 # Gradient accumulation steps
660
- ):
661
- dataset_size = tf.shape(q_pad)[0].numpy()
662
- val_size = int(dataset_size * validation_split)
663
- train_size = dataset_size - val_size
664
-
665
- logger.info(f"Total samples: {dataset_size}")
666
- logger.info(f"Training samples: {train_size}")
667
- logger.info(f"Validation samples: {val_size}")
 
 
 
 
 
 
 
 
 
 
 
668
 
669
- steps_per_epoch = train_size // batch_size
670
- if train_size % batch_size != 0:
671
- steps_per_epoch += 1
 
 
672
  total_steps = steps_per_epoch * epochs
673
- logger.info(f"Total training steps (approx): {total_steps}")
674
 
675
- # 1) Set up LR schedule or fixed LR
 
 
 
 
 
 
676
  if use_lr_schedule:
677
  warmup_steps = int(total_steps * warmup_steps_ratio)
678
  lr_schedule = self._get_lr_schedule(
@@ -686,175 +959,290 @@ class RetrievalChatbot(DeviceAwareModel):
686
  self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
687
  logger.info("Using fixed learning rate.")
688
 
689
- # 2) Prepare data splits
690
- train_q = q_pad[:train_size]
691
- train_p = p_pad[:train_size]
692
- val_q = q_pad[train_size:]
693
- val_p = p_pad[train_size:]
694
-
695
- train_dataset = (tf.data.Dataset.from_tensor_slices((train_q, train_p))
696
- .shuffle(4096)
697
- .batch(batch_size)
698
- .prefetch(tf.data.AUTOTUNE))
699
-
700
- val_dataset = (tf.data.Dataset.from_tensor_slices((val_q, val_p))
701
- .batch(batch_size)
702
- .prefetch(tf.data.AUTOTUNE))
703
-
704
- # 3) Checkpoint + manager
705
  checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
706
  manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
707
 
708
- # 4) TensorBoard setup
709
  log_dir = Path(checkpoint_dir) / "tensorboard_logs"
710
  log_dir.mkdir(parents=True, exist_ok=True)
711
-
712
  current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
713
  train_log_dir = str(log_dir / f"train_{current_time}")
714
  val_log_dir = str(log_dir / f"val_{current_time}")
715
-
716
  train_summary_writer = tf.summary.create_file_writer(train_log_dir)
717
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
718
-
719
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
720
 
721
- # 5) Early stopping
722
  best_val_loss = float("inf")
723
  epochs_no_improve = 0
724
 
725
- logger.info("Beginning training loop...")
726
- global_step = 0
727
-
728
- # Prepare zero-initialized accumulators for your trainable variables
729
- # We'll accumulate gradients across mini-batches, then apply them every accum_steps.
730
- train_vars = self.encoder.pretrained.trainable_variables
731
- accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
732
-
733
- from tqdm import tqdm
734
- for epoch in range(1, epochs + 1):
735
- logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
736
- epoch_loss_avg = tf.keras.metrics.Mean()
737
-
738
- step_in_epoch = 0
739
- with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
740
- for (q_batch, p_batch) in train_dataset:
741
- step_in_epoch += 1
742
- global_step += 1
743
-
744
- with tf.GradientTape() as tape:
745
- q_enc = self.encoder(q_batch, training=True)
746
- p_enc = self.encoder(p_batch, training=True)
747
-
748
- sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
749
- bsz = tf.shape(q_enc)[0]
750
- labels = tf.range(bsz, dtype=tf.int32)
751
- loss_value = tf.nn.sparse_softmax_cross_entropy_with_logits(
752
- labels=labels, logits=sim_matrix
753
- )
754
- loss_value = tf.reduce_mean(loss_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
 
756
- gradients = tape.gradient(loss_value, train_vars)
 
 
 
757
 
758
- # -- Accumulate gradients --
759
- for i, grad in enumerate(gradients):
760
- if grad is not None:
761
- accum_grads[i] += tf.cast(grad, tf.float32)
 
 
762
 
763
- epoch_loss_avg(loss_value)
 
 
764
 
765
- # -- Apply gradients every 'accum_steps' mini-batches --
766
- if (step_in_epoch % accum_steps) == 0:
767
- # Scale by 1/accum_steps so that each accumulation cycle
768
- # is effectively the same as one “normal” update
769
- for i in range(len(accum_grads)):
770
- accum_grads[i] /= accum_steps
771
 
772
- self.optimizer.apply_gradients(
773
- [(accum_grads[i], train_vars[i]) for i in range(len(accum_grads))]
774
- )
775
- # Reset the accumulator
776
- accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
777
-
778
- # Logging / tqdm updates
779
- if use_lr_schedule:
780
- # measure current LR
781
- lr = self.optimizer.learning_rate
782
- if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
783
- current_step = tf.cast(self.optimizer.iterations, tf.float32)
784
- current_lr = lr(current_step)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785
  else:
786
- current_lr = lr
787
- current_lr_value = float(current_lr.numpy())
788
- else:
789
- current_lr_value = float(self.optimizer.learning_rate.numpy())
790
-
791
- pbar.update(1)
792
- pbar.set_postfix({
793
- "loss": f"{loss_value.numpy():.4f}",
794
- "lr": f"{current_lr_value:.2e}"
795
- })
796
-
797
- # TensorBoard logging omitted for brevity...
798
-
799
- # -- Handle leftover partial accumulation at epoch end --
800
- leftover = (step_in_epoch % accum_steps)
801
- if leftover != 0:
802
- logger.info(f"Applying leftover accum_grads for partial batch group (size={leftover}).")
803
- # If you want each leftover batch to contribute proportionally:
804
- # multiply by leftover/accum_steps (this ensures leftover
805
- # steps have the same "average" effect as a full accumulation cycle)
806
- for i in range(len(accum_grads)):
807
- accum_grads[i] *= float(leftover) / float(accum_steps)
808
-
809
- self.optimizer.apply_gradients(
810
- [(accum_grads[i], train_vars[i]) for i in range(len(accum_grads))]
811
- )
812
- accum_grads = [tf.zeros_like(var, dtype=tf.float32) for var in train_vars]
813
-
814
- # Validation
815
- val_loss_avg = tf.keras.metrics.Mean()
816
- for q_val, p_val in val_dataset:
817
- q_enc = self.encoder(q_val, training=False)
818
- p_enc = self.encoder(p_val, training=False)
819
- sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
820
- bs_val = tf.shape(q_enc)[0]
821
- labels_val = tf.range(bs_val, dtype=tf.int32)
822
- loss_val = tf.nn.sparse_softmax_cross_entropy_with_logits(
823
- labels=labels_val,
824
- logits=sim_matrix
825
- )
826
- val_loss_avg(tf.reduce_mean(loss_val))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827
 
828
- train_loss = epoch_loss_avg.result().numpy()
829
- val_loss = val_loss_avg.result().numpy()
 
 
 
 
 
 
830
 
831
- logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
832
 
833
- # TensorBoard: validation loss
834
- with val_summary_writer.as_default():
835
- tf.summary.scalar("val_loss", val_loss, step=epoch)
836
 
837
- # Save checkpoint
838
- manager.save()
 
 
 
 
839
 
840
- # Update history
841
- self.history['train_loss'].append(train_loss)
842
- self.history['val_loss'].append(val_loss)
843
- self.history.setdefault('learning_rate', []).append(float(current_lr_value))
844
 
845
- # Early stopping
846
- if val_loss < best_val_loss - min_delta:
847
- best_val_loss = val_loss
848
- epochs_no_improve = 0
849
- logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
 
 
 
 
 
 
 
 
850
  else:
851
- epochs_no_improve += 1
852
- logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
853
- if epochs_no_improve >= early_stopping_patience:
854
- logger.info("Early stopping triggered.")
855
- break
 
 
 
 
 
 
 
 
 
 
856
 
857
- logger.info("In-batch training completed!")
 
 
 
 
 
 
 
 
 
 
858
 
859
  def _get_lr_schedule(
860
  self,
@@ -994,277 +1382,554 @@ class RetrievalChatbot(DeviceAwareModel):
994
  conversation_parts.append(f"{self.special_tokens['user']} {query}")
995
  return "\n".join(conversation_parts)
996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998
 
 
 
 
 
 
999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1000
 
 
 
 
 
 
 
 
 
 
 
1001
 
1002
-
1003
-
1004
-
1005
-
1006
-
1007
- # def prepare_dataset(
1008
- # self,
1009
- # dialogues: List[dict],
1010
- # debug_samples: int = None
1011
- # ) -> Tuple[tf.Tensor, tf.Tensor]:
1012
- # """
1013
- # Prepares dataset for in-batch negatives:
1014
- # Only returns (query, positive) pairs.
1015
- # """
1016
- # logger.info("Preparing in-batch dataset...")
1017
-
1018
- # queries, positives = [], []
1019
-
1020
- # for dialogue in dialogues:
1021
- # turns = dialogue.get('turns', [])
1022
- # for i in range(len(turns) - 1):
1023
- # current_turn = turns[i]
1024
- # next_turn = turns[i+1]
1025
-
1026
- # if (current_turn.get('speaker') == 'user' and
1027
- # next_turn.get('speaker') == 'assistant' and
1028
- # 'text' in current_turn and
1029
- # 'text' in next_turn):
1030
 
1031
- # query = current_turn['text'].strip()
1032
- # positive = next_turn['text'].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1033
 
1034
- # queries.append(query)
1035
- # positives.append(positive)
1036
-
1037
- # # Optional debug slicing
1038
- # if debug_samples is not None:
1039
- # queries = queries[:debug_samples]
1040
- # positives = positives[:debug_samples]
1041
- # logger.info(f"Debug mode: limited to {debug_samples} pairs.")
1042
-
1043
- # logger.info(f"Prepared {len(queries)} (query, positive) pairs.")
1044
-
1045
- # # Tokenize queries
1046
- # encoded_queries = self.tokenizer(
1047
- # queries,
1048
- # padding='max_length',
1049
- # truncation=True,
1050
- # max_length=self.config.max_sequence_length,
1051
- # return_tensors='tf'
1052
- # )
1053
- # # Tokenize positives
1054
- # encoded_positives = self.tokenizer(
1055
- # positives,
1056
- # padding='max_length',
1057
- # truncation=True,
1058
- # max_length=self.config.max_sequence_length,
1059
- # return_tensors='tf'
1060
- # )
1061
-
1062
- # q_tensor = encoded_queries['input_ids']
1063
- # p_tensor = encoded_positives['input_ids']
1064
-
1065
- # logger.info("Tokenized and padded sequences for in-batch training.")
1066
- # return q_tensor, p_tensor
1067
-
1068
- # def train(
1069
- # self,
1070
- # q_pad: tf.Tensor,
1071
- # p_pad: tf.Tensor,
1072
- # epochs: int = 20,
1073
- # batch_size: int = 16,
1074
- # validation_split: float = 0.2,
1075
- # checkpoint_dir: str = "checkpoints/",
1076
- # use_lr_schedule: bool = True,
1077
- # peak_lr: float = 2e-5,
1078
- # warmup_steps_ratio: float = 0.1,
1079
- # early_stopping_patience: int = 3,
1080
- # min_delta: float = 1e-4
1081
- # ):
1082
- # dataset_size = tf.shape(q_pad)[0].numpy()
1083
- # val_size = int(dataset_size * validation_split)
1084
- # train_size = dataset_size - val_size
1085
-
1086
- # logger.info(f"Total samples: {dataset_size}")
1087
- # logger.info(f"Training samples: {train_size}")
1088
- # logger.info(f"Validation samples: {val_size}")
1089
-
1090
- # steps_per_epoch = train_size // batch_size
1091
- # if train_size % batch_size != 0:
1092
- # steps_per_epoch += 1
1093
- # total_steps = steps_per_epoch * epochs
1094
- # logger.info(f"Total training steps (approx): {total_steps}")
1095
-
1096
- # # 1) Set up LR schedule or fixed LR
1097
- # if use_lr_schedule:
1098
- # warmup_steps = int(total_steps * warmup_steps_ratio)
1099
- # lr_schedule = self._get_lr_schedule(
1100
- # total_steps=total_steps,
1101
- # peak_lr=peak_lr,
1102
- # warmup_steps=warmup_steps
1103
- # )
1104
- # self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
1105
- # logger.info("Using custom learning rate schedule.")
1106
- # else:
1107
- # self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
1108
- # logger.info("Using fixed learning rate.")
1109
-
1110
- # # 2) Prepare data splits
1111
- # train_q = q_pad[:train_size]
1112
- # train_p = p_pad[:train_size]
1113
- # val_q = q_pad[train_size:]
1114
- # val_p = p_pad[train_size:]
1115
-
1116
- # train_dataset = tf.data.Dataset.from_tensor_slices((train_q, train_p))
1117
- # train_dataset = train_dataset.shuffle(buffer_size=4096).batch(batch_size)
1118
-
1119
- # val_dataset = tf.data.Dataset.from_tensor_slices((val_q, val_p))
1120
- # val_dataset = val_dataset.batch(batch_size)
1121
-
1122
- # # 3) Checkpoint + manager
1123
- # checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
1124
- # manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
1125
-
1126
- # # 4) TensorBoard setup
1127
- # log_dir = Path(checkpoint_dir) / "tensorboard_logs"
1128
- # log_dir.mkdir(parents=True, exist_ok=True)
1129
-
1130
- # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
1131
- # train_log_dir = str(log_dir / f"train_{current_time}")
1132
- # val_log_dir = str(log_dir / f"val_{current_time}")
1133
-
1134
- # train_summary_writer = tf.summary.create_file_writer(train_log_dir)
1135
- # val_summary_writer = tf.summary.create_file_writer(val_log_dir)
1136
-
1137
- # logger.info(f"TensorBoard logs will be saved in {log_dir}")
1138
-
1139
- # # 5) Early stopping
1140
- # best_val_loss = float("inf")
1141
- # epochs_no_improve = 0
1142
-
1143
- # logger.info("Beginning training loop...")
1144
- # global_step = 0
1145
-
1146
- # from tqdm import tqdm
1147
- # for epoch in range(1, epochs + 1):
1148
- # logger.info(f"\n=== Epoch {epoch}/{epochs} ===")
1149
- # epoch_loss_avg = tf.keras.metrics.Mean()
1150
-
1151
- # # Training loop
1152
- # with tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}") as pbar:
1153
- # for (q_batch, p_batch) in train_dataset:
1154
- # global_step += 1
1155
-
1156
- # # Train step
1157
- # batch_loss = self._train_step(q_batch, p_batch)
1158
- # epoch_loss_avg(batch_loss)
1159
-
1160
- # # Get current LR
1161
- # if use_lr_schedule:
1162
- # lr = self.optimizer.learning_rate
1163
- # if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
1164
- # # Get the current step
1165
- # current_step = tf.cast(self.optimizer.iterations, tf.float32)
1166
- # # Compute the current learning rate
1167
- # current_lr = lr(current_step)
1168
- # else:
1169
- # # If learning_rate is not a schedule, use it directly
1170
- # current_lr = lr
1171
- # # Convert to float for logging
1172
- # current_lr_value = float(current_lr.numpy())
1173
- # else:
1174
- # # If using fixed learning rate
1175
- # current_lr_value = float(self.optimizer.learning_rate.numpy())
1176
-
1177
- # # Update tqdm
1178
- # pbar.update(1)
1179
- # pbar.set_postfix({
1180
- # "loss": f"{batch_loss.numpy():.4f}",
1181
- # "lr": f"{current_lr_value:.2e}"
1182
- # })
1183
-
1184
- # # TensorBoard: log train metrics per step
1185
- # with train_summary_writer.as_default():
1186
- # tf.summary.scalar("loss", batch_loss, step=global_step)
1187
- # tf.summary.scalar("learning_rate", current_lr_value, step=global_step)
1188
-
1189
- # # Validation
1190
- # val_loss_avg = tf.keras.metrics.Mean()
1191
- # for q_val, p_val in val_dataset:
1192
- # q_enc = self.encoder(q_val, training=False)
1193
- # p_enc = self.encoder(p_val, training=False)
1194
- # sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
1195
- # bs_val = tf.shape(q_enc)[0]
1196
- # labels_val = tf.range(bs_val, dtype=tf.int32)
1197
- # loss_val = tf.nn.sparse_softmax_cross_entropy_with_logits(
1198
- # labels=labels_val,
1199
- # logits=sim_matrix
1200
- # )
1201
- # val_loss_avg(tf.reduce_mean(loss_val))
1202
-
1203
- # train_loss = epoch_loss_avg.result().numpy()
1204
- # val_loss = val_loss_avg.result().numpy()
1205
-
1206
- # logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
1207
-
1208
- # # TensorBoard: validation loss
1209
- # with val_summary_writer.as_default():
1210
- # tf.summary.scalar("val_loss", val_loss, step=epoch)
1211
-
1212
- # # Save checkpoint
1213
- # manager.save()
1214
-
1215
- # # Update history
1216
- # self.history['train_loss'].append(train_loss)
1217
- # self.history['val_loss'].append(val_loss)
1218
- # self.history.setdefault('learning_rate', []).append(float(current_lr_value))
1219
-
1220
- # # Early stopping
1221
- # if val_loss < best_val_loss - min_delta:
1222
- # best_val_loss = val_loss
1223
- # epochs_no_improve = 0
1224
- # logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
1225
- # else:
1226
- # epochs_no_improve += 1
1227
- # logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
1228
- # if epochs_no_improve >= early_stopping_patience:
1229
- # logger.info("Early stopping triggered.")
1230
- # break
1231
-
1232
- # logger.info("In-batch training completed!")
1233
-
1234
- # @tf.function
1235
- # def _train_step(self, q_batch, p_batch):
1236
- # """
1237
- # Single training step using in-batch negatives.
1238
- # q_batch: (batch_size, seq_len) int32 input_ids for queries
1239
- # p_batch: (batch_size, seq_len) int32 input_ids for positives
1240
- # """
1241
- # with tf.GradientTape() as tape:
1242
- # # Encode queries and positives
1243
- # q_enc = self.encoder(q_batch, training=True) # [B, emb_dim]
1244
- # p_enc = self.encoder(p_batch, training=True) # [B, emb_dim]
1245
-
1246
- # # Compute similarity matrix: (B, B) = q_enc * p_enc^T
1247
- # # If embeddings are L2-normalized, this is cosine similarity
1248
- # sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True) # [B, B]
1249
-
1250
- # # Labels are just the diagonal indices
1251
- # batch_size = tf.shape(q_enc)[0]
1252
- # labels = tf.range(batch_size, dtype=tf.int32) # [0..B-1]
1253
-
1254
- # # Softmax cross-entropy
1255
- # loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1256
- # labels=labels,
1257
- # logits=sim_matrix
1258
- # )
1259
- # loss = tf.reduce_mean(loss)
1260
-
1261
- # # Compute gradients for the pretrained DistilBERT variables only
1262
- # train_vars = self.encoder.pretrained.trainable_variables
1263
- # gradients = tape.gradient(loss, train_vars)
1264
 
1265
- # # Remove any None grads (in case some layers are frozen)
1266
- # grads_and_vars = [(g, v) for g, v in zip(gradients, train_vars) if g is not None]
1267
- # if grads_and_vars:
1268
- # self.optimizer.apply_gradients(grads_and_vars)
 
 
 
 
 
 
 
 
 
 
 
1269
 
1270
- # return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
  from transformers import TFAutoModel, AutoTokenizer
3
  import tensorflow as tf
4
  import numpy as np
5
+ import threading
6
+ from queue import Queue, Empty
7
+ from typing import Generator, List, Tuple, Dict, Optional, Union, Any
8
  import math
9
  from dataclasses import dataclass
10
  import json
 
11
  from pathlib import Path
12
  import datetime
13
  import faiss
14
+ import gc
15
+ import random
16
  from response_quality_checker import ResponseQualityChecker
17
  from cross_encoder_reranker import CrossEncoderReranker
18
  from conversation_summarizer import DeviceAwareModel, Summarizer
19
+ from gpu_monitor import GPUMemoryMonitor
20
+ import absl.logging
21
  from logger_config import config_logger
22
+ from tqdm.auto import tqdm
23
+
24
+ absl.logging.set_verbosity(absl.logging.WARNING)
25
  logger = config_logger(__name__)
26
 
27
  @dataclass
28
  class ChatbotConfig:
29
  """Configuration for the RetrievalChatbot."""
30
+ vocab_size: int = 30526 # DistilBERT vocab size + special tokens
31
  max_context_token_limit: int = 512
32
+ embedding_dim: int = 512
33
  encoder_units: int = 256
34
  num_attention_heads: int = 8
35
  dropout_rate: float = 0.2
 
139
  summarizer = Summarizer(device=self.device)
140
  self.summarizer = summarizer
141
 
142
+ # # Configure XLA optimization if on GPU/TPU
143
+ # if self.device in ["GPU", "TPU"]:
144
+ # tf.config.optimizer.set_jit(True)
145
+ # logger.info(f"XLA compilation enabled for {self.device}")
146
 
147
+ # # Configure mixed precision for GPU/TPU
148
+ # if self.device != "CPU":
149
+ # policy = tf.keras.mixed_precision.Policy('mixed_float16')
150
+ # tf.keras.mixed_precision.set_global_policy(policy)
151
+ # logger.info("Mixed precision training enabled (float16)")
152
 
153
  # Special tokens
154
  self.special_tokens = {
 
164
  {'additional_special_tokens': list(self.special_tokens.values())}
165
  )
166
 
167
+ self.memory_monitor = GPUMemoryMonitor()
168
+ self.min_batch_size = 8
169
+ self.max_batch_size = 128
170
+ self.current_batch_size = 32
 
 
 
 
 
171
 
172
+ # Collect unique responses from dialogues
173
+ self.response_pool, self.unique_responses = self._collect_responses(dialogues)
174
 
175
  # Initialize training history
176
  self.history = {
 
180
  "val_metrics": {}
181
  }
182
 
183
+ def build_models(self):
184
  """Initialize the shared encoder."""
185
  logger.info("Building encoder model...")
186
+ tf.keras.backend.clear_session()
187
 
188
  # Shared encoder for both queries and responses
189
  self.encoder = EncoderModel(
 
196
  self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
197
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
198
 
199
+ # Initialize FAISS index (moved here from __init__)
200
+ self._initialize_faiss()
201
+ # Compute embeddings after FAISS is initialized and moved
202
+ self._compute_and_index_embeddings()
 
203
 
204
  # Try different ways to get embedding dimension
205
  try:
 
231
  logger.error("Vocabulary size is less than embedding dimension.")
232
  raise ValueError("Vocabulary size is less than embedding dimension.")
233
 
234
+ def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]:
235
+ """Collect all unique responses from dialogues."""
236
+ logger.info("Collecting responses from dialogues...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
 
238
  responses = []
239
+ try:
240
+ progress_bar = tqdm(dialogues, desc="Collecting assistant responses")
241
+ except ImportError:
242
+ progress_bar = dialogues
243
+ logger.info("Progress bar disabled - continuing without visual progress")
244
+
245
+ for dialogue in progress_bar:
246
  turns = dialogue.get('turns', [])
247
  for turn in turns:
248
  if turn.get('speaker') == 'assistant' and 'text' in turn:
 
252
  unique_responses = list(set(responses))
253
  logger.info(f"Found {len(unique_responses)} unique responses.")
254
 
255
+ return responses, unique_responses
256
+
257
+ def _adjust_batch_size(self) -> None:
258
+ """Dynamically adjust batch size based on GPU memory usage."""
259
+ if self.memory_monitor.should_reduce_batch_size():
260
+ new_size = max(self.min_batch_size, self.current_batch_size // 2)
261
+ if new_size != self.current_batch_size:
262
+ logger.info(f"Reducing batch size to {new_size} due to high memory usage")
263
+ self.current_batch_size = new_size
264
+ gc.collect()
265
+ if tf.config.list_physical_devices('GPU'):
266
+ tf.keras.backend.clear_session()
267
+ elif self.memory_monitor.can_increase_batch_size():
268
+ new_size = min(self.max_batch_size, self.current_batch_size * 2)
269
+ if new_size != self.current_batch_size:
270
+ logger.info(f"Increasing batch size to {new_size}")
271
+ self.current_batch_size = new_size
272
+
273
+ def _initialize_faiss(self):
274
+ """Initialize FAISS with safer GPU handling and memory monitoring."""
275
+ logger.info("Initializing FAISS index...")
276
 
277
+ # First, detect if we have GPU-enabled FAISS
278
+ self.faiss_gpu = False
279
+ self.gpu_resources = []
280
 
281
+ try:
282
+ if hasattr(faiss, 'get_num_gpus'):
283
+ ngpus = faiss.get_num_gpus()
284
+ if ngpus > 0:
285
+ # Configure GPU resources with memory limit
286
+ for i in range(ngpus):
287
+ res = faiss.StandardGpuResources()
288
+ # Set temp memory to 1/4 of total memory to avoid OOM
289
+ if self.memory_monitor.has_gpu:
290
+ stats = self.memory_monitor.get_memory_stats()
291
+ if stats:
292
+ temp_memory = int(stats.total * 0.25) # 25% of total memory
293
+ res.setTempMemory(temp_memory)
294
+ self.gpu_resources.append(res)
295
+ self.faiss_gpu = True
296
+ logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs")
297
+ else:
298
+ logger.info("Using CPU-only FAISS build")
299
+
300
+ except Exception as e:
301
+ logger.warning(f"Using CPU due to GPU initialization error: {e}")
302
 
303
+ # TODO: figure out buf with faiss-gpu
304
+ try:
305
+ # Create appropriate index based on dataset size
306
+ if len(self.unique_responses) < 1000:
307
+ logger.info("Small dataset detected, using simple FlatIP index")
308
+ self.index = faiss.IndexFlatIP(self.config.embedding_dim)
309
+ else:
310
+ # Use IVF index with dynamic number of clusters
311
+ # nlist = min(
312
+ # 25, # max clusters
313
+ # max(int(math.sqrt(len(self.unique_responses))), 1) # min 1 cluster
314
+ # )
315
+ # logger.info(f"Using IVF index with {nlist} clusters")
316
+
317
+ # quantizer = faiss.IndexFlatIP(self.config.embedding_dim)
318
+ # self.index = faiss.IndexIVFFlat(
319
+ # quantizer,
320
+ # self.config.embedding_dim,
321
+ # nlist,
322
+ # faiss.METRIC_INNER_PRODUCT
323
+ # )
324
+ self.index = faiss.IndexFlatIP(self.config.embedding_dim)
325
+
326
+ # # Move to GPU(s) if available
327
+ # if self.faiss_gpu and self.gpu_resources:
328
+ # try:
329
+ # if len(self.gpu_resources) > 1:
330
+ # self.index = faiss.index_cpu_to_gpus_list(self.index, self.gpu_resources)
331
+ # logger.info("FAISS index distributed across multiple GPUs")
332
+ # else:
333
+ # self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, self.index)
334
+ # logger.info("FAISS index moved to single GPU")
335
+ # except Exception as e:
336
+ # logger.warning(f"Failed to move index to GPU: {e}. Falling back to CPU")
337
+ # self.faiss_gpu = False
338
+
339
+ # # Set search parameters for IVF index
340
+ # if isinstance(self.index, faiss.IndexIVFFlat):
341
+ # self.index.nprobe = min(10, nlist)
342
+
343
+ except Exception as e:
344
+ logger.error(f"Error initializing FAISS: {e}")
345
+ raise
346
 
347
  def encode_responses(
348
  self,
 
350
  batch_size: int = 64
351
  ) -> tf.Tensor:
352
  """
353
+ Encodes responses with more conservative memory management.
 
 
 
 
 
 
 
 
 
354
  """
 
355
  all_embeddings = []
356
+ self.current_batch_size = batch_size
357
+
358
+ # Memory stats
359
+ # if self.memory_monitor.has_gpu:
360
+ # initial_stats = self.memory_monitor.get_memory_stats()
361
+ # if initial_stats:
362
+ # logger.info("Initial GPU memory state:")
363
+ # logger.info(f"Total: {initial_stats.total / 1e9:.2f}GB")
364
+ # logger.info(f"Used: {initial_stats.used / 1e9:.2f}GB")
365
+ # logger.info(f"Free: {initial_stats.free / 1e9:.2f}GB")
366
+
367
+ total_processed = 0
368
+
369
+ with tqdm(total=len(responses), desc="Encoding responses") as pbar:
370
+ while total_processed < len(responses):
371
+ # Monitor memory and adjust batch size
372
+ if self.memory_monitor.has_gpu:
373
+ gpu_usage = self.memory_monitor.get_memory_usage()
374
+ if gpu_usage > 0.8: # Over 80% usage
375
+ self.current_batch_size = max(128, self.current_batch_size // 2)
376
+ logger.info(f"High GPU memory usage ({gpu_usage:.1%}), reducing batch size to {self.current_batch_size}")
377
+ gc.collect()
378
+ tf.keras.backend.clear_session()
379
+
380
+ # Get batch
381
+ end_idx = min(total_processed + self.current_batch_size, len(responses))
382
+ batch_texts = responses[total_processed:end_idx]
383
+
384
+ try:
385
+ # Tokenize
386
+ encodings = self.tokenizer(
387
+ batch_texts,
388
+ padding='max_length',
389
+ truncation=True,
390
+ max_length=self.config.max_context_token_limit,
391
+ return_tensors='tf'
392
+ )
393
 
394
+ # Encode
395
+ embeddings_batch = self.encoder(encodings['input_ids'], training=False)
396
+
397
+ # Cast to float32
398
+ if embeddings_batch.dtype != tf.float32:
399
+ embeddings_batch = tf.cast(embeddings_batch, tf.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
+ # Store
402
+ all_embeddings.append(embeddings_batch)
403
+
404
+ # Update progress
405
+ batch_processed = len(batch_texts)
406
+ total_processed += batch_processed
407
+
408
+ # Update progress bar
409
+ if self.memory_monitor.has_gpu:
410
+ gpu_usage = self.memory_monitor.get_memory_usage()
411
+ pbar.set_postfix({
412
+ 'GPU mem': f'{gpu_usage:.1%}',
413
+ 'batch_size': self.current_batch_size
414
+ })
415
+ pbar.update(batch_processed)
416
+
417
+ # Memory cleanup every 1000 samples
418
+ if total_processed % 1000 == 0:
419
+ gc.collect()
420
+ if tf.config.list_physical_devices('GPU'):
421
+ tf.keras.backend.clear_session()
422
+
423
+ except tf.errors.ResourceExhaustedError:
424
+ logger.warning("GPU memory exhausted during encoding, reducing batch size")
425
+ self.current_batch_size = max(8, self.current_batch_size // 2)
426
+ continue
427
+
428
+ except Exception as e:
429
+ logger.error(f"Error during encoding: {str(e)}")
430
+ raise
431
 
432
+ # Concatenate results
433
+ #logger.info("Concatenating embeddings...")
434
  if len(all_embeddings) == 1:
 
435
  final_embeddings = all_embeddings[0]
436
  else:
 
437
  final_embeddings = tf.concat(all_embeddings, axis=0)
438
 
439
  return final_embeddings
440
 
441
+ def _train_faiss_index(self, response_embeddings: np.ndarray) -> None:
442
+ """Train FAISS index with better memory management and robust fallback mechanisms."""
443
+ if self.index.is_trained:
444
+ logger.info("Index already trained, skipping training phase")
445
+ return
446
+
447
+ logger.info("Starting FAISS index training...")
448
+
449
+ try:
450
+ # First attempt: Try training with smaller subset
451
+ subset_size = min(5000, len(response_embeddings)) # Reduced from 10000
452
+ logger.info(f"Using {subset_size} samples for initial training attempt")
453
+ subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False)
454
+ training_embeddings = response_embeddings[subset_idx].copy() # Make a copy
455
+
456
+ # Ensure contiguous memory layout
457
+ training_embeddings = np.ascontiguousarray(training_embeddings)
458
+
459
+ # Force cleanup before training
460
+ gc.collect()
461
+ if tf.config.list_physical_devices('GPU'):
462
+ tf.keras.backend.clear_session()
463
+
464
+ # Verify data properties
465
+ logger.info(f"FAISS training data shape: {training_embeddings.shape}")
466
+ logger.info(f"FAISS training data dtype: {training_embeddings.dtype}")
467
+
468
+ logger.info("Starting initial training attempt...")
469
+ self.index.train(training_embeddings)
470
+ logger.info("Training completed successfully")
471
+
472
+ except (RuntimeError, Exception) as e:
473
+ logger.warning(f"Initial training attempt failed: {str(e)}")
474
+ logger.info("Attempting fallback strategy...")
475
+
476
+ try:
477
+ # Move to CPU for more stable training
478
+ if self.faiss_gpu:
479
+ logger.info("Moving index to CPU for fallback training")
480
+ cpu_index = faiss.index_gpu_to_cpu(self.index)
481
+ else:
482
+ cpu_index = self.index
483
+
484
+ # Create simpler index type if needed
485
+ if isinstance(cpu_index, faiss.IndexIVFFlat):
486
+ logger.info("Creating simpler FlatL2 index for fallback")
487
+ cpu_index = faiss.IndexFlatL2(self.config.embedding_dim)
488
+
489
+ # Use even smaller subset for CPU training
490
+ subset_size = min(2000, len(response_embeddings))
491
+ subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False)
492
+ fallback_embeddings = response_embeddings[subset_idx].copy()
493
+
494
+ # Ensure data is properly formatted
495
+ if not fallback_embeddings.flags['C_CONTIGUOUS']:
496
+ fallback_embeddings = np.ascontiguousarray(fallback_embeddings)
497
+ if fallback_embeddings.dtype != np.float32:
498
+ fallback_embeddings = fallback_embeddings.astype(np.float32)
499
+
500
+ # Train on CPU
501
+ logger.info("Training fallback index on CPU...")
502
+ cpu_index.train(fallback_embeddings)
503
+
504
+ # Move back to GPU if needed
505
+ if self.faiss_gpu:
506
+ logger.info("Moving trained index back to GPU...")
507
+ if len(self.gpu_resources) > 1:
508
+ self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources)
509
+ else:
510
+ self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
511
+ else:
512
+ self.index = cpu_index
513
+
514
+ logger.info("Fallback training completed successfully")
515
+
516
+ except Exception as e2:
517
+ logger.error(f"Fallback training also failed: {str(e2)}")
518
+ logger.warning("Creating basic brute-force index as last resort")
519
+
520
+ try:
521
+ # Create basic brute-force index as last resort
522
+ dim = response_embeddings.shape[1]
523
+ basic_index = faiss.IndexFlatL2(dim)
524
+
525
+ if self.faiss_gpu:
526
+ if len(self.gpu_resources) > 1:
527
+ self.index = faiss.index_cpu_to_gpus_list(basic_index, self.gpu_resources)
528
+ else:
529
+ self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, basic_index)
530
+ else:
531
+ self.index = basic_index
532
+
533
+ logger.info("Basic index created as fallback")
534
+
535
+ except Exception as e3:
536
+ logger.error(f"All training attempts failed: {str(e3)}")
537
+ raise RuntimeError("Unable to create working FAISS index")
538
+
539
+ def _add_vectors_to_index(self, response_embeddings: np.ndarray) -> None:
540
+ """Add vectors to FAISS index with enhanced memory management."""
541
+ logger.info("Starting vector addition process...")
542
+
543
+ # Even smaller batches
544
+ initial_batch_size = 50 # Start smaller
545
+ min_batch_size = 10
546
+ max_batch_size = 500 # Lower maximum
547
+
548
+ total_added = 0
549
+ retry_count = 0
550
+ max_retries = 5
551
+
552
+ while total_added < len(response_embeddings):
553
+ try:
554
+ # Monitor memory
555
+ if self.memory_monitor.has_gpu:
556
+ gpu_usage = self.memory_monitor.get_memory_usage()
557
+ #logger.info(f"GPU memory usage before batch: {gpu_usage:.1%}")
558
+
559
+ # Force cleanup if memory usage is high
560
+ if gpu_usage > 0.7: # Lower threshold to 70%
561
+ logger.info("High memory usage detected, forcing cleanup")
562
+ gc.collect()
563
+ tf.keras.backend.clear_session()
564
+
565
+ # Get batch
566
+ end_idx = min(total_added + initial_batch_size, len(response_embeddings))
567
+ batch = response_embeddings[total_added:end_idx]
568
+
569
+ # Add batch
570
+ self.index.add(batch)
571
+
572
+ # Update progress
573
+ batch_size = len(batch)
574
+ total_added += batch_size
575
+ #logger.info(f"Added batch of {batch_size} vectors ({total_added}/{len(response_embeddings)} total)")
576
+
577
+ # Memory cleanup every few batches
578
+ if total_added % (initial_batch_size * 5) == 0:
579
+ gc.collect()
580
+ if tf.config.list_physical_devices('GPU'):
581
+ tf.keras.backend.clear_session()
582
+
583
+ # Gradually increase batch size
584
+ if initial_batch_size < max_batch_size:
585
+ initial_batch_size = min(initial_batch_size + 25, max_batch_size)
586
+
587
+ except Exception as e:
588
+ logger.warning(f"Error adding batch: {str(e)}")
589
+ retry_count += 1
590
+
591
+ if retry_count > max_retries:
592
+ logger.error("Max retries exceeded.")
593
+ raise
594
+
595
+ # Reduce batch size
596
+ initial_batch_size = max(min_batch_size, initial_batch_size // 2)
597
+ logger.info(f"Reducing batch size to {initial_batch_size} and retrying...")
598
+
599
+ # Cleanup
600
+ gc.collect()
601
+ if tf.config.list_physical_devices('GPU'):
602
+ tf.keras.backend.clear_session()
603
+
604
+ time.sleep(1) # Brief pause before retry
605
+
606
+ logger.info(f"Successfully added all {total_added} vectors to index")
607
+
608
+ def _add_vectors_cpu_fallback(self, remaining_embeddings: np.ndarray, already_added: int = 0) -> None:
609
+ """CPU fallback with extra safeguards and progress tracking."""
610
+ logger.info(f"CPU Fallback: Adding {len(remaining_embeddings)} remaining vectors...")
611
+
612
+ try:
613
+ # Move index to CPU
614
+ if self.faiss_gpu:
615
+ logger.info("Moving index to CPU...")
616
+ cpu_index = faiss.index_gpu_to_cpu(self.index)
617
+ else:
618
+ cpu_index = self.index
619
+
620
+ # Add remaining vectors on CPU with very small batches
621
+ batch_size = 50 # Extremely conservative batch size for CPU
622
+ total_added = already_added
623
+
624
+ for i in range(0, len(remaining_embeddings), batch_size):
625
+ end_idx = min(i + batch_size, len(remaining_embeddings))
626
+ batch = remaining_embeddings[i:end_idx]
627
+
628
+ # Add batch
629
+ cpu_index.add(batch)
630
+
631
+ # Update progress
632
+ total_added += len(batch)
633
+ if i % (batch_size * 10) == 0:
634
+ logger.info(f"Added {total_added} vectors total "
635
+ f"({i}/{len(remaining_embeddings)} in current phase)")
636
+
637
+ # Periodic cleanup
638
+ if i % (batch_size * 20) == 0:
639
+ gc.collect()
640
+
641
+ # Move back to GPU if needed
642
+ if self.faiss_gpu:
643
+ logger.info("Moving index back to GPU...")
644
+ if len(self.gpu_resources) > 1:
645
+ self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources)
646
+ else:
647
+ self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
648
+ else:
649
+ self.index = cpu_index
650
+
651
+ logger.info("CPU fallback completed successfully")
652
+
653
+ except Exception as e:
654
+ logger.error(f"Error during CPU fallback: {str(e)}")
655
+ raise
656
+
657
+ def _compute_and_index_embeddings(self):
658
+ """Compute embeddings and build FAISS index with simpler handling."""
659
+ logger.info("Computing embeddings and indexing with FAISS...")
660
+
661
+ try:
662
+ # Encode responses with memory monitoring
663
+ logger.info("Encoding unique responses")
664
+ response_embeddings = self.encode_responses(self.unique_responses)
665
+ response_embeddings = response_embeddings.numpy()
666
+
667
+ # Memory cleanup after encoding
668
+ gc.collect()
669
+ if tf.config.list_physical_devices('GPU'):
670
+ tf.keras.backend.clear_session()
671
+
672
+ # Ensure float32 and memory contiguous
673
+ response_embeddings = response_embeddings.astype('float32')
674
+ response_embeddings = np.ascontiguousarray(response_embeddings)
675
+
676
+ # Log memory state before normalization
677
+ if self.memory_monitor.has_gpu:
678
+ stats = self.memory_monitor.get_memory_stats()
679
+ if stats:
680
+ logger.info(f"GPU memory before normalization: {stats.used/1e9:.2f}GB used")
681
+
682
+ # Normalize embeddings
683
+ logger.info("Normalizing embeddings with FAISS")
684
+ faiss.normalize_L2(response_embeddings)
685
+
686
+ # Create and initialize simple FlatIP index
687
+ dim = response_embeddings.shape[1]
688
+ if self.faiss_gpu:
689
+ cpu_index = faiss.IndexFlatIP(dim)
690
+ if len(self.gpu_resources) > 1:
691
+ self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources)
692
+ else:
693
+ self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
694
+ else:
695
+ self.index = faiss.IndexFlatIP(dim)
696
+
697
+ # Add vectors to index
698
+ self._add_vectors_to_index(response_embeddings)
699
+
700
+ # Store responses and embeddings
701
+ self.response_pool = self.unique_responses
702
+ self.response_embeddings = response_embeddings
703
+
704
+ # Final memory cleanup
705
+ gc.collect()
706
+ if tf.config.list_physical_devices('GPU'):
707
+ tf.keras.backend.clear_session()
708
+
709
+ # Log final state
710
+ logger.info(f"Successfully indexed {self.index.ntotal} responses")
711
+ if self.memory_monitor.has_gpu:
712
+ stats = self.memory_monitor.get_memory_stats()
713
+ if stats:
714
+ logger.info(f"Final GPU memory usage: {stats.used/1e9:.2f}GB used")
715
+
716
+ logger.info("Indexing completed successfully")
717
+
718
+ except Exception as e:
719
+ logger.error(f"Error during indexing: {e}")
720
+ # Ensure cleanup even on error
721
+ gc.collect()
722
+ if tf.config.list_physical_devices('GPU'):
723
+ tf.keras.backend.clear_session()
724
+ raise
725
+
726
+ def verify_faiss_index(self):
727
+ """Verify that FAISS index matches the response pool."""
728
+ indexed_size = self.index.ntotal
729
+ pool_size = len(self.response_pool)
730
+ logger.info(f"FAISS index size: {indexed_size}")
731
+ logger.info(f"Response pool size: {pool_size}")
732
+ if indexed_size != pool_size:
733
+ logger.warning("Mismatch between FAISS index size and response pool size.")
734
+ else:
735
+ logger.info("FAISS index correctly matches the response pool.")
736
+
737
  def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
738
  """Encode a query with optional conversation context."""
739
  # Prepare query with context
 
812
  """Retrieve top-k responses using FAISS."""
813
  # Encode the query
814
  q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
815
+ q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
816
 
817
  # Normalize the query embedding for cosine similarity
818
  faiss.normalize_L2(q_emb_np)
 
899
  logger.info(f"Loaded {len(dialogues)} dialogues.")
900
  return dialogues
901
 
902
+ def train_streaming(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
  self,
904
+ dialogues: List[dict],
 
905
  epochs: int = 20,
906
  batch_size: int = 16,
907
  validation_split: float = 0.2,
 
911
  warmup_steps_ratio: float = 0.1,
912
  early_stopping_patience: int = 3,
913
  min_delta: float = 1e-4,
914
+ buffer_size: int = 10,
915
+ neg_samples: int = 1
916
+ ) -> None:
917
+ """
918
+ Streaming version of training that interleaves training/val batches by
919
+ giving priority to training until we meet `steps_per_epoch`, then
920
+ sending leftover batches to validation.
921
+ """
922
+ logger.info("Starting streaming training pipeline...")
923
+
924
+ # Initialize dataset preparer
925
+ dataset_preparer = StreamingDataPipeline(
926
+ tokenizer=self.tokenizer,
927
+ encoder=self.encoder,
928
+ index=self.index,
929
+ response_pool=self.response_pool,
930
+ max_length=self.config.max_context_token_limit,
931
+ batch_size=batch_size,
932
+ neg_samples=neg_samples
933
+ )
934
 
935
+ # Calculate total steps for learning rate schedule
936
+ total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
937
+ train_size = total_pairs * (1 - validation_split)
938
+ steps_per_epoch = int(math.ceil(train_size / batch_size))
939
+ val_steps = int(math.ceil((total_pairs * validation_split) / batch_size))
940
  total_steps = steps_per_epoch * epochs
 
941
 
942
+ logger.info(f"Total pairs: {total_pairs}")
943
+ logger.info(f"Training pairs: {train_size}")
944
+ logger.info(f"Steps per epoch: {steps_per_epoch}")
945
+ logger.info(f"Validation steps: {val_steps}")
946
+ logger.info(f"Total steps: {total_steps}")
947
+
948
+ # Set up optimizer with learning rate schedule
949
  if use_lr_schedule:
950
  warmup_steps = int(total_steps * warmup_steps_ratio)
951
  lr_schedule = self._get_lr_schedule(
 
959
  self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
960
  logger.info("Using fixed learning rate.")
961
 
962
+ # Initialize checkpoint manager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963
  checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
964
  manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
965
 
966
+ # Setup TensorBoard
967
  log_dir = Path(checkpoint_dir) / "tensorboard_logs"
968
  log_dir.mkdir(parents=True, exist_ok=True)
 
969
  current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
970
  train_log_dir = str(log_dir / f"train_{current_time}")
971
  val_log_dir = str(log_dir / f"val_{current_time}")
 
972
  train_summary_writer = tf.summary.create_file_writer(train_log_dir)
973
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
974
+
975
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
976
 
977
+ # Training loop
978
  best_val_loss = float("inf")
979
  epochs_no_improve = 0
980
 
981
+ try:
982
+ epoch_pbar = tqdm(range(1, epochs + 1), desc="Training", unit="epoch")
983
+ is_tqdm_epoch = True
984
+ except ImportError:
985
+ epoch_pbar = range(1, epochs + 1)
986
+ is_tqdm_epoch = False
987
+ logger.info("Epoch progress bar disabled - continuing without visual progress")
988
+
989
+ for epoch in epoch_pbar:
990
+ # Shared queues for streaming pipeline
991
+ train_queue = Queue(maxsize=buffer_size)
992
+ val_queue = Queue(maxsize=buffer_size)
993
+ stop_flag = threading.Event()
994
+
995
+ def data_pipeline_worker():
996
+ """Thread function that processes dialogues and sends batches to train or val."""
997
+ try:
998
+ train_batches_needed = steps_per_epoch # 9 in your logs
999
+ val_batches_needed = val_steps # 3 in your logs
1000
+ train_batches_sent = 0
1001
+ val_batches_sent = 0
1002
+
1003
+ logger.info(f"Pipeline starting: need {train_batches_needed} train batches, {val_batches_needed} val batches")
1004
+
1005
+ # Possibly shuffle your processed pairs to avoid repeating them in the same order
1006
+ # (If you haven't already done so in the pipeline)
1007
+ random.shuffle(dataset_preparer.processed_pairs)
1008
+
1009
+ while (train_batches_sent < train_batches_needed or
1010
+ val_batches_sent < val_batches_needed):
1011
+
1012
+ # We loop over the generator
1013
+ for batch in dataset_preparer.process_dialogues(dialogues):
1014
+ if stop_flag.is_set():
1015
+ logger.warning("Pipeline stopped early")
1016
+ break
1017
+
1018
+ if train_batches_sent < train_batches_needed:
1019
+ train_queue.put(batch)
1020
+ train_batches_sent += 1
1021
+ elif val_batches_sent < val_batches_needed:
1022
+ val_queue.put(batch)
1023
+ val_batches_sent += 1
1024
+ else:
1025
+ # We have enough batches for both train & val
1026
+ break
1027
+
1028
+ # If we still haven't met our target steps, REPEAT the data
1029
+ if train_batches_sent < train_batches_needed or val_batches_sent < val_batches_needed:
1030
+ logger.info("Data exhausted, repeating since we still need more batches...")
1031
+ # Optionally shuffle again
1032
+ random.shuffle(dataset_preparer.processed_pairs)
1033
+ else:
1034
+ # We have enough
1035
+ break
1036
 
1037
+ logger.info(
1038
+ f"Pipeline complete: sent {train_batches_sent}/{train_batches_needed} train batches, "
1039
+ f"{val_batches_sent}/{val_batches_needed} val batches"
1040
+ )
1041
 
1042
+ except Exception as e:
1043
+ logger.error(f"Error in pipeline worker: {str(e)}")
1044
+ raise e
1045
+ finally:
1046
+ train_queue.put(None)
1047
+ val_queue.put(None)
1048
 
1049
+ # Start data preparation pipeline in background thread
1050
+ pipeline_thread = threading.Thread(target=data_pipeline_worker)
1051
+ pipeline_thread.start()
1052
 
1053
+ try:
1054
+ # --- Training Phase ---
1055
+ epoch_loss_avg = tf.keras.metrics.Mean()
1056
+ batches_processed = 0
 
 
1057
 
1058
+ try:
1059
+ train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}")
1060
+ is_tqdm_train = True
1061
+ except ImportError:
1062
+ train_pbar = None
1063
+ is_tqdm_train = False
1064
+ logger.info("Training progress bar disabled")
1065
+
1066
+ while batches_processed < steps_per_epoch:
1067
+ try:
1068
+ batch = train_queue.get(timeout=1200) # 20 minutes timeout
1069
+ if batch is None:
1070
+ logger.warning(f"Received end signal after only {batches_processed}/{steps_per_epoch} batches")
1071
+ break
1072
+
1073
+ q_batch, p_batch = batch[0], batch[1]
1074
+ attention_mask = batch[2] if len(batch) > 2 else None
1075
+
1076
+ loss = self.train_step(q_batch, p_batch, attention_mask)
1077
+ epoch_loss_avg(loss)
1078
+ batches_processed += 1
1079
+
1080
+ # Log to TensorBoard
1081
+ with train_summary_writer.as_default():
1082
+ tf.summary.scalar("loss", loss, step=epoch)
1083
+
1084
+ # Update progress bar
1085
+ if use_lr_schedule:
1086
+ current_lr = float(lr_schedule(self.optimizer.iterations))
1087
  else:
1088
+ current_lr = float(self.optimizer.learning_rate.numpy())
1089
+
1090
+ if is_tqdm_train:
1091
+ train_pbar.update(1)
1092
+ train_pbar.set_postfix({
1093
+ "loss": f"{loss.numpy():.4f}",
1094
+ "lr": f"{current_lr:.2e}",
1095
+ "batches": f"{batches_processed}/{steps_per_epoch}"
1096
+ })
1097
+
1098
+ except Empty:
1099
+ logger.warning(f"Queue timeout after {batches_processed}/{steps_per_epoch} batches")
1100
+ break
1101
+
1102
+ if is_tqdm_train and train_pbar:
1103
+ train_pbar.close()
1104
+
1105
+ # --- Validation Phase ---
1106
+ val_loss_avg = tf.keras.metrics.Mean()
1107
+ val_batches_processed = 0
1108
+
1109
+ try:
1110
+ val_pbar = tqdm(total=val_steps, desc="Validation")
1111
+ is_tqdm_val = True
1112
+ except ImportError:
1113
+ val_pbar = None
1114
+ is_tqdm_val = False
1115
+ logger.info("Validation progress bar disabled")
1116
+
1117
+ while val_batches_processed < val_steps:
1118
+ try:
1119
+ batch = val_queue.get(timeout=30)
1120
+ if batch is None:
1121
+ logger.warning(
1122
+ f"Received end signal after {val_batches_processed}/{val_steps} validation batches"
1123
+ )
1124
+ break
1125
+
1126
+ q_batch, p_batch = batch[0], batch[1]
1127
+ attention_mask = batch[2] if len(batch) > 2 else None
1128
+
1129
+ val_loss = self.validation_step(q_batch, p_batch, attention_mask)
1130
+ val_loss_avg(val_loss)
1131
+ val_batches_processed += 1
1132
+
1133
+ if is_tqdm_val:
1134
+ val_pbar.update(1)
1135
+ val_pbar.set_postfix({
1136
+ "val_loss": f"{val_loss.numpy():.4f}",
1137
+ "batches": f"{val_batches_processed}/{val_steps}"
1138
+ })
1139
+
1140
+ except Empty:
1141
+ logger.warning(
1142
+ f"Validation queue timeout after {val_batches_processed}/{val_steps} batches"
1143
+ )
1144
+ break
1145
+
1146
+ if is_tqdm_val and val_pbar:
1147
+ val_pbar.close()
1148
+
1149
+ # End of epoch: compute final epoch stats
1150
+ train_loss = epoch_loss_avg.result().numpy()
1151
+ val_loss = val_loss_avg.result().numpy()
1152
+ logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
1153
+
1154
+ # Log epoch metrics
1155
+ with val_summary_writer.as_default():
1156
+ tf.summary.scalar("val_loss", val_loss, step=epoch)
1157
+
1158
+ # Save checkpoint
1159
+ manager.save()
1160
+
1161
+ # Store metrics in history
1162
+ self.history['train_loss'].append(train_loss)
1163
+ self.history['val_loss'].append(val_loss)
1164
+
1165
+ if use_lr_schedule:
1166
+ current_lr = float(lr_schedule(self.optimizer.iterations))
1167
+ else:
1168
+ current_lr = float(self.optimizer.learning_rate.numpy())
1169
+
1170
+ self.history.setdefault('learning_rate', []).append(current_lr)
1171
+
1172
+ # Early stopping logic
1173
+ if val_loss < best_val_loss - min_delta:
1174
+ best_val_loss = val_loss
1175
+ epochs_no_improve = 0
1176
+ logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
1177
+ else:
1178
+ epochs_no_improve += 1
1179
+ logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
1180
+ if epochs_no_improve >= early_stopping_patience:
1181
+ logger.info("Early stopping triggered.")
1182
+ break
1183
 
1184
+ except Exception as e:
1185
+ logger.error(f"Error during training: {str(e)}")
1186
+ stop_flag.set()
1187
+ raise e
1188
+ finally:
1189
+ # Clean up epoch resources
1190
+ stop_flag.set()
1191
+ pipeline_thread.join()
1192
 
1193
+ logger.info("Streaming training completed!")
1194
 
 
 
 
1195
 
1196
+ @tf.function
1197
+ def train_step(self, q_batch: tf.Tensor, p_batch: tf.Tensor, attention_mask: Optional[tf.Tensor] = None) -> tf.Tensor:
1198
+ """Single training step with tf.function optimization and partial batch handling."""
1199
+ with tf.GradientTape() as tape:
1200
+ q_enc = self.encoder(q_batch, training=True)
1201
+ p_enc = self.encoder(p_batch, training=True)
1202
 
1203
+ sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
 
 
 
1204
 
1205
+ # Handle partial batches
1206
+ batch_size = tf.shape(q_enc)[0]
1207
+ labels = tf.range(batch_size, dtype=tf.int32)
1208
+
1209
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1210
+ labels=labels, logits=sim_matrix
1211
+ )
1212
+
1213
+ # If there's an attention mask, apply it
1214
+ if attention_mask is not None:
1215
+ loss = loss * attention_mask
1216
+ # normalize by the sum of attention_mask
1217
+ loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1218
  else:
1219
+ loss = tf.reduce_mean(loss)
1220
+
1221
+ gradients = tape.gradient(loss, self.encoder.trainable_variables)
1222
+ self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
1223
+ return loss
1224
+
1225
+ @tf.function
1226
+ def validation_step(self, q_batch: tf.Tensor, p_batch: tf.Tensor, attention_mask: Optional[tf.Tensor] = None) -> tf.Tensor:
1227
+ """Single validation step with partial batch handling."""
1228
+ q_enc = self.encoder(q_batch, training=False)
1229
+ p_enc = self.encoder(p_batch, training=False)
1230
+
1231
+ sim_matrix = tf.matmul(q_enc, p_enc, transpose_b=True)
1232
+ batch_size = tf.shape(q_enc)[0]
1233
+ labels = tf.range(batch_size, dtype=tf.int32)
1234
 
1235
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1236
+ labels=labels, logits=sim_matrix
1237
+ )
1238
+
1239
+ if attention_mask is not None:
1240
+ loss = loss * attention_mask
1241
+ loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1242
+ else:
1243
+ loss = tf.reduce_mean(loss)
1244
+
1245
+ return loss
1246
 
1247
  def _get_lr_schedule(
1248
  self,
 
1382
  conversation_parts.append(f"{self.special_tokens['user']} {query}")
1383
  return "\n".join(conversation_parts)
1384
 
1385
+ class StreamingDataPipeline:
1386
+ """Helper class to manage the streaming data preparation pipeline with optimized caching and GPU usage."""
1387
+ def __init__(
1388
+ self,
1389
+ tokenizer,
1390
+ encoder,
1391
+ index,
1392
+ response_pool,
1393
+ max_length: int,
1394
+ batch_size: int,
1395
+ neg_samples: int
1396
+ ):
1397
+ self.tokenizer = tokenizer
1398
+ self.encoder = encoder
1399
+ self.index = index
1400
+ self.response_pool = response_pool
1401
+ self.max_length = max_length
1402
+ self.base_batch_size = batch_size
1403
+ self.neg_samples = neg_samples
1404
+ self.memory_monitor = GPUMemoryMonitor()
1405
+
1406
+ # Caching structures
1407
+ self.hard_negatives_cache = {}
1408
+ self.processed_pairs = []
1409
+ self.query_embeddings_cache = {}
1410
+
1411
+ # Error tracking
1412
+ self.error_count = 0
1413
+ self.max_retries = 3
1414
+
1415
+ # Batch processing settings
1416
+ self.current_batch_size = batch_size
1417
+ self.batch_increase_factor = 1.25
1418
+
1419
+ # TODO: use GPU/strategy
1420
+ if len(response_pool) < 100:
1421
+ self.embedding_batch_size = 16
1422
+ self.search_batch_size = 8
1423
+ self.max_batch_size = 32
1424
+ self.min_batch_size = 4
1425
+ else:
1426
+ self.embedding_batch_size = 64
1427
+ self.search_batch_size = 32
1428
+ self.min_batch_size = max(8, batch_size // 4)
1429
+ self.max_batch_size = 64
1430
+
1431
+ def save_cache(self, cache_dir: Path) -> None:
1432
+ """Save all cached data for future runs."""
1433
+ cache_dir = Path(cache_dir)
1434
+ cache_dir.mkdir(parents=True, exist_ok=True)
1435
+
1436
+ logger.info(f"Saving cache to {cache_dir}")
1437
+
1438
+ # Save embeddings cache
1439
+ embeddings_path = cache_dir / "query_embeddings.npy"
1440
+ np.save(
1441
+ embeddings_path,
1442
+ {k: v.numpy() if hasattr(v, 'numpy') else v
1443
+ for k, v in self.query_embeddings_cache.items()}
1444
+ )
1445
+
1446
+ # Save hard negatives and processed pairs
1447
+ with open(cache_dir / "hard_negatives.json", 'w') as f:
1448
+ json.dump(self.hard_negatives_cache, f)
1449
+
1450
+ with open(cache_dir / "processed_pairs.json", 'w') as f:
1451
+ json.dump(self.processed_pairs, f)
1452
+
1453
+ logger.info("Cache saved successfully")
1454
+
1455
+ def load_cache(self, cache_dir: Path) -> bool:
1456
+ """Load cached data if available."""
1457
+ cache_dir = Path(cache_dir)
1458
+ required_files = [
1459
+ "query_embeddings.npy",
1460
+ "hard_negatives.json",
1461
+ "processed_pairs.json"
1462
+ ]
1463
+
1464
+ if not all((cache_dir / f).exists() for f in required_files):
1465
+ logger.info("Cache files not found")
1466
+ return False
1467
+
1468
+ try:
1469
+ logger.info("Loading cache...")
1470
+
1471
+ # Load embeddings
1472
+ self.query_embeddings_cache = np.load(
1473
+ cache_dir / "query_embeddings.npy",
1474
+ allow_pickle=True
1475
+ ).item()
1476
+
1477
+ # Load other caches
1478
+ with open(cache_dir / "hard_negatives.json", 'r') as f:
1479
+ self.hard_negatives_cache = json.load(f)
1480
+
1481
+ with open(cache_dir / "processed_pairs.json", 'r') as f:
1482
+ self.processed_pairs = json.load(f)
1483
+
1484
+ logger.info(f"Cache loaded successfully: {len(self.processed_pairs)} pairs")
1485
+ return True
1486
+
1487
+ except Exception as e:
1488
+ logger.error(f"Error loading cache: {e}")
1489
+ return False
1490
+
1491
+ def _adjust_batch_size(self) -> None:
1492
+ """Dynamically adjust batch size based on GPU memory usage."""
1493
+ if self.memory_monitor:
1494
+ if self.memory_monitor.should_reduce_batch_size():
1495
+ new_size = max(self.min_batch_size, self.current_batch_size // 2)
1496
+ if new_size != self.current_batch_size:
1497
+ if new_size < self.min_batch_size:
1498
+ logger.info(f"Reducing batch size to {new_size} due to high memory usage")
1499
+ self.current_batch_size = new_size
1500
+ gc.collect()
1501
+ if tf.config.list_physical_devices('GPU'):
1502
+ tf.keras.backend.clear_session()
1503
+
1504
+ elif self.memory_monitor.can_increase_batch_size():
1505
+ new_size = min(self.max_batch_size, int(self.current_batch_size * self.batch_increase_factor)) # More gradual increase
1506
+ if new_size != self.current_batch_size:
1507
+ if new_size > self.max_batch_size:
1508
+ logger.info(f"Increasing batch size to {new_size}")
1509
+ self.current_batch_size = new_size
1510
+
1511
+ def _add_progress_metrics(self, pbar, **metrics) -> None:
1512
+ """Add memory and batch size metrics to progress bars."""
1513
+ if self.memory_monitor:
1514
+ gpu_usage = self.memory_monitor.get_memory_usage()
1515
+ metrics['gpu_mem'] = f"{gpu_usage:.1%}"
1516
+ metrics['batch_size'] = self.current_batch_size
1517
+ pbar.set_postfix(**metrics)
1518
+
1519
+ def preprocess_dialogues(self, dialogues: List[dict]) -> None:
1520
+ """Preprocess all dialogues with error recovery and caching."""
1521
+ retry_count = 0
1522
+
1523
+ while retry_count < self.max_retries:
1524
+ try:
1525
+ self._preprocess_dialogues_internal(dialogues)
1526
+ break
1527
+ except Exception as e:
1528
+ retry_count += 1
1529
+ logger.warning(f"Preprocessing attempt {retry_count} failed: {e}")
1530
+ if retry_count == self.max_retries:
1531
+ logger.error("Max retries reached. Falling back to CPU processing")
1532
+ self._fallback_to_cpu_processing(dialogues)
1533
+
1534
+ def _preprocess_dialogues_internal(self, dialogues: List[dict]) -> None:
1535
+ """Internal preprocessing implementation with progress tracking."""
1536
+ logger.info("Starting dialogue preprocessing...")
1537
+
1538
+ # Collect unique queries and pairs
1539
+ unique_queries = set()
1540
+ query_positive_pairs = []
1541
+
1542
+ with tqdm(total=len(dialogues), desc="Collecting dialogue pairs") as pbar:
1543
+ for dialogue in dialogues:
1544
+ pairs = self._extract_pairs_from_dialogue(dialogue)
1545
+ for query, positive in pairs:
1546
+ unique_queries.add(query)
1547
+ query_positive_pairs.append((query, positive))
1548
+ pbar.update(1)
1549
+ self._add_progress_metrics(pbar, pairs=len(query_positive_pairs))
1550
+
1551
+ # Precompute embeddings
1552
+ logger.info("Precomputing query embeddings...")
1553
+ self.precompute_query_embeddings(list(unique_queries))
1554
+
1555
+ # Find hard negatives
1556
+ logger.info("Finding hard negatives for all pairs...")
1557
+ self._find_hard_negatives_for_pairs(query_positive_pairs)
1558
 
1559
+ def precompute_query_embeddings(self, queries: List[str]) -> None:
1560
+ """Precompute embeddings for all unique queries in batches."""
1561
+ unique_queries = list(set(queries))
1562
+
1563
+ with tqdm(total=len(unique_queries), desc="Precomputing query embeddings") as pbar:
1564
+ for i in range(0, len(unique_queries), self.embedding_batch_size):
1565
+ # Adjust batch size based on memory
1566
+ self._adjust_batch_size()
1567
+ batch_size = min(self.embedding_batch_size, len(unique_queries) - i)
1568
+
1569
+ # Get batch of queries
1570
+ batch_queries = unique_queries[i:i + batch_size]
1571
+
1572
+ try:
1573
+ # Tokenize batch
1574
+ encoded = self.tokenizer(
1575
+ batch_queries,
1576
+ padding=True,
1577
+ truncation=True,
1578
+ max_length=self.max_length,
1579
+ return_tensors='tf'
1580
+ )
1581
+
1582
+ # Get embeddings
1583
+ embeddings = self.encoder(encoded['input_ids'], training=False)
1584
+ embeddings_np = embeddings.numpy().astype('float32')
1585
+
1586
+ # Normalize for similarity search
1587
+ faiss.normalize_L2(embeddings_np)
1588
+
1589
+ # Cache embeddings
1590
+ for query, emb in zip(batch_queries, embeddings_np):
1591
+ self.query_embeddings_cache[query] = emb
1592
+
1593
+ pbar.update(len(batch_queries))
1594
+ self._add_progress_metrics(
1595
+ pbar,
1596
+ cached=len(self.query_embeddings_cache),
1597
+ batch_size=batch_size
1598
+ )
1599
+
1600
+ except Exception as e:
1601
+ logger.warning(f"Error processing batch: {e}")
1602
+ # Reduce batch size and retry
1603
+ self.embedding_batch_size = max(self.min_batch_size, self.embedding_batch_size // 2)
1604
+ continue
1605
+
1606
+ # Memory cleanup after successful batch
1607
+ if i % (self.embedding_batch_size * 10) == 0:
1608
+ gc.collect()
1609
+ if tf.config.list_physical_devices('GPU'):
1610
+ tf.keras.backend.clear_session()
1611
+
1612
+ logger.info(f"Cached embeddings for {len(self.query_embeddings_cache)} unique queries")
1613
+
1614
+ def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
1615
+ """Extract query-response pairs from a dialogue."""
1616
+ pairs = []
1617
+ turns = dialogue.get('turns', [])
1618
+
1619
+ for i in range(len(turns) - 1):
1620
+ current_turn = turns[i]
1621
+ next_turn = turns[i+1]
1622
+
1623
+ if (current_turn.get('speaker') == 'user' and
1624
+ next_turn.get('speaker') == 'assistant' and
1625
+ 'text' in current_turn and
1626
+ 'text' in next_turn):
1627
+
1628
+ query = current_turn['text'].strip()
1629
+ positive = next_turn['text'].strip()
1630
+ pairs.append((query, positive))
1631
+
1632
+ return pairs
1633
+
1634
+ def _fallback_to_cpu_processing(self, dialogues: List[dict]) -> None:
1635
+ """Fallback processing method using CPU only."""
1636
+ logger.info("Falling back to CPU-only processing")
1637
+ # Reset GPU-specific settings
1638
+ self.current_batch_size = self.min_batch_size
1639
+ self.embedding_batch_size = 32
1640
+ self.search_batch_size = 16
1641
+
1642
+ # Attempt preprocessing with reduced batches
1643
+ self._preprocess_dialogues_internal(dialogues)
1644
+
1645
+ def process_dialogues(self, dialogues: List[dict]) -> Generator[Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]], None, None]:
1646
+ """
1647
+ Process dialogues using cached data with dynamic batch sizing.
1648
+ Yields (q_tokens['input_ids'], p_tokens['input_ids'], attention_mask) tuples.
1649
+ """
1650
+ # Preprocess if not already done
1651
+ if not self.processed_pairs:
1652
+ self.preprocess_dialogues(dialogues)
1653
+
1654
+ # Generate batches from cached data
1655
+ current_queries = []
1656
+ current_positives = []
1657
+
1658
+ # Counters for logging
1659
+ total_examples_yielded = 0
1660
+ total_batches_yielded = 0
1661
+
1662
+ with tqdm(total=len(self.processed_pairs), desc="Generating training batches", leave=False) as pbar:
1663
+ for i, (query, positive) in enumerate(self.processed_pairs):
1664
+ # Periodically adjust batch size
1665
+ if i % 10 == 0: # Check more frequently (e.g., every 10 pairs)
1666
+ self._adjust_batch_size()
1667
+
1668
+ # Add original pair
1669
+ current_queries.append(query)
1670
+ current_positives.append(positive)
1671
+
1672
+ # Add cached hard negatives for each query
1673
+ hard_negatives = self.hard_negatives_cache.get((query, positive), [])
1674
+ for neg_text in hard_negatives:
1675
+ current_queries.append(query)
1676
+ current_positives.append(neg_text)
1677
+
1678
+ # If we have enough examples to form a full batch, yield it
1679
+ while len(current_queries) >= self.current_batch_size:
1680
+ batch_queries = current_queries[:self.current_batch_size]
1681
+ batch_positives = current_positives[:self.current_batch_size]
1682
+
1683
+ # Update counters and logs
1684
+ batch_size_to_yield = len(batch_queries)
1685
+ total_examples_yielded += batch_size_to_yield
1686
+ total_batches_yielded += 1
1687
+
1688
+ yield self._prepare_batch(batch_queries, batch_positives, pad_to_batch_size=False)
1689
+
1690
+ # Remove used entries
1691
+ current_queries = current_queries[self.current_batch_size:]
1692
+ current_positives = current_positives[self.current_batch_size:]
1693
+
1694
+ # Update progress bar
1695
+ pbar.update(1)
1696
+ self._add_progress_metrics(
1697
+ pbar,
1698
+ pairs_processed=pbar.n,
1699
+ pending_pairs=len(current_queries)
1700
+ )
1701
 
1702
+ # After the loop, if anything is left, yield a final partial batch
1703
+ if current_queries:
1704
+ leftover_size = len(current_queries)
1705
+ total_examples_yielded += leftover_size
1706
+ total_batches_yielded += 1
1707
 
1708
+ yield self._prepare_batch(
1709
+ current_queries,
1710
+ current_positives,
1711
+ pad_to_batch_size=True
1712
+ )
1713
+
1714
+ def _find_hard_negatives_for_pairs(self, query_positive_pairs: List[Tuple[str, str]]) -> None:
1715
+ """Process pairs in batches to find hard negatives with GPU acceleration."""
1716
+ total_pairs = len(query_positive_pairs)
1717
+
1718
+ # Use smaller batch size for small datasets
1719
+ if len(self.response_pool) < 1000:
1720
+ batch_size = min(8, self.search_batch_size)
1721
+ else:
1722
+ batch_size = self.search_batch_size
1723
+
1724
+ try:
1725
+ pbar = tqdm(total=total_pairs, desc="Finding hard negatives")
1726
+ is_tqdm = True
1727
+ except ImportError:
1728
+ pbar = None
1729
+ is_tqdm = False
1730
+ logger.info("Progress bar disabled - continuing without visual progress")
1731
+
1732
+ for i in range(0, total_pairs, batch_size):
1733
+ self._adjust_batch_size()
1734
+
1735
+ batch_pairs = query_positive_pairs[i:i + batch_size]
1736
+ batch_queries, batch_positives = zip(*batch_pairs)
1737
+
1738
+ batch_negatives = self._find_hard_negatives_batch(
1739
+ list(batch_queries),
1740
+ list(batch_positives)
1741
+ )
1742
+
1743
+ for query, positive, negatives in zip(batch_queries, batch_positives, batch_negatives):
1744
+ self.hard_negatives_cache[(query, positive)] = negatives
1745
+ self.processed_pairs.append((query, positive))
1746
+
1747
+ if is_tqdm:
1748
+ pbar.update(len(batch_pairs))
1749
+ self._add_progress_metrics(
1750
+ pbar,
1751
+ cached=len(self.processed_pairs),
1752
+ progress=f"{i+len(batch_pairs)}/{total_pairs}"
1753
+ )
1754
+
1755
+ if is_tqdm:
1756
+ pbar.close()
1757
+
1758
+ def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
1759
+ """Find hard negatives for a batch of queries with error handling and retries."""
1760
+ retry_count = 0
1761
+ total_responses = len(self.response_pool)
1762
+
1763
+ # For very small datasets (testing), just use random sampling
1764
+ if total_responses < 100:
1765
+ all_negatives = []
1766
+ for positive in positives:
1767
+ available = [r for r in self.response_pool if r.strip() != positive.strip()]
1768
+ if available:
1769
+ negatives = list(np.random.choice(
1770
+ available,
1771
+ size=min(self.neg_samples, len(available)),
1772
+ replace=False
1773
+ ))
1774
+ else:
1775
+ negatives = []
1776
+ # Pad with empty strings if needed
1777
+ while len(negatives) < self.neg_samples:
1778
+ negatives.append("")
1779
+ all_negatives.append(negatives)
1780
+ return all_negatives
1781
+
1782
+ while retry_count < self.max_retries:
1783
+ try:
1784
+ # Get cached embeddings and ensure they're the right type
1785
+ query_embeddings = np.vstack([
1786
+ self.query_embeddings_cache[q] for q in queries
1787
+ ]).astype(np.float32)
1788
 
1789
+ if not query_embeddings.flags['C_CONTIGUOUS']:
1790
+ query_embeddings = np.ascontiguousarray(query_embeddings)
1791
+
1792
+ # Normalize embeddings
1793
+ faiss.normalize_L2(query_embeddings)
1794
+
1795
+ k = 1 #min(total_responses - 1, max(3, self.neg_samples + 2))
1796
+ #logger.debug(f"Searching with k={k} among {total_responses} responses")
1797
+
1798
+ assert query_embeddings.dtype == np.float32, f"Embeddings are not float32: {query_embeddings.dtype}" # Assertion here
1799
 
1800
+ try:
1801
+ distances, indices = self.index.search(query_embeddings, k)
1802
+ except RuntimeError as e:
1803
+ logger.error(f"FAISS search failed: {e}")
1804
+ return self._fallback_random_negatives(queries, positives)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1805
 
1806
+ # Process results
1807
+ all_negatives = []
1808
+ for i, (query_indices, query, positive) in enumerate(zip(indices, queries, positives)):
1809
+ negatives = []
1810
+ positive_strip = positive.strip()
1811
+
1812
+ # Filter valid indices and deduplicate
1813
+ seen = {positive_strip}
1814
+ for idx in query_indices:
1815
+ if idx >= 0 and idx < total_responses:
1816
+ candidate = self.response_pool[idx].strip()
1817
+ if candidate and candidate not in seen: # Check for non-empty strings
1818
+ seen.add(candidate)
1819
+ negatives.append(candidate)
1820
+ if len(negatives) >= self.neg_samples:
1821
+ break
1822
+
1823
+ # If we don't have enough negatives, use random sampling from remaining pool
1824
+ if len(negatives) < self.neg_samples:
1825
+ available = [r for r in self.response_pool if r.strip() not in seen and r.strip()]
1826
+ if available:
1827
+ additional = np.random.choice(
1828
+ available,
1829
+ size=min(self.neg_samples - len(negatives), len(available)),
1830
+ replace=False
1831
+ )
1832
+ negatives.extend(additional)
1833
+
1834
+ # Still pad with empty strings if needed
1835
+ while len(negatives) < self.neg_samples:
1836
+ negatives.append("")
1837
+
1838
+ all_negatives.append(negatives)
1839
 
1840
+ return all_negatives
1841
+
1842
+ except Exception as e:
1843
+ retry_count += 1
1844
+ logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
1845
+ if retry_count == self.max_retries:
1846
+ logger.error("Max retries reached for hard negative search")
1847
+ return [[] for _ in queries] # Return empty lists on complete failure
1848
+ gc.collect()
1849
+ if tf.config.list_physical_devices('GPU'):
1850
+ tf.keras.backend.clear_session()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1851
 
1852
+ def _fallback_random_negatives(self, queries: List[str], positives: List[str]) -> List[List[str]]:
1853
+ """Fallback to random sampling when similarity search fails."""
1854
+ logger.warning("Falling back to random negative sampling")
1855
+ all_negatives = []
1856
+ for positive in positives:
1857
+ available = [r for r in self.response_pool if r.strip() != positive.strip()]
1858
+ negatives = list(np.random.choice(
1859
+ available,
1860
+ size=min(self.neg_samples, len(available)),
1861
+ replace=False
1862
+ )) if available else []
1863
+ while len(negatives) < self.neg_samples:
1864
+ negatives.append("")
1865
+ all_negatives.append(negatives)
1866
+ return all_negatives
1867
 
1868
+ def _prepare_batch(
1869
+ self,
1870
+ queries: List[str],
1871
+ positives: List[str],
1872
+ pad_to_batch_size: bool = False
1873
+ ) -> Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor]]:
1874
+ """Prepare a batch with dynamic padding and memory optimization."""
1875
+ actual_size = len(queries)
1876
+
1877
+ # Handle padding if requested and needed
1878
+ if pad_to_batch_size and actual_size < self.current_batch_size:
1879
+ padding_needed = self.current_batch_size - actual_size
1880
+ queries.extend([queries[0]] * padding_needed)
1881
+ positives.extend([positives[0]] * padding_needed)
1882
+ # Create attention mask for padded examples
1883
+ attention_mask = tf.concat([
1884
+ tf.ones((actual_size,), dtype=tf.float32),
1885
+ tf.zeros((padding_needed,), dtype=tf.float32)
1886
+ ], axis=0)
1887
+ else:
1888
+ attention_mask = None
1889
+
1890
+ try:
1891
+ # Tokenize batch
1892
+ q_tokens = self.tokenizer(
1893
+ queries,
1894
+ padding='max_length',
1895
+ truncation=True,
1896
+ max_length=self.max_length,
1897
+ return_tensors='tf'
1898
+ )
1899
+ p_tokens = self.tokenizer(
1900
+ positives,
1901
+ padding='max_length',
1902
+ truncation=True,
1903
+ max_length=self.max_length,
1904
+ return_tensors='tf'
1905
+ )
1906
+
1907
+ return q_tokens['input_ids'], p_tokens['input_ids'], attention_mask
1908
+
1909
+ except Exception as e:
1910
+ logger.error(f"Error preparing batch: {e}")
1911
+ # Emergency memory cleanup
1912
+ gc.collect()
1913
+ if tf.config.list_physical_devices('GPU'):
1914
+ tf.keras.backend.clear_session()
1915
+ raise
1916
+
1917
+ def estimate_total_pairs(self, dialogues: List[dict]) -> int:
1918
+ """Estimate total number of training pairs including hard negatives."""
1919
+ base_pairs = sum(
1920
+ len([
1921
+ 1 for i in range(len(d.get('turns', [])) - 1)
1922
+ if (d['turns'][i].get('speaker') == 'user' and
1923
+ d['turns'][i+1].get('speaker') == 'assistant')
1924
+ ])
1925
+ for d in dialogues
1926
+ )
1927
+ # Account for hard negatives
1928
+ return base_pairs * (1 + self.neg_samples)
1929
+
1930
+ def cleanup(self):
1931
+ """Cleanup resources and memory."""
1932
+ self.query_embeddings_cache.clear()
1933
+ gc.collect()
1934
+ if tf.config.list_physical_devices('GPU'):
1935
+ tf.keras.backend.clear_session()
conversation_summarizer.py CHANGED
@@ -25,9 +25,9 @@ class DeviceAwareModel:
25
  self.strategy = None
26
 
27
  if self.device == 'GPU':
28
- # Enable mixed precision for better performance
29
- policy = tf.keras.mixed_precision.Policy('mixed_float16')
30
- tf.keras.mixed_precision.set_global_policy(policy)
31
 
32
  # Setup distribution strategy for multi-GPU if available
33
  gpus = tf.config.list_physical_devices('GPU')
 
25
  self.strategy = None
26
 
27
  if self.device == 'GPU':
28
+ # # Enable mixed precision for better performance
29
+ # policy = tf.keras.mixed_precision.Policy('mixed_float16')
30
+ # tf.keras.mixed_precision.set_global_policy(policy)
31
 
32
  # Setup distribution strategy for multi-GPU if available
33
  gpus = tf.config.list_physical_devices('GPU')
environment_setup.py CHANGED
@@ -122,14 +122,14 @@ class EnvironmentSetup:
122
  except (subprocess.SubprocessError, FileNotFoundError):
123
  logger.warning("Could not detect specific GPU model")
124
 
125
- # Enable XLA
126
- tf.config.optimizer.set_jit(True)
127
- logger.info("XLA compilation enabled for Colab GPU")
128
 
129
- # Set mixed precision policy
130
- policy = tf.keras.mixed_precision.Policy('mixed_float16')
131
- tf.keras.mixed_precision.set_global_policy(policy)
132
- logger.info("Mixed precision training enabled (float16)")
133
 
134
  strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
135
  return "GPU", strategy
@@ -187,20 +187,24 @@ class EnvironmentSetup:
187
  stderr=subprocess.DEVNULL
188
  ).decode('utf-8').strip()
189
 
190
- if "T4" in gpu_name:
 
 
 
 
191
  # T4 optimizations
192
  logger.info("Optimizing for Colab T4 GPU")
193
- base_batch_size = min(base_batch_size * 2, 32) # T4 can handle larger batches
194
  elif "V100" in gpu_name:
195
  # V100 optimizations
196
  logger.info("Optimizing for Colab V100 GPU")
197
- base_batch_size = min(base_batch_size * 3, 48) # V100 can handle even larger batches
198
  except (subprocess.SubprocessError, FileNotFoundError):
199
  logger.warning("Could not detect specific GPU model, using default settings")
200
 
201
  elif self.device_type == "TPU":
202
  # TPU optimizations
203
- base_batch_size = min(base_batch_size * 4, 64) # TPUs can handle very large batches
204
  logger.info("Optimizing for Colab TPU")
205
 
206
  logger.info(f"Optimized batch size for Colab: {base_batch_size}")
 
122
  except (subprocess.SubprocessError, FileNotFoundError):
123
  logger.warning("Could not detect specific GPU model")
124
 
125
+ # # Enable XLA
126
+ # tf.config.optimizer.set_jit(True)
127
+ # logger.info("XLA compilation enabled for Colab GPU")
128
 
129
+ # # Set mixed precision policy
130
+ # policy = tf.keras.mixed_precision.Policy('mixed_float16')
131
+ # tf.keras.mixed_precision.set_global_policy(policy)
132
+ # logger.info("Mixed precision training enabled (float16)")
133
 
134
  strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
135
  return "GPU", strategy
 
187
  stderr=subprocess.DEVNULL
188
  ).decode('utf-8').strip()
189
 
190
+ if "A100" in gpu_name:
191
+ # A100 optimizations - has 40GB or 80GB variants
192
+ logger.info("Optimizing for Colab A100 GPU")
193
+ base_batch_size = min(base_batch_size * 8, 128) # A100 can handle much larger batches
194
+ elif "T4" in gpu_name:
195
  # T4 optimizations
196
  logger.info("Optimizing for Colab T4 GPU")
197
+ base_batch_size = min(base_batch_size * 2, 32)
198
  elif "V100" in gpu_name:
199
  # V100 optimizations
200
  logger.info("Optimizing for Colab V100 GPU")
201
+ base_batch_size = min(base_batch_size * 3, 48)
202
  except (subprocess.SubprocessError, FileNotFoundError):
203
  logger.warning("Could not detect specific GPU model, using default settings")
204
 
205
  elif self.device_type == "TPU":
206
  # TPU optimizations
207
+ base_batch_size = min(base_batch_size * 4, 64)
208
  logger.info("Optimizing for Colab TPU")
209
 
210
  logger.info(f"Optimized batch size for Colab: {base_batch_size}")
gpu_monitor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import faiss
4
+ import json
5
+ from pathlib import Path
6
+ from typing import List, Dict, Tuple, Optional, Generator
7
+ from dataclasses import dataclass
8
+ import threading
9
+ from queue import Queue
10
+ import gc
11
+ try:
12
+ from tqdm.notebook import tqdm
13
+ except ImportError:
14
+ from tqdm import tqdm
15
+
16
+ @dataclass
17
+ class GPUMemoryStats:
18
+ total: int
19
+ used: int
20
+ free: int
21
+
22
+ class GPUMemoryMonitor:
23
+ """Monitor GPU memory usage with safe CPU fallback."""
24
+ def __init__(self):
25
+ self.has_gpu = False
26
+ try:
27
+ gpus = tf.config.list_physical_devices('GPU')
28
+ self.has_gpu = len(gpus) > 0
29
+ except:
30
+ pass
31
+
32
+ def get_memory_stats(self) -> Optional[GPUMemoryStats]:
33
+ """Get current GPU memory statistics."""
34
+ if not self.has_gpu:
35
+ return None
36
+
37
+ try:
38
+ memory_info = tf.config.experimental.get_memory_info('GPU:0')
39
+ return GPUMemoryStats(
40
+ total=memory_info['peak'],
41
+ used=memory_info['current'],
42
+ free=memory_info['peak'] - memory_info['current']
43
+ )
44
+ except:
45
+ return None
46
+
47
+ def get_memory_usage(self) -> float:
48
+ """Get current GPU memory usage as a percentage."""
49
+ if not self.has_gpu:
50
+ return 0.0
51
+ stats = self.get_memory_stats()
52
+ if stats is None or stats.total == 0:
53
+ return 0.0
54
+ return stats.used / stats.total
55
+
56
+ def should_reduce_batch_size(self) -> bool:
57
+ """Check if batch size should be reduced based on memory usage."""
58
+ if not self.has_gpu:
59
+ return False
60
+ usage = self.get_memory_usage()
61
+ return usage > 0.90
62
+
63
+ def can_increase_batch_size(self) -> bool:
64
+ """Check if batch size can be increased based on memory usage."""
65
+ if not self.has_gpu:
66
+ return True # Allow increase on CPU
67
+ usage = self.get_memory_usage()
68
+ return usage < 0.70
run_model_train.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from chatbot_model import RetrievalChatbot, ChatbotConfig
2
  from environment_setup import EnvironmentSetup
3
  from response_quality_checker import ResponseQualityChecker
@@ -33,11 +34,12 @@ def run_interactive_chat(chatbot, quality_checker):
33
 
34
  def main():
35
  # Initialize environment
 
36
  env = EnvironmentSetup()
37
  env.initialize()
38
 
39
- DEBUG_SAMPLES = 5
40
- EPOCHS = 1 if DEBUG_SAMPLES else 20
41
  TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
42
 
43
  # Optimize batch size for Colab
@@ -54,23 +56,16 @@ def main():
54
  dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
55
 
56
  # Initialize chatbot and verify FAISS index
57
- with env.strategy.scope():
58
- chatbot = RetrievalChatbot(config, dialogues)
 
59
  chatbot.verify_faiss_index()
60
-
61
- # Prepare dataset
62
- logger.info("Preparing dataset...")
63
- q_tensor, p_tensor = chatbot.prepare_dataset(dialogues)
64
- quality_checker = ResponseQualityChecker(chatbot=chatbot)
65
-
66
- # Train model
67
- logger.info("Starting training...")
68
- chatbot.train(
69
- q_pad=q_tensor,
70
- p_pad=p_tensor,
71
  epochs=EPOCHS,
72
  batch_size=batch_size,
73
- validation_split=0.2,
74
  )
75
 
76
  # Save final model
 
1
+ import tensorflow as tf
2
  from chatbot_model import RetrievalChatbot, ChatbotConfig
3
  from environment_setup import EnvironmentSetup
4
  from response_quality_checker import ResponseQualityChecker
 
34
 
35
  def main():
36
  # Initialize environment
37
+ tf.keras.backend.clear_session()
38
  env = EnvironmentSetup()
39
  env.initialize()
40
 
41
+ DEBUG_SAMPLES = 15
42
+ EPOCHS = 5 if DEBUG_SAMPLES else 20
43
  TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
44
 
45
  # Optimize batch size for Colab
 
56
  dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
57
 
58
  # Initialize chatbot and verify FAISS index
59
+ #with env.strategy.scope():
60
+ chatbot = RetrievalChatbot(config, dialogues)
61
+ chatbot.build_models()
62
  chatbot.verify_faiss_index()
63
+
64
+ chatbot.train_streaming(
65
+ dialogues=dialogues,
 
 
 
 
 
 
 
 
66
  epochs=EPOCHS,
67
  batch_size=batch_size,
68
+ use_lr_schedule=True,
69
  )
70
 
71
  # Save final model