JoeArmani commited on
Commit
5b413d1
·
1 Parent(s): ee0f664

training and inference updates

Browse files
chatbot_model.py CHANGED
@@ -1,7 +1,6 @@
1
- import time
2
  from transformers import TFAutoModel, AutoTokenizer
3
  import tensorflow as tf
4
- import numpy as np
5
  from typing import List, Tuple, Dict, Optional, Union, Any
6
  import math
7
  from dataclasses import dataclass
@@ -66,23 +65,17 @@ class EncoderModel(tf.keras.Model):
66
  super().__init__(name=name, **kwargs)
67
  self.config = config
68
 
69
- # Load pretrained model
70
  self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
71
-
72
- # Freeze layers based on config
73
  self._freeze_layers()
74
 
75
- # Pooling layer (Global Average Pooling)
76
  self.pooler = tf.keras.layers.GlobalAveragePooling1D()
77
-
78
- # Projection layer
79
  self.projection = tf.keras.layers.Dense(
80
  config.embedding_dim,
81
  activation='tanh',
82
  name="projection"
83
  )
84
-
85
- # Dropout and normalization
86
  self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
87
  self.normalize = tf.keras.layers.Lambda(
88
  lambda x: tf.nn.l2_normalize(x, axis=1),
@@ -110,13 +103,13 @@ class EncoderModel(tf.keras.Model):
110
  """Forward pass."""
111
  # Get pretrained embeddings
112
  pretrained_outputs = self.pretrained(inputs, training=training)
113
- x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
114
 
115
  # Apply pooling, projection, dropout, and normalization
116
- x = self.pooler(x) # Shape: [batch_size, 768]
117
- x = self.projection(x) # Shape: [batch_size, 768]
118
- x = self.dropout(x, training=training) # Apply dropout
119
- x = self.normalize(x) # Shape: [batch_size, 768]
120
 
121
  return x
122
 
@@ -134,12 +127,11 @@ class RetrievalChatbot(DeviceAwareModel):
134
  def __init__(
135
  self,
136
  config: ChatbotConfig,
137
- dialogues: List[dict] = [],
138
  device: str = None,
139
  strategy=None,
140
  reranker: Optional[CrossEncoderReranker] = None,
141
  summarizer: Optional[Summarizer] = None,
142
- mode: str = 'preparation'
143
  ):
144
  super().__init__()
145
  self.config = config
@@ -147,17 +139,37 @@ class RetrievalChatbot(DeviceAwareModel):
147
  self.device = device or self._setup_default_device()
148
  self.mode = mode.lower()
149
 
150
- # Initialize reranker, summarizer, tokenizer, and memory monitor
151
  self.reranker = reranker or self._initialize_reranker()
152
- self.summarizer = summarizer or self._initialize_summarizer()
153
  self.tokenizer = self._initialize_tokenizer()
 
 
154
  self.memory_monitor = GPUMemoryMonitor()
155
 
156
- # # Initialize models
157
- # self.min_batch_size = 8
158
- # self.max_batch_size = 128
159
- # self.current_batch_size = 32
 
 
 
 
 
 
 
 
 
 
 
160
 
 
 
 
 
 
 
 
 
161
  # Initialize training history
162
  self.history = {
163
  "train_loss": [],
@@ -165,15 +177,7 @@ class RetrievalChatbot(DeviceAwareModel):
165
  "train_metrics": {},
166
  "val_metrics": {}
167
  }
168
-
169
- # Collect unique responses from dialogues
170
- if self.mode == 'preparation':
171
- # Collect unique responses from dialogues only in preparation mode
172
- self.response_pool, self.unique_responses = self._collect_responses(dialogues)
173
- else:
174
- # In training mode, assume response_pool is handled via TFRecord
175
- self.response_pool = []
176
- self.unique_responses = []
177
 
178
  def _setup_default_device(self) -> str:
179
  """Set up default device if none is provided."""
@@ -189,8 +193,13 @@ class RetrievalChatbot(DeviceAwareModel):
189
 
190
  def _initialize_summarizer(self) -> Summarizer:
191
  """Initialize the Summarizer."""
192
- logger.info("Initializing default Summarizer...")
193
- return Summarizer(device=self.device)
 
 
 
 
 
194
 
195
  def _initialize_tokenizer(self) -> AutoTokenizer:
196
  """Initialize the tokenizer and add special tokens."""
@@ -207,559 +216,127 @@ class RetrievalChatbot(DeviceAwareModel):
207
  )
208
  return tokenizer
209
 
210
- def _collect_responses(self, dialogues: List[dict]) -> Tuple[List[str], List[str]]:
211
- """
212
- Collect unique responses from dialogues.
213
- Returns:
214
- response_pool: List of all possible responses.
215
- unique_responses: List of unique responses.
216
- """
217
- logger.info("Collecting unique responses from dialogues...")
218
- responses = set()
219
- for dialogue in dialogues:
220
- turns = dialogue.get('turns', [])
221
- for turn in turns:
222
- if turn.get('speaker') == 'assistant' and 'text' in turn:
223
- response = turn['text'].strip()
224
- if len(response) >= self.config.min_text_length:
225
- responses.add(response)
226
- response_pool = list(responses)
227
- unique_responses = list(responses) # Assuming uniqueness
228
- logger.info(f"Collected {len(response_pool)} unique responses.")
229
- return response_pool, unique_responses
230
-
231
- def build_models(self):
232
- """Initialize the shared encoder and FAISS index."""
233
- logger.info("Building encoder model...")
234
- tf.keras.backend.clear_session()
235
-
236
- # Shared encoder for both queries and responses
237
- self.encoder = EncoderModel(
238
  self.config,
239
  name="shared_encoder",
240
  )
241
 
242
- # Resize token embeddings after adding special tokens
243
  new_vocab_size = len(self.tokenizer)
244
- self.encoder.pretrained.resize_token_embeddings(new_vocab_size)
245
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
246
-
247
- if self.mode == 'preparation':
248
- # Initialize FAISS index only in preparation mode
249
- self._initialize_faiss()
250
- # Compute and index embeddings
251
- self._compute_and_index_embeddings()
252
- else:
253
- # In training mode, skip FAISS indexing from dialogues
254
- logger.info("Training mode: Skipping FAISS index initialization from dialogues.")
255
-
256
- # Retrieve embedding dimension from encoder
257
- embedding_dim = self.config.embedding_dim
258
- vocab_size = len(self.tokenizer)
259
-
260
- logger.info(f"Encoder Embedding Dimension: {embedding_dim}")
261
- logger.info(f"Encoder Embedding Vocabulary Size: {vocab_size}")
262
- if vocab_size >= embedding_dim:
263
- logger.info("Encoder model built and embeddings resized successfully.")
264
- else:
265
- logger.error("Vocabulary size is less than embedding dimension.")
266
- raise ValueError("Vocabulary size is less than embedding dimension.")
267
 
268
- def _adjust_batch_size(self) -> None:
269
- """Dynamically adjust batch size based on GPU memory usage."""
270
- if self.memory_monitor.should_reduce_batch_size():
271
- new_size = max(self.min_batch_size, self.current_batch_size // 2)
272
- if new_size != self.current_batch_size:
273
- logger.info(f"Reducing batch size to {new_size} due to high memory usage")
274
- self.current_batch_size = new_size
275
- gc.collect()
276
- if tf.config.list_physical_devices('GPU'):
277
- tf.keras.backend.clear_session()
278
- elif self.memory_monitor.can_increase_batch_size():
279
- new_size = min(self.max_batch_size, self.current_batch_size * 2)
280
- if new_size != self.current_batch_size:
281
- logger.info(f"Increasing batch size to {new_size}")
282
- self.current_batch_size = new_size
283
-
284
- def _initialize_faiss(self):
285
- """Initialize FAISS with safe GPU handling and memory monitoring."""
286
- logger.info("Initializing FAISS index...")
287
-
288
- # Detect if we have GPU-enabled FAISS
289
- self.faiss_gpu = False
290
- self.gpu_resources = []
291
-
292
- try:
293
- if hasattr(faiss, 'get_num_gpus'):
294
- ngpus = faiss.get_num_gpus()
295
- if ngpus > 0:
296
- # Configure GPU resources with memory limit
297
- for i in range(ngpus):
298
- res = faiss.StandardGpuResources()
299
- # Set temp memory to 1/4 of total memory to avoid OOM
300
- if self.memory_monitor.has_gpu:
301
- stats = self.memory_monitor.get_memory_stats()
302
- if stats:
303
- temp_memory = int(stats.total * 0.25) # 25% of total memory
304
- res.setTempMemory(temp_memory)
305
- self.gpu_resources.append(res)
306
- self.faiss_gpu = True
307
- logger.info(f"FAISS GPU resources initialized on {ngpus} GPUs")
308
- except Exception as e:
309
- logger.warning(f"Using CPU due to GPU initialization error: {e}")
310
-
311
  try:
312
- # Create appropriate index based on dataset size
313
- if len(self.unique_responses) < 1000:
314
- logger.info("Small dataset detected, using simple FlatIP index")
315
- self.index = faiss.IndexFlatIP(self.config.embedding_dim)
 
 
 
 
 
 
316
  else:
317
- # For larger datasets, consider using more efficient indices like IVF
318
- self.index = faiss.IndexFlatIP(self.config.embedding_dim)
 
 
 
 
319
 
320
- # Move to GPU(s) if available and needed
321
- if self.faiss_gpu and self.gpu_resources:
322
- try:
323
- if len(self.gpu_resources) > 1:
324
- self.index = faiss.index_cpu_to_gpus_list(self.index, self.gpu_resources)
325
- logger.info("FAISS index distributed across multiple GPUs")
326
- else:
327
- self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, self.index)
328
- logger.info("FAISS index moved to single GPU")
329
- except Exception as e:
330
- logger.warning(f"Failed to move index to GPU: {e}. Falling back to CPU")
331
- self.faiss_gpu = False
332
  except Exception as e:
333
- logger.error(f"Error initializing FAISS: {e}")
334
  raise
335
-
336
- def encode_responses(
337
- self,
338
- responses: List[str],
339
- batch_size: int = 64
340
- ) -> tf.Tensor:
341
- """
342
- Encodes responses with more conservative memory management.
343
  """
344
- if not responses:
345
- logger.info("No responses to encode. Returning empty tensor.")
346
- return tf.constant([], dtype=tf.float32)
347
-
348
- all_embeddings = []
349
- self.current_batch_size = batch_size
350
-
351
- if self.memory_monitor.has_gpu:
352
- batch_size = 128
353
-
354
- total_processed = 0
355
 
356
- with tqdm(total=len(responses), desc="Encoding responses") as pbar:
357
- while total_processed < len(responses):
358
- # Monitor memory and adjust batch size
359
- if self.memory_monitor.has_gpu:
360
- gpu_usage = self.memory_monitor.get_memory_usage()
361
- if gpu_usage > 0.8: # Over 80% usage
362
- self.current_batch_size = max(128, self.current_batch_size // 2)
363
- logger.info(f"High GPU memory usage ({gpu_usage:.1%}), reducing batch size to {self.current_batch_size}")
364
- gc.collect()
365
- tf.keras.backend.clear_session()
366
-
367
- # Get batch
368
- end_idx = min(total_processed + self.current_batch_size, len(responses))
369
- batch_texts = responses[total_processed:end_idx]
370
-
371
- try:
372
- # Tokenize
373
- encodings = self.tokenizer(
374
- batch_texts,
375
- padding='max_length',
376
- truncation=True,
377
- max_length=self.config.max_context_token_limit,
378
- return_tensors='tf'
379
- )
380
-
381
- # Encode
382
- embeddings_batch = self.encoder(encodings['input_ids'], training=False)
383
-
384
- # Cast to float32
385
- if embeddings_batch.dtype != tf.float32:
386
- embeddings_batch = tf.cast(embeddings_batch, tf.float32)
387
-
388
- # Store
389
- all_embeddings.append(embeddings_batch)
390
-
391
- # Update progress
392
- batch_processed = len(batch_texts)
393
- total_processed += batch_processed
394
-
395
- # Update progress bar
396
- if self.memory_monitor.has_gpu:
397
- gpu_usage = self.memory_monitor.get_memory_usage()
398
- pbar.set_postfix({
399
- 'GPU mem': f'{gpu_usage:.1%}',
400
- 'batch_size': self.current_batch_size
401
- })
402
- pbar.update(batch_processed)
403
-
404
- # Memory cleanup every 1000 samples
405
- if total_processed % 1000 == 0:
406
- gc.collect()
407
- if tf.config.list_physical_devices('GPU'):
408
- tf.keras.backend.clear_session()
409
-
410
- except tf.errors.ResourceExhaustedError:
411
- logger.warning("GPU memory exhausted during encoding, reducing batch size")
412
- self.current_batch_size = max(8, self.current_batch_size // 2)
413
- continue
414
-
415
- except Exception as e:
416
- logger.error(f"Error during encoding: {str(e)}")
417
- raise
418
-
419
- # Concatenate results
420
- if not all_embeddings:
421
- logger.info("No embeddings were encoded. Returning empty tensor.")
422
- return tf.constant([], dtype=tf.float32)
423
-
424
- if len(all_embeddings) == 1:
425
- final_embeddings = all_embeddings[0]
426
- else:
427
- final_embeddings = tf.concat(all_embeddings, axis=0)
428
-
429
- return final_embeddings
430
-
431
- def _train_faiss_index(self, response_embeddings: np.ndarray) -> None:
432
- """Train FAISS index with better memory management and robust fallback mechanisms."""
433
- if self.index.is_trained:
434
- logger.info("Index already trained, skipping training phase")
435
- return
436
-
437
- logger.info("Starting FAISS index training...")
438
 
439
- try:
440
- # First attempt: Try training with smaller subset
441
- subset_size = min(5000, len(response_embeddings)) # Reduced from 10000
442
- logger.info(f"Using {subset_size} samples for initial training attempt")
443
- subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False)
444
- training_embeddings = response_embeddings[subset_idx].copy() # Make a copy
445
-
446
- # Ensure contiguous memory layout
447
- training_embeddings = np.ascontiguousarray(training_embeddings)
448
-
449
- # Force cleanup before training
450
- gc.collect()
451
- if tf.config.list_physical_devices('GPU'):
452
- tf.keras.backend.clear_session()
453
-
454
- # Verify data properties
455
- logger.info(f"FAISS training data shape: {training_embeddings.shape}")
456
- logger.info(f"FAISS training data dtype: {training_embeddings.dtype}")
457
-
458
- logger.info("Starting initial training attempt...")
459
- self.index.train(training_embeddings)
460
- logger.info("Training completed successfully")
461
-
462
- except (RuntimeError, Exception) as e:
463
- logger.warning(f"Initial training attempt failed: {str(e)}")
464
- logger.info("Attempting fallback strategy...")
465
-
466
- try:
467
- # Move to CPU for more stable training
468
- if self.faiss_gpu:
469
- logger.info("Moving index to CPU for fallback training")
470
- cpu_index = faiss.index_gpu_to_cpu(self.index)
471
- else:
472
- cpu_index = self.index
473
-
474
- # Create simpler index type if needed
475
- if isinstance(cpu_index, faiss.IndexIVFFlat):
476
- logger.info("Creating simpler FlatL2 index for fallback")
477
- cpu_index = faiss.IndexFlatL2(self.config.embedding_dim)
478
-
479
- # Use even smaller subset for CPU training
480
- subset_size = min(2000, len(response_embeddings))
481
- subset_idx = np.random.choice(len(response_embeddings), subset_size, replace=False)
482
- fallback_embeddings = response_embeddings[subset_idx].copy()
483
-
484
- # Ensure data is properly formatted
485
- if not fallback_embeddings.flags['C_CONTIGUOUS']:
486
- fallback_embeddings = np.ascontiguousarray(fallback_embeddings)
487
- if fallback_embeddings.dtype != np.float32:
488
- fallback_embeddings = fallback_embeddings.astype(np.float32)
489
-
490
- # Train on CPU
491
- logger.info("Training fallback index on CPU...")
492
- cpu_index.train(fallback_embeddings)
493
-
494
- # Move back to GPU if needed
495
- if self.faiss_gpu:
496
- logger.info("Moving trained index back to GPU...")
497
- if len(self.gpu_resources) > 1:
498
- self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources)
499
- else:
500
- self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
501
- else:
502
- self.index = cpu_index
503
-
504
- logger.info("Fallback training completed successfully")
505
-
506
- except Exception as e2:
507
- logger.error(f"Fallback training also failed: {str(e2)}")
508
- logger.warning("Creating basic brute-force index as last resort")
509
-
510
- try:
511
- # Create basic brute-force index as last resort
512
- dim = response_embeddings.shape[1]
513
- basic_index = faiss.IndexFlatL2(dim)
514
-
515
- if self.faiss_gpu:
516
- if len(self.gpu_resources) > 1:
517
- self.index = faiss.index_cpu_to_gpus_list(basic_index, self.gpu_resources)
518
- else:
519
- self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, basic_index)
520
- else:
521
- self.index = basic_index
522
-
523
- logger.info("Basic index created as fallback")
524
-
525
- except Exception as e3:
526
- logger.error(f"All training attempts failed: {str(e3)}")
527
- raise RuntimeError("Unable to create working FAISS index")
528
-
529
- def _add_vectors_to_index(self, response_embeddings: np.ndarray) -> None:
530
- """Add vectors to FAISS index with enhanced memory management."""
531
- logger.info("Starting vector addition process...")
532
 
533
- # Even smaller batches
534
- initial_batch_size = 128
535
- min_batch_size = 32
536
- max_batch_size = 1024
537
 
538
- total_added = 0
539
- retry_count = 0
540
- max_retries = 5
 
 
541
 
542
- while total_added < len(response_embeddings):
543
- try:
544
- # Monitor memory
545
- if self.memory_monitor.has_gpu:
546
- gpu_usage = self.memory_monitor.get_memory_usage()
547
- #logger.info(f"GPU memory usage before batch: {gpu_usage:.1%}")
548
-
549
- # Force cleanup if memory usage is high
550
- if gpu_usage > 0.7: # Lower threshold to 70%
551
- logger.info("High memory usage detected, forcing cleanup")
552
- gc.collect()
553
- tf.keras.backend.clear_session()
554
-
555
- # Get batch
556
- end_idx = min(total_added + initial_batch_size, len(response_embeddings))
557
- batch = response_embeddings[total_added:end_idx]
558
-
559
- # Add batch
560
- self.index.add(batch)
561
-
562
- # Update progress
563
- batch_size = len(batch)
564
- total_added += batch_size
565
-
566
- # Memory cleanup every few batches
567
- if total_added % (initial_batch_size * 5) == 0:
568
- gc.collect()
569
- if tf.config.list_physical_devices('GPU'):
570
- tf.keras.backend.clear_session()
571
-
572
- # Gradually increase batch size
573
- if initial_batch_size < max_batch_size:
574
- initial_batch_size = min(initial_batch_size + 25, max_batch_size)
575
-
576
- except Exception as e:
577
- logger.warning(f"Error adding batch: {str(e)}")
578
- retry_count += 1
579
-
580
- if retry_count > max_retries:
581
- logger.error("Max retries exceeded.")
582
- raise
583
-
584
- # Reduce batch size
585
- initial_batch_size = max(min_batch_size, initial_batch_size // 2)
586
- logger.info(f"Reducing batch size to {initial_batch_size} and retrying...")
587
-
588
- # Cleanup
589
- gc.collect()
590
- if tf.config.list_physical_devices('GPU'):
591
- tf.keras.backend.clear_session()
592
-
593
- time.sleep(1) # Brief pause before retry
594
 
595
- logger.info(f"Successfully added all {total_added} vectors to index")
596
-
597
- def _add_vectors_cpu_fallback(self, remaining_embeddings: np.ndarray, already_added: int = 0) -> None:
598
- """CPU fallback with extra safeguards and progress tracking."""
599
- logger.info(f"CPU Fallback: Adding {len(remaining_embeddings)} remaining vectors...")
600
 
 
 
 
 
 
601
  try:
602
- # Move index to CPU
603
- if self.faiss_gpu:
604
- logger.info("Moving index to CPU...")
605
- cpu_index = faiss.index_gpu_to_cpu(self.index)
 
606
  else:
607
- cpu_index = self.index
608
-
609
- # Add remaining vectors on CPU with very small batches
610
- batch_size = 128
611
- total_added = already_added
612
-
613
- for i in range(0, len(remaining_embeddings), batch_size):
614
- end_idx = min(i + batch_size, len(remaining_embeddings))
615
- batch = remaining_embeddings[i:end_idx]
616
-
617
- # Add batch
618
- cpu_index.add(batch)
619
-
620
- # Update progress
621
- total_added += len(batch)
622
- if i % (batch_size * 10) == 0:
623
- logger.info(f"Added {total_added} vectors total "
624
- f"({i}/{len(remaining_embeddings)} in current phase)")
625
-
626
- # Periodic cleanup
627
- if i % (batch_size * 20) == 0:
628
- gc.collect()
629
 
630
- # Move back to GPU if needed
631
- if self.faiss_gpu:
632
- logger.info("Moving index back to GPU...")
633
- if len(self.gpu_resources) > 1:
634
- self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources)
635
- else:
636
- self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
637
  else:
638
- self.index = cpu_index
639
-
640
- logger.info("CPU fallback completed successfully")
641
-
642
- except Exception as e:
643
- logger.error(f"Error during CPU fallback: {str(e)}")
644
- raise
645
-
646
- def _compute_and_index_embeddings(self):
647
- """Compute embeddings and build FAISS index with simpler handling."""
648
- logger.info("Computing embeddings and indexing with FAISS...")
649
-
650
- try:
651
- # Encode responses with memory monitoring
652
- logger.info("Encoding unique responses")
653
- response_embeddings = self.encode_responses(self.unique_responses)
654
- response_embeddings = response_embeddings.numpy()
655
-
656
- # Memory cleanup after encoding
657
- gc.collect()
658
- if tf.config.list_physical_devices('GPU'):
659
- tf.keras.backend.clear_session()
660
-
661
- # Ensure float32 and memory contiguous
662
- response_embeddings = response_embeddings.astype('float32')
663
- response_embeddings = np.ascontiguousarray(response_embeddings)
664
-
665
- # Log memory state before normalization
666
- if self.memory_monitor.has_gpu:
667
- stats = self.memory_monitor.get_memory_stats()
668
- if stats:
669
- logger.info(f"GPU memory before normalization: {stats.used/1e9:.2f}GB used")
670
-
671
- # Normalize embeddings
672
- logger.info("Normalizing embeddings with FAISS")
673
- faiss.normalize_L2(response_embeddings)
674
-
675
- # Create and initialize simple FlatIP index
676
- dim = response_embeddings.shape[1]
677
- if self.faiss_gpu:
678
- cpu_index = faiss.IndexFlatIP(dim)
679
- if len(self.gpu_resources) > 1:
680
- self.index = faiss.index_cpu_to_gpus_list(cpu_index, self.gpu_resources)
681
- else:
682
- self.index = faiss.index_cpu_to_gpu(self.gpu_resources[0], 0, cpu_index)
683
- else:
684
- self.index = faiss.IndexFlatIP(dim)
685
-
686
- # Add vectors to index
687
- self._add_vectors_to_index(response_embeddings)
688
-
689
- # Store responses and embeddings
690
- self.response_pool = self.unique_responses
691
- self.response_embeddings = response_embeddings
692
-
693
- # Final memory cleanup
694
- gc.collect()
695
- if tf.config.list_physical_devices('GPU'):
696
- tf.keras.backend.clear_session()
697
-
698
- # Log final state
699
- logger.info(f"Successfully indexed {self.index.ntotal} responses")
700
- if self.memory_monitor.has_gpu:
701
- stats = self.memory_monitor.get_memory_stats()
702
- if stats:
703
- logger.info(f"Final GPU memory usage: {stats.used/1e9:.2f}GB used")
704
-
705
- logger.info("Indexing completed successfully")
706
 
 
 
 
 
 
 
 
707
  except Exception as e:
708
- logger.error(f"Error during indexing: {e}")
709
- # Ensure cleanup even on error
710
- gc.collect()
711
- if tf.config.list_physical_devices('GPU'):
712
- tf.keras.backend.clear_session()
713
  raise
714
-
715
- def verify_faiss_index(self):
716
- """Verify that FAISS index matches the response pool, if index exists."""
717
- if not hasattr(self, 'index') or self.index is None:
718
- logger.info("FAISS index not initialized. Skipping verification.")
719
- return
720
-
721
- indexed_size = self.index.ntotal
722
- pool_size = len(self.response_pool)
723
- logger.info(f"FAISS index size: {indexed_size}")
724
- logger.info(f"Response pool size: {pool_size}")
725
- if indexed_size != pool_size:
726
- logger.warning("Mismatch between FAISS index size and response pool size.")
727
- else:
728
- logger.info("FAISS index correctly matches the response pool.")
729
-
730
- def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> tf.Tensor:
731
- """Encode a query with optional conversation context."""
732
- # Prepare query with context
733
- if context:
734
- context_str = ' '.join([
735
- f"{self.special_tokens['user']} {q} "
736
- f"{self.special_tokens['assistant']} {r}"
737
- for q, r in context[-self.config.max_context_turns:]
738
- ])
739
- query = f"{context_str} {self.special_tokens['user']} {query}"
740
- else:
741
- query = f"{self.special_tokens['user']} {query}"
742
 
743
- # Tokenize and encode
744
- encodings = self.tokenizer(
745
- [query],
746
- padding='max_length',
747
- truncation=True,
748
- max_length=self.config.max_context_token_limit,
749
- return_tensors='tf'
750
- )
751
- input_ids = encodings['input_ids']
752
 
753
- # Verify token IDs
754
- max_id = tf.reduce_max(input_ids).numpy()
755
- new_vocab_size = len(self.tokenizer)
756
 
757
- if max_id >= new_vocab_size:
758
- logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
759
- raise ValueError("Token ID exceeds vocabulary size.")
760
 
761
- # Get embeddings from the shared encoder
762
- return self.encoder(input_ids, training=False)
763
 
764
  def retrieve_responses_cross_encoder(
765
  self,
@@ -786,7 +363,7 @@ class RetrievalChatbot(DeviceAwareModel):
786
 
787
  # 2) Dense retrieval
788
  dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
789
-
790
  if not dense_topk:
791
  return []
792
 
@@ -800,75 +377,228 @@ class RetrievalChatbot(DeviceAwareModel):
800
  combined.sort(key=lambda x: x[1], reverse=True)
801
 
802
  return combined
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
 
804
  def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
805
  """Retrieve top-k responses using FAISS."""
806
- if not hasattr(self, 'index') or self.index is None:
807
  logger.warning("FAISS index not initialized. Cannot retrieve responses.")
808
  return []
809
 
810
- # Encode the query
811
- q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
812
  q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
813
 
814
  # Normalize the query embedding for cosine similarity
815
  faiss.normalize_L2(q_emb_np)
816
 
817
  # Search the FAISS index
818
- distances, indices = self.index.search(q_emb_np, top_k)
819
 
820
  # Map indices to responses and distances to similarities
821
  top_responses = []
822
  for i, idx in enumerate(indices[0]):
823
- if idx < len(self.response_pool):
824
- top_responses.append((self.response_pool[idx], float(distances[0][i])))
825
  else:
826
  logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
827
 
828
  return top_responses
829
-
830
- def save_models(self, save_dir: Union[str, Path]):
831
- """Save models and configuration."""
832
- save_dir = Path(save_dir)
833
- save_dir.mkdir(parents=True, exist_ok=True)
834
 
835
- # Save config
836
- with open(save_dir / "config.json", "w") as f:
837
- json.dump(self.config.to_dict(), f, indent=2)
838
-
839
- # Save models
840
- self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
841
 
842
- # Save tokenizer
843
- self.tokenizer.save_pretrained(save_dir / "tokenizer")
844
 
845
- logger.info(f"Models and tokenizer saved to {save_dir}.")
846
-
847
- @classmethod
848
- def load_models(cls, load_dir: Union[str, Path]) -> 'RetrievalChatbot':
849
- """Load saved models and configuration."""
850
- load_dir = Path(load_dir)
851
 
852
- # Load config
853
- with open(load_dir / "config.json", "r") as f:
854
- config = ChatbotConfig.from_dict(json.load(f))
855
-
856
- # Initialize chatbot
857
- chatbot = cls(config)
 
858
 
859
- # Load models
860
- chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
861
- load_dir / "shared_encoder",
862
- config=config
863
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
 
865
- # Load tokenizer
866
- chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
 
868
- logger.info(f"Models and tokenizer loaded from {load_dir}.")
869
- return chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
870
 
871
- def train_streaming(
872
  self,
873
  tfrecord_file_path: str,
874
  epochs: int = 20,
@@ -876,10 +606,12 @@ class RetrievalChatbot(DeviceAwareModel):
876
  validation_split: float = 0.2,
877
  checkpoint_dir: str = "checkpoints/",
878
  use_lr_schedule: bool = True,
879
- peak_lr: float = 2e-5,
880
  warmup_steps_ratio: float = 0.1,
881
  early_stopping_patience: int = 3,
882
  min_delta: float = 1e-4,
 
 
883
  ) -> None:
884
  """Training using a pre-prepared TFRecord dataset."""
885
  logger.info("Starting training with pre-prepared TFRecord dataset...")
@@ -908,8 +640,8 @@ class RetrievalChatbot(DeviceAwareModel):
908
  negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
909
  negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
910
 
911
- return query_ids, positive_ids, negative_ids
912
-
913
  # Calculate total steps by counting the number of records in the TFRecord
914
  raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
915
  total_pairs = sum(1 for _ in raw_dataset)
@@ -920,6 +652,7 @@ class RetrievalChatbot(DeviceAwareModel):
920
  steps_per_epoch = math.ceil(train_size / batch_size)
921
  val_steps = math.ceil(val_size / batch_size)
922
  total_steps = steps_per_epoch * epochs
 
923
 
924
  logger.info(f"Training pairs: {train_size}")
925
  logger.info(f"Validation pairs: {val_size}")
@@ -942,9 +675,42 @@ class RetrievalChatbot(DeviceAwareModel):
942
  logger.info("Using fixed learning rate.")
943
 
944
  # Initialize checkpoint manager
945
- checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
946
- manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
 
 
 
 
 
947
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948
  # Setup TensorBoard
949
  log_dir = Path(checkpoint_dir) / "tensorboard_logs"
950
  log_dir.mkdir(parents=True, exist_ok=True)
@@ -960,20 +726,47 @@ class RetrievalChatbot(DeviceAwareModel):
960
 
961
  # Create the full dataset
962
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963
  dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
964
- dataset = dataset.shuffle(buffer_size=10000) # Adjust buffer size as needed
965
- dataset = dataset.batch(batch_size, drop_remainder=True)
966
- dataset = dataset.prefetch(tf.data.AUTOTUNE)
967
-
968
- # Split into training and validation
969
  train_dataset = dataset.take(train_size)
970
  val_dataset = dataset.skip(train_size).take(val_size)
971
 
 
 
 
 
 
 
 
 
 
 
 
972
  # Training loop
973
  best_val_loss = float("inf")
974
  epochs_no_improve = 0
975
 
976
- for epoch in range(1, epochs + 1):
977
  # --- Training Phase ---
978
  epoch_loss_avg = tf.keras.metrics.Mean()
979
  batches_processed = 0
@@ -987,13 +780,28 @@ class RetrievalChatbot(DeviceAwareModel):
987
  logger.info("Training progress bar disabled")
988
 
989
  for q_batch, p_batch, n_batch in train_dataset:
990
- loss = self.train_step(q_batch, p_batch, n_batch)
 
 
 
 
 
 
 
 
 
 
 
 
991
  epoch_loss_avg(loss)
992
  batches_processed += 1
993
 
994
  # Log to TensorBoard
995
  with train_summary_writer.as_default():
996
- tf.summary.scalar("loss", loss, step=(epoch - 1) * steps_per_epoch + batches_processed)
 
 
 
997
 
998
  # Update progress bar
999
  if use_lr_schedule:
@@ -1005,6 +813,8 @@ class RetrievalChatbot(DeviceAwareModel):
1005
  train_pbar.update(1)
1006
  train_pbar.set_postfix({
1007
  "loss": f"{loss.numpy():.4f}",
 
 
1008
  "lr": f"{current_lr:.2e}",
1009
  "batches": f"{batches_processed}/{steps_per_epoch}"
1010
  })
@@ -1064,6 +874,11 @@ class RetrievalChatbot(DeviceAwareModel):
1064
 
1065
  # Save checkpoint
1066
  manager.save()
 
 
 
 
 
1067
 
1068
  # Store metrics in history
1069
  self.history['train_loss'].append(train_loss)
@@ -1074,8 +889,14 @@ class RetrievalChatbot(DeviceAwareModel):
1074
  else:
1075
  current_lr = float(self.optimizer.learning_rate.numpy())
1076
 
 
1077
  self.history.setdefault('learning_rate', []).append(current_lr)
1078
 
 
 
 
 
 
1079
  # Early stopping logic
1080
  if val_loss < best_val_loss - min_delta:
1081
  best_val_loss = val_loss
@@ -1144,10 +965,19 @@ class RetrievalChatbot(DeviceAwareModel):
1144
  )
1145
  loss = tf.reduce_mean(loss)
1146
 
1147
- # Apply gradients
1148
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
 
 
 
 
 
 
 
 
1149
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
1150
- return loss
 
1151
 
1152
  @tf.function
1153
  def validation_step(
@@ -1185,316 +1015,6 @@ class RetrievalChatbot(DeviceAwareModel):
1185
  loss = tf.reduce_mean(loss)
1186
 
1187
  return loss
1188
- # def train_streaming(
1189
- # self,
1190
- # dialogues: List[dict],
1191
- # epochs: int = 20,
1192
- # batch_size: int = 16,
1193
- # validation_split: float = 0.2,
1194
- # checkpoint_dir: str = "checkpoints/",
1195
- # use_lr_schedule: bool = True,
1196
- # peak_lr: float = 2e-5,
1197
- # warmup_steps_ratio: float = 0.1,
1198
- # early_stopping_patience: int = 3,
1199
- # min_delta: float = 1e-4,
1200
- # neg_samples: int = 1
1201
- # ) -> None:
1202
- # """Streaming training with tf.data pipeline."""
1203
- # logger.info("Starting streaming training pipeline with tf.data...")
1204
-
1205
- # # Initialize TFDataPipeline (replaces StreamingDataPipeline)
1206
- # dataset_preparer = TFDataPipeline(
1207
- # embedding_batch_size=self.config.embedding_batch_size,
1208
- # tokenizer=self.tokenizer,
1209
- # encoder=self.encoder,
1210
- # index=self.index, # Pass CPU version of FAISS index
1211
- # response_pool=self.response_pool,
1212
- # max_length=self.config.max_context_token_limit,
1213
- # neg_samples=neg_samples
1214
- # )
1215
-
1216
- # # Calculate total steps for learning rate schedule
1217
- # total_pairs = dataset_preparer.estimate_total_pairs(dialogues)
1218
- # train_size = int(total_pairs * (1 - validation_split))
1219
- # val_size = int(total_pairs * validation_split)
1220
- # steps_per_epoch = int(math.ceil(train_size / batch_size))
1221
- # val_steps = int(math.ceil(val_size / batch_size))
1222
- # total_steps = steps_per_epoch * epochs
1223
-
1224
- # logger.info(f"Total pairs: {total_pairs}")
1225
- # logger.info(f"Training pairs: {train_size}")
1226
- # logger.info(f"Validation pairs: {val_size}")
1227
- # logger.info(f"Steps per epoch: {steps_per_epoch}")
1228
- # logger.info(f"Validation steps: {val_steps}")
1229
- # logger.info(f"Total steps: {total_steps}")
1230
-
1231
- # # Set up optimizer with learning rate schedule
1232
- # if use_lr_schedule:
1233
- # warmup_steps = int(total_steps * warmup_steps_ratio)
1234
- # lr_schedule = self._get_lr_schedule(
1235
- # total_steps=total_steps,
1236
- # peak_lr=peak_lr,
1237
- # warmup_steps=warmup_steps
1238
- # )
1239
- # self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
1240
- # logger.info("Using custom learning rate schedule.")
1241
- # else:
1242
- # self.optimizer = tf.keras.optimizers.Adam(learning_rate=peak_lr)
1243
- # logger.info("Using fixed learning rate.")
1244
-
1245
- # # Initialize checkpoint manager
1246
- # checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.encoder)
1247
- # manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
1248
-
1249
- # # Setup TensorBoard
1250
- # log_dir = Path(checkpoint_dir) / "tensorboard_logs"
1251
- # log_dir.mkdir(parents=True, exist_ok=True)
1252
- # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
1253
- # train_log_dir = str(log_dir / f"train_{current_time}")
1254
- # val_log_dir = str(log_dir / f"val_{current_time}")
1255
- # train_summary_writer = tf.summary.create_file_writer(train_log_dir)
1256
- # val_summary_writer = tf.summary.create_file_writer(val_log_dir)
1257
- # logger.info(f"TensorBoard logs will be saved in {log_dir}")
1258
-
1259
- # # Create training and validation datasets
1260
- # train_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).take(train_size)
1261
- # val_dataset = dataset_preparer.get_tf_dataset(dialogues, batch_size).skip(train_size).take(val_size)
1262
-
1263
- # # Training loop
1264
- # best_val_loss = float("inf")
1265
- # epochs_no_improve = 0
1266
-
1267
- # for epoch in range(1, epochs + 1):
1268
- # # --- Training Phase ---
1269
- # epoch_loss_avg = tf.keras.metrics.Mean()
1270
- # batches_processed = 0
1271
-
1272
- # try:
1273
- # train_pbar = tqdm(total=steps_per_epoch, desc=f"Training Epoch {epoch}", unit="batch")
1274
- # is_tqdm_train = True
1275
- # except ImportError:
1276
- # train_pbar = None
1277
- # is_tqdm_train = False
1278
- # logger.info("Training progress bar disabled")
1279
-
1280
- # for q_batch, p_batch, n_batch in train_dataset:
1281
- # #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
1282
- # loss = self.train_step(q_batch, p_batch, n_batch)
1283
- # epoch_loss_avg(loss)
1284
- # batches_processed += 1
1285
-
1286
- # # Log to TensorBoard
1287
- # with train_summary_writer.as_default():
1288
- # tf.summary.scalar("loss", loss, step=(epoch - 1) * steps_per_epoch + batches_processed)
1289
-
1290
- # # Update progress bar
1291
- # if use_lr_schedule:
1292
- # current_lr = float(lr_schedule(self.optimizer.iterations))
1293
- # else:
1294
- # current_lr = float(self.optimizer.learning_rate.numpy())
1295
-
1296
- # if is_tqdm_train:
1297
- # train_pbar.update(1)
1298
- # train_pbar.set_postfix({
1299
- # "loss": f"{loss.numpy():.4f}",
1300
- # "lr": f"{current_lr:.2e}",
1301
- # "batches": f"{batches_processed}/{steps_per_epoch}"
1302
- # })
1303
-
1304
- # # Memory cleanup
1305
- # gc.collect()
1306
-
1307
- # if batches_processed >= steps_per_epoch:
1308
- # break
1309
-
1310
- # if is_tqdm_train and train_pbar:
1311
- # train_pbar.close()
1312
-
1313
- # # --- Validation Phase ---
1314
- # val_loss_avg = tf.keras.metrics.Mean()
1315
- # val_batches_processed = 0
1316
-
1317
- # try:
1318
- # val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
1319
- # is_tqdm_val = True
1320
- # except ImportError:
1321
- # val_pbar = None
1322
- # is_tqdm_val = False
1323
- # logger.info("Validation progress bar disabled")
1324
-
1325
- # for q_batch, p_batch, n_batch in val_dataset:
1326
- # #p_batch = p_n_batch[:, 0, :] # Extract positive from (positive, negative) pair
1327
- # val_loss = self.validation_step(q_batch, p_batch, n_batch)
1328
- # val_loss_avg(val_loss)
1329
- # val_batches_processed += 1
1330
-
1331
- # if is_tqdm_val:
1332
- # val_pbar.update(1)
1333
- # val_pbar.set_postfix({
1334
- # "val_loss": f"{val_loss.numpy():.4f}",
1335
- # "batches": f"{val_batches_processed}/{val_steps}"
1336
- # })
1337
-
1338
- # # Memory cleanup
1339
- # gc.collect()
1340
-
1341
-
1342
- # if val_batches_processed >= val_steps:
1343
- # break
1344
-
1345
- # if is_tqdm_val and val_pbar:
1346
- # val_pbar.close()
1347
-
1348
- # # End of epoch: compute final epoch stats, log, and save checkpoint
1349
- # train_loss = epoch_loss_avg.result().numpy()
1350
- # val_loss = val_loss_avg.result().numpy()
1351
- # logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
1352
-
1353
- # # Log epoch metrics
1354
- # with train_summary_writer.as_default():
1355
- # tf.summary.scalar("epoch_loss", train_loss, step=epoch)
1356
- # with val_summary_writer.as_default():
1357
- # tf.summary.scalar("val_loss", val_loss, step=epoch)
1358
-
1359
- # # Save checkpoint
1360
- # manager.save()
1361
-
1362
- # # Store metrics in history
1363
- # self.history['train_loss'].append(train_loss)
1364
- # self.history['val_loss'].append(val_loss)
1365
-
1366
- # if use_lr_schedule:
1367
- # current_lr = float(lr_schedule(self.optimizer.iterations))
1368
- # else:
1369
- # current_lr = float(self.optimizer.learning_rate.numpy())
1370
-
1371
- # self.history.setdefault('learning_rate', []).append(current_lr)
1372
-
1373
- # # Early stopping logic
1374
- # if val_loss < best_val_loss - min_delta:
1375
- # best_val_loss = val_loss
1376
- # epochs_no_improve = 0
1377
- # logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.")
1378
- # else:
1379
- # epochs_no_improve += 1
1380
- # logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}")
1381
- # if epochs_no_improve >= early_stopping_patience:
1382
- # logger.info("Early stopping triggered.")
1383
- # break
1384
-
1385
- # logger.info("Streaming training completed!")
1386
-
1387
-
1388
- # @tf.function
1389
- # def train_step(
1390
- # self,
1391
- # q_batch: tf.Tensor,
1392
- # p_batch: tf.Tensor,
1393
- # n_batch: tf.Tensor,
1394
- # attention_mask: Optional[tf.Tensor] = None
1395
- # ) -> tf.Tensor:
1396
- # """
1397
- # Single training step that uses queries, positives, and negatives in a
1398
- # contrastive/InfoNCE style. The label is always 0 (the positive) vs.
1399
- # the negative alternatives.
1400
- # """
1401
- # with tf.GradientTape() as tape:
1402
- # # Encode queries
1403
- # q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
1404
-
1405
- # # Encode positives
1406
- # p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
1407
-
1408
- # # Encode negatives
1409
- # # n_batch: [batch_size, neg_samples, max_length]
1410
- # shape = tf.shape(n_batch)
1411
- # bs = shape[0]
1412
- # neg_samples = shape[1]
1413
-
1414
- # # Flatten negatives to feed them in one pass:
1415
- # # => [batch_size * neg_samples, max_length]
1416
- # n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
1417
- # n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
1418
-
1419
- # # Reshape back => [batch_size, neg_samples, embed_dim]
1420
- # n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
1421
-
1422
- # # Combine the positive embedding and negative embeddings along dim=1
1423
- # # => shape [batch_size, 1 + neg_samples, embed_dim]
1424
- # # The first column is the positive; subsequent columns are negatives
1425
- # combined_p_n = tf.concat(
1426
- # [tf.expand_dims(p_enc, axis=1), n_enc],
1427
- # axis=1
1428
- # ) # [bs, (1+neg_samples), embed_dim]
1429
-
1430
- # # Now compute scores: dot product of q_enc with each column in combined_p_n
1431
- # # We'll use `tf.einsum` to handle the batch dimension properly
1432
- # # dot_products => shape [batch_size, (1+neg_samples)]
1433
- # dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
1434
-
1435
- # # The label for each row is 0 (the first column is the correct/positive)
1436
- # labels = tf.zeros([bs], dtype=tf.int32)
1437
-
1438
- # # Cross-entropy over the [batch_size, 1+neg_samples] scores
1439
- # loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1440
- # labels=labels,
1441
- # logits=dot_products
1442
- # )
1443
- # loss = tf.reduce_mean(loss)
1444
-
1445
- # # If there's an attention_mask you want to apply (less common in this scenario),
1446
- # # you could do something like:
1447
- # if attention_mask is not None:
1448
- # loss = loss * attention_mask
1449
- # loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1450
-
1451
- # # Apply gradients
1452
- # gradients = tape.gradient(loss, self.encoder.trainable_variables)
1453
- # self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
1454
- # return loss
1455
-
1456
- # @tf.function
1457
- # def validation_step(
1458
- # self,
1459
- # q_batch: tf.Tensor,
1460
- # p_batch: tf.Tensor,
1461
- # n_batch: tf.Tensor,
1462
- # attention_mask: Optional[tf.Tensor] = None
1463
- # ) -> tf.Tensor:
1464
- # """
1465
- # Single validation step with queries, positives, and negatives.
1466
- # Uses the same loss calculation as train_step, but `training=False`.
1467
- # """
1468
- # q_enc = self.encoder(q_batch, training=False)
1469
- # p_enc = self.encoder(p_batch, training=False)
1470
-
1471
- # shape = tf.shape(n_batch)
1472
- # bs = shape[0]
1473
- # neg_samples = shape[1]
1474
-
1475
- # n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
1476
- # n_enc_flat = self.encoder(n_batch_flat, training=False)
1477
- # n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
1478
-
1479
- # combined_p_n = tf.concat(
1480
- # [tf.expand_dims(p_enc, axis=1), n_enc],
1481
- # axis=1
1482
- # )
1483
-
1484
- # dot_products = tf.einsum('bd,bkd->bk', q_enc, combined_p_n)
1485
- # labels = tf.zeros([bs], dtype=tf.int32)
1486
-
1487
- # loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1488
- # labels=labels,
1489
- # logits=dot_products
1490
- # )
1491
- # loss = tf.reduce_mean(loss)
1492
-
1493
- # if attention_mask is not None:
1494
- # loss = loss * attention_mask
1495
- # loss = tf.reduce_sum(loss) / tf.reduce_sum(attention_mask)
1496
-
1497
- # return loss
1498
 
1499
  def _get_lr_schedule(
1500
  self,
@@ -1561,75 +1081,3 @@ class RetrievalChatbot(DeviceAwareModel):
1561
  }
1562
 
1563
  return CustomSchedule(total_steps, peak_lr, warmup_steps)
1564
-
1565
- def _cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> np.ndarray:
1566
- """Compute cosine similarity between two numpy arrays."""
1567
- normalized_emb1 = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
1568
- normalized_emb2 = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
1569
- return np.dot(normalized_emb1, normalized_emb2.T)
1570
-
1571
- def chat(
1572
- self,
1573
- query: str,
1574
- conversation_history: Optional[List[Tuple[str, str]]] = None,
1575
- quality_checker: Optional['ResponseQualityChecker'] = None,
1576
- top_k: int = 5,
1577
- ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
1578
- """
1579
- Example chat method that always uses cross-encoder re-ranking
1580
- if self.reranker is available.
1581
- """
1582
- @self.run_on_device
1583
- def get_response(self_arg, query_arg): # Add parameters that match decorator's expectations
1584
- # 1) Build conversation context string
1585
- conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
1586
-
1587
- # 2) Retrieve + cross-encoder re-rank
1588
- results = self_arg.retrieve_responses_cross_encoder(
1589
- query=conversation_str,
1590
- top_k=top_k,
1591
- reranker=self_arg.reranker,
1592
- summarizer=self_arg.summarizer,
1593
- summarize_threshold=512
1594
- )
1595
-
1596
- # 3) Handle empty or confidence
1597
- if not results:
1598
- return (
1599
- "I'm sorry, but I couldn't find a relevant response.",
1600
- [],
1601
- {}
1602
- )
1603
-
1604
- if quality_checker:
1605
- metrics = quality_checker.check_response_quality(query_arg, results)
1606
- if not metrics.get('is_confident', False):
1607
- return (
1608
- "I need more information to provide a good answer. Could you please clarify?",
1609
- results,
1610
- metrics
1611
- )
1612
- return results[0][0], results, metrics
1613
-
1614
- return results[0][0], results, {}
1615
-
1616
- return get_response(self, query)
1617
-
1618
- def _build_conversation_context(
1619
- self,
1620
- query: str,
1621
- conversation_history: Optional[List[Tuple[str, str]]]
1622
- ) -> str:
1623
- """Build conversation context with better memory management."""
1624
- if not conversation_history:
1625
- return f"{self.special_tokens['user']} {query}"
1626
-
1627
- conversation_parts = []
1628
- for user_txt, assistant_txt in conversation_history:
1629
- conversation_parts.extend([
1630
- f"{self.special_tokens['user']} {user_txt}",
1631
- f"{self.special_tokens['assistant']} {assistant_txt}"
1632
- ])
1633
-
1634
- conversation_parts.append(f"{self.special_tokens['user']} {query}")
1635
- return "\n".join(conversation_parts)
 
1
+ import os
2
  from transformers import TFAutoModel, AutoTokenizer
3
  import tensorflow as tf
 
4
  from typing import List, Tuple, Dict, Optional, Union, Any
5
  import math
6
  from dataclasses import dataclass
 
65
  super().__init__(name=name, **kwargs)
66
  self.config = config
67
 
68
+ # Load pretrained model and freeze layers based on config
69
  self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
 
 
70
  self._freeze_layers()
71
 
72
+ # Add Pooling layer (Global Average Pooling), Projection layer, Dropout, and Normalization
73
  self.pooler = tf.keras.layers.GlobalAveragePooling1D()
 
 
74
  self.projection = tf.keras.layers.Dense(
75
  config.embedding_dim,
76
  activation='tanh',
77
  name="projection"
78
  )
 
 
79
  self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
80
  self.normalize = tf.keras.layers.Lambda(
81
  lambda x: tf.nn.l2_normalize(x, axis=1),
 
103
  """Forward pass."""
104
  # Get pretrained embeddings
105
  pretrained_outputs = self.pretrained(inputs, training=training)
106
+ x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
107
 
108
  # Apply pooling, projection, dropout, and normalization
109
+ x = self.pooler(x) # Shape: [batch_size, 768]
110
+ x = self.projection(x) # Shape: [batch_size, 768]
111
+ x = self.dropout(x, training=training)
112
+ x = self.normalize(x) # Shape: [batch_size, 768]
113
 
114
  return x
115
 
 
127
  def __init__(
128
  self,
129
  config: ChatbotConfig,
 
130
  device: str = None,
131
  strategy=None,
132
  reranker: Optional[CrossEncoderReranker] = None,
133
  summarizer: Optional[Summarizer] = None,
134
+ mode: str = 'training'
135
  ):
136
  super().__init__()
137
  self.config = config
 
139
  self.device = device or self._setup_default_device()
140
  self.mode = mode.lower()
141
 
142
+ # Initialize reranker, summarizer, tokenizer, encoder, and memory monitor
143
  self.reranker = reranker or self._initialize_reranker()
 
144
  self.tokenizer = self._initialize_tokenizer()
145
+ self.encoder = self._initialize_encoder()
146
+ self.summarizer = summarizer or self._initialize_summarizer()
147
  self.memory_monitor = GPUMemoryMonitor()
148
 
149
+ # Initialize data pipeline
150
+ logger.info("Initializing TFDataPipeline.")
151
+ self.data_pipeline = TFDataPipeline(
152
+ config=self.config,
153
+ tokenizer=self.tokenizer,
154
+ encoder=self.encoder,
155
+ index_file_path='path/to/index', # Update as needed # TODO: Update this path
156
+ response_pool=[],
157
+ max_length=self.config.max_context_token_limit,
158
+ query_embeddings_cache={},
159
+ neg_samples=self.config.neg_samples,
160
+ index_type='IndexFlatIP',
161
+ nlist=100, # Not used with IndexFlatIP
162
+ max_retries=self.config.max_retries
163
+ )
164
 
165
+ # Collect unique responses from dialogues
166
+ if self.mode == 'inference':
167
+ logger.info("Mode set to 'inference'. Loading FAISS index and response pool.")
168
+ self._load_faiss_index_and_responses()
169
+ elif self.mode != 'training':
170
+ logger.error(f"Unsupported mode in RetrievalChatbot init: {self.mode}")
171
+ raise ValueError(f"Unsupported mode in RetrievalChatbot init: {self.mode}")
172
+
173
  # Initialize training history
174
  self.history = {
175
  "train_loss": [],
 
177
  "train_metrics": {},
178
  "val_metrics": {}
179
  }
180
+
 
 
 
 
 
 
 
 
181
 
182
  def _setup_default_device(self) -> str:
183
  """Set up default device if none is provided."""
 
193
 
194
  def _initialize_summarizer(self) -> Summarizer:
195
  """Initialize the Summarizer."""
196
+ return Summarizer(
197
+ tokenizer=self.tokenizer,
198
+ model_name="t5-small",
199
+ max_summary_length=self.config.max_context_token_limit // 4,
200
+ device=self.device,
201
+ max_summary_rounds=2
202
+ )
203
 
204
  def _initialize_tokenizer(self) -> AutoTokenizer:
205
  """Initialize the tokenizer and add special tokens."""
 
216
  )
217
  return tokenizer
218
 
219
+ def _initialize_encoder(self) -> EncoderModel:
220
+ """Initialize the EncoderModel and resize token embeddings."""
221
+ logger.info("Initializing encoder model...")
222
+ encoder = EncoderModel(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  self.config,
224
  name="shared_encoder",
225
  )
226
 
 
227
  new_vocab_size = len(self.tokenizer)
228
+ encoder.pretrained.resize_token_embeddings(new_vocab_size)
229
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
230
+ return encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ def _load_faiss_index_and_responses(self) -> None:
233
+ """Load FAISS index and response pool for inference."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  try:
235
+ logger.info(f"Loading FAISS index from {self.data_pipeline.index_file_path}...")
236
+ self.data_pipeline.load_faiss_index(self.data_pipeline.index_file_path)
237
+ logger.info("FAISS index loaded successfully.")
238
+
239
+ # Load response pool associated with the FAISS index
240
+ response_pool_path = self.data_pipeline.index_file_path.replace('.index', '_responses.json')
241
+ if os.path.exists(response_pool_path):
242
+ with open(response_pool_path, 'r', encoding='utf-8') as f:
243
+ self.data_pipeline.response_pool = json.load(f)
244
+ logger.info(f"Loaded {len(self.data_pipeline.response_pool)} responses from {response_pool_path}.")
245
  else:
246
+ logger.error(f"Response pool file not found at {response_pool_path}.")
247
+ raise FileNotFoundError(f"Response pool file not found at {response_pool_path}.")
248
+
249
+ # Validate FAISS index and response pool
250
+ self.data_pipeline.validate_faiss_index()
251
+ logger.info("FAISS index and response pool validated successfully.")
252
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  except Exception as e:
254
+ logger.error(f"Failed to load FAISS index and response pool: {e}")
255
  raise
256
+
257
+ @classmethod
258
+ def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
 
 
 
 
 
259
  """
260
+ Load saved models and configuration.
 
 
 
 
 
 
 
 
 
 
261
 
262
+ Args:
263
+ load_dir (Union[str, Path]): Directory containing saved model files
264
+ mode (str): Either 'training' or 'inference'. In inference mode,
265
+ also loads FAISS index and response pool.
266
+ """
267
+ load_dir = Path(load_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
+ # Load config
270
+ with open(load_dir / "config.json", "r") as f:
271
+ config = ChatbotConfig.from_dict(json.load(f))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ # Initialize chatbot with appropriate mode
274
+ chatbot = cls(config, mode=mode)
 
 
275
 
276
+ # Load models
277
+ chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
278
+ load_dir / "shared_encoder",
279
+ config=config
280
+ )
281
 
282
+ # Load tokenizer
283
+ chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
284
+ logger.info(f"Models and tokenizer loaded from {load_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ # If in inference mode, load additional components
287
+ if mode == 'inference':
288
+ cls._prepare_model_for_inference(chatbot, load_dir)
 
 
289
 
290
+ return chatbot
291
+
292
+ @classmethod
293
+ def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None:
294
+ """Internal method to load inference components."""
295
  try:
296
+ # Load FAISS index
297
+ faiss_path = load_dir / 'faiss_index.bin'
298
+ if faiss_path.exists():
299
+ chatbot.index = faiss.read_index(str(faiss_path))
300
+ logger.info("FAISS index loaded successfully")
301
  else:
302
+ raise FileNotFoundError(f"FAISS index not found at {faiss_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ # Load response pool
305
+ response_pool_path = load_dir / 'response_pool.json'
306
+ if response_pool_path.exists():
307
+ with open(response_pool_path, 'r') as f:
308
+ chatbot.response_pool = json.load(f)
309
+ logger.info(f"Loaded {len(chatbot.response_pool)} responses")
 
310
  else:
311
+ raise FileNotFoundError(f"Response pool not found at {response_pool_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+ # Verify dimensions match
314
+ if chatbot.index.d != chatbot.config.embedding_dim:
315
+ raise ValueError(
316
+ f"FAISS index dimension {chatbot.index.d} doesn't match "
317
+ f"model dimension {chatbot.config.embedding_dim}"
318
+ )
319
+
320
  except Exception as e:
321
+ logger.error(f"Error loading inference components: {e}")
 
 
 
 
322
  raise
323
+
324
+ def save_models(self, save_dir: Union[str, Path]):
325
+ """Save models and configuration."""
326
+ save_dir = Path(save_dir)
327
+ save_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
+ # Save config
330
+ with open(save_dir / "config.json", "w") as f:
331
+ json.dump(self.config.to_dict(), f, indent=2)
 
 
 
 
 
 
332
 
333
+ # Save models
334
+ self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
 
335
 
336
+ # Save tokenizer
337
+ self.tokenizer.save_pretrained(save_dir / "tokenizer")
 
338
 
339
+ logger.info(f"Models and tokenizer saved to {save_dir}.")
 
340
 
341
  def retrieve_responses_cross_encoder(
342
  self,
 
363
 
364
  # 2) Dense retrieval
365
  dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
366
+
367
  if not dense_topk:
368
  return []
369
 
 
377
  combined.sort(key=lambda x: x[1], reverse=True)
378
 
379
  return combined
380
+ # def retrieve_responses_cross_encoder(
381
+ # self,
382
+ # query: str,
383
+ # top_k: int,
384
+ # reranker: Optional[CrossEncoderReranker] = None,
385
+ # summarizer: Optional[Summarizer] = None,
386
+ # summarize_threshold: int = 512 # Summarize over 512 tokens
387
+ # ) -> List[Tuple[str, float]]:
388
+ # """
389
+ # Retrieve top-k from FAISS, then re-rank them with a cross-encoder.
390
+ # Optionally summarize the user query if it's too long.
391
+ # """
392
+ # if reranker is None:
393
+ # reranker = self.reranker
394
+ # if summarizer is None:
395
+ # summarizer = self.summarizer
396
+
397
+ # # Optional summarization
398
+ # if summarizer and len(query.split()) > summarize_threshold:
399
+ # logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}")
400
+ # query = summarizer.summarize_text(query)
401
+ # logger.info(f"Summarized query: {query}")
402
+
403
+ # # 2) Dense retrieval
404
+ # dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
405
+
406
+ # if not dense_topk:
407
+ # return []
408
+
409
+ # # 3) Cross-encoder rerank
410
+ # candidate_texts = [pair[0] for pair in dense_topk]
411
+ # cross_scores = reranker.rerank(query, candidate_texts, max_length=256)
412
+
413
+ # # Combine
414
+ # combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)]
415
+ # # Sort descending by cross-encoder score
416
+ # combined.sort(key=lambda x: x[1], reverse=True)
417
+
418
+ # return combined
419
 
420
  def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
421
  """Retrieve top-k responses using FAISS."""
422
+ if not hasattr(self.data_pipeline, 'index') or self.data_pipeline.index is None:
423
  logger.warning("FAISS index not initialized. Cannot retrieve responses.")
424
  return []
425
 
426
+ # Encode the query using TFDataPipeline's method
427
+ q_emb = self.data_pipeline.encode_query(query) # Ensure encode_query is within TFDataPipeline
428
  q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
429
 
430
  # Normalize the query embedding for cosine similarity
431
  faiss.normalize_L2(q_emb_np)
432
 
433
  # Search the FAISS index
434
+ distances, indices = self.data_pipeline.index.search(q_emb_np, top_k)
435
 
436
  # Map indices to responses and distances to similarities
437
  top_responses = []
438
  for i, idx in enumerate(indices[0]):
439
+ if idx < len(self.data_pipeline.response_pool):
440
+ top_responses.append((self.data_pipeline.response_pool[idx], float(distances[0][i])))
441
  else:
442
  logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
443
 
444
  return top_responses
445
+ # def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
446
+ # """Retrieve top-k responses using FAISS."""
447
+ # if not hasattr(self, 'index') or self.index is None:
448
+ # logger.warning("FAISS index not initialized. Cannot retrieve responses.")
449
+ # return []
450
 
451
+ # # Encode the query
452
+ # q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
453
+ # q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
 
 
 
454
 
455
+ # # Normalize the query embedding for cosine similarity
456
+ # faiss.normalize_L2(q_emb_np)
457
 
458
+ # # Search the FAISS index
459
+ # distances, indices = self.index.search(q_emb_np, top_k)
 
 
 
 
460
 
461
+ # # Map indices to responses and distances to similarities
462
+ # top_responses = []
463
+ # for i, idx in enumerate(indices[0]):
464
+ # if idx < len(self.response_pool):
465
+ # top_responses.append((self.response_pool[idx], float(distances[0][i])))
466
+ # else:
467
+ # logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
468
 
469
+ # return top_responses
470
+
471
+ def chat(
472
+ self,
473
+ query: str,
474
+ conversation_history: Optional[List[Tuple[str, str]]] = None,
475
+ quality_checker: Optional['ResponseQualityChecker'] = None,
476
+ top_k: int = 5,
477
+ ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
478
+ """
479
+ Example chat method that always uses cross-encoder re-ranking
480
+ if self.reranker is available.
481
+ """
482
+ @self.run_on_device
483
+ def get_response(self_arg, query_arg):
484
+ # 1) Build conversation context string
485
+ conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
486
+
487
+ # 2) Retrieve + cross-encoder re-rank
488
+ results = self_arg.retrieve_responses_cross_encoder(
489
+ query=conversation_str,
490
+ top_k=top_k,
491
+ reranker=self_arg.reranker,
492
+ summarizer=self_arg.summarizer,
493
+ summarize_threshold=512
494
+ )
495
+
496
+ # 3) Handle empty or confidence
497
+ if not results:
498
+ return (
499
+ "I'm sorry, but I couldn't find a relevant response.",
500
+ [],
501
+ {}
502
+ )
503
+
504
+ if quality_checker:
505
+ metrics = quality_checker.check_response_quality(query_arg, results)
506
+ if not metrics.get('is_confident', False):
507
+ return (
508
+ "I need more information to provide a good answer. Could you please clarify?",
509
+ results,
510
+ metrics
511
+ )
512
+ return results[0][0], results, metrics
513
+
514
+ return results[0][0], results, {}
515
 
516
+ return get_response(self, query)
517
+ # def chat(
518
+ # self,
519
+ # query: str,
520
+ # conversation_history: Optional[List[Tuple[str, str]]] = None,
521
+ # quality_checker: Optional['ResponseQualityChecker'] = None,
522
+ # top_k: int = 5,
523
+ # ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
524
+ # """
525
+ # Example chat method that always uses cross-encoder re-ranking
526
+ # if self.reranker is available.
527
+ # """
528
+ # @self.run_on_device
529
+ # def get_response(self_arg, query_arg): # Add parameters that match decorator's expectations
530
+ # # 1) Build conversation context string
531
+ # conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
532
+
533
+ # # 2) Retrieve + cross-encoder re-rank
534
+ # results = self_arg.retrieve_responses_cross_encoder(
535
+ # query=conversation_str,
536
+ # top_k=top_k,
537
+ # reranker=self_arg.reranker,
538
+ # summarizer=self_arg.summarizer,
539
+ # summarize_threshold=512
540
+ # )
541
+
542
+ # # 3) Handle empty or confidence
543
+ # if not results:
544
+ # return (
545
+ # "I'm sorry, but I couldn't find a relevant response.",
546
+ # [],
547
+ # {}
548
+ # )
549
+
550
+ # if quality_checker:
551
+ # metrics = quality_checker.check_response_quality(query_arg, results)
552
+ # if not metrics.get('is_confident', False):
553
+ # return (
554
+ # "I need more information to provide a good answer. Could you please clarify?",
555
+ # results,
556
+ # metrics
557
+ # )
558
+ # return results[0][0], results, metrics
559
+
560
+ # return results[0][0], results, {}
561
 
562
+ # return get_response(self, query)
563
+
564
+ def _build_conversation_context(
565
+ self,
566
+ query: str,
567
+ conversation_history: Optional[List[Tuple[str, str]]]
568
+ ) -> str:
569
+ """Build conversation context with better memory management."""
570
+ if not conversation_history:
571
+ return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
572
+
573
+ conversation_parts = []
574
+ for user_txt, assistant_txt in conversation_history:
575
+ conversation_parts.extend([
576
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}",
577
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}"
578
+ ])
579
+
580
+ conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
581
+ return "\n".join(conversation_parts)
582
+ # def _build_conversation_context(
583
+ # self,
584
+ # query: str,
585
+ # conversation_history: Optional[List[Tuple[str, str]]]
586
+ # ) -> str:
587
+ # """Build conversation context with better memory management."""
588
+ # if not conversation_history:
589
+ # return f"{self.special_tokens['user']} {query}"
590
+
591
+ # conversation_parts = []
592
+ # for user_txt, assistant_txt in conversation_history:
593
+ # conversation_parts.extend([
594
+ # f"{self.special_tokens['user']} {user_txt}",
595
+ # f"{self.special_tokens['assistant']} {assistant_txt}"
596
+ # ])
597
+
598
+ # conversation_parts.append(f"{self.special_tokens['user']} {query}")
599
+ # return "\n".join(conversation_parts)
600
 
601
+ def train_model(
602
  self,
603
  tfrecord_file_path: str,
604
  epochs: int = 20,
 
606
  validation_split: float = 0.2,
607
  checkpoint_dir: str = "checkpoints/",
608
  use_lr_schedule: bool = True,
609
+ peak_lr: float = 1e-5,
610
  warmup_steps_ratio: float = 0.1,
611
  early_stopping_patience: int = 3,
612
  min_delta: float = 1e-4,
613
+ test_mode: bool = False,
614
+ initial_epoch: int = 0
615
  ) -> None:
616
  """Training using a pre-prepared TFRecord dataset."""
617
  logger.info("Starting training with pre-prepared TFRecord dataset...")
 
640
  negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
641
  negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
642
 
643
+ return query_ids, positive_ids, negative_ids
644
+
645
  # Calculate total steps by counting the number of records in the TFRecord
646
  raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path)
647
  total_pairs = sum(1 for _ in raw_dataset)
 
652
  steps_per_epoch = math.ceil(train_size / batch_size)
653
  val_steps = math.ceil(val_size / batch_size)
654
  total_steps = steps_per_epoch * epochs
655
+ buffer_size = total_pairs // 10 # 10% of the dataset
656
 
657
  logger.info(f"Training pairs: {train_size}")
658
  logger.info(f"Validation pairs: {val_size}")
 
675
  logger.info("Using fixed learning rate.")
676
 
677
  # Initialize checkpoint manager
678
+ checkpoint = tf.train.Checkpoint(
679
+ epoch=tf.Variable(0),
680
+ optimizer=self.optimizer,
681
+ model=self.encoder,
682
+ variables=self.encoder.variables
683
+ )
684
+ manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3, checkpoint_name='ckpt')
685
 
686
+ # Restore from checkpoint if available
687
+ latest_checkpoint = manager.latest_checkpoint
688
+ if latest_checkpoint:
689
+ history_path = Path(checkpoint_dir) / 'training_history.json'
690
+ if history_path.exists():
691
+ try:
692
+ with open(history_path, 'r') as f:
693
+ self.history = json.load(f)
694
+ logger.info(f"Loaded previous training history from {history_path}")
695
+ except Exception as e:
696
+ logger.warning(f"Could not load history, starting fresh: {e}")
697
+ self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
698
+ else:
699
+ self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
700
+
701
+ status = checkpoint.restore(latest_checkpoint)
702
+ status.expect_partial()
703
+
704
+ logger.info(f"Restored from checkpoint: {latest_checkpoint}")
705
+ # Get the checkpoint number to validate initial_epoch
706
+ ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
707
+ if initial_epoch == 0:
708
+ initial_epoch = ckpt_number
709
+ logger.info(f"Resuming from epoch {initial_epoch}")
710
+ else:
711
+ logger.info("Starting training from scratch")
712
+ initial_epoch = 0
713
+
714
  # Setup TensorBoard
715
  log_dir = Path(checkpoint_dir) / "tensorboard_logs"
716
  log_dir.mkdir(parents=True, exist_ok=True)
 
726
 
727
  # Create the full dataset
728
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
729
+
730
+ # Test mode for debugging
731
+ if test_mode:
732
+ subset_size = 200
733
+ dataset = dataset.take(subset_size)
734
+ logger.info(f"TEST MODE: Using only {subset_size} examples")
735
+ # Recalculate sizes
736
+ total_pairs = subset_size
737
+ train_size = int(total_pairs * (1 - validation_split))
738
+ val_size = total_pairs - train_size
739
+ steps_per_epoch = math.ceil(train_size / batch_size)
740
+ val_steps = math.ceil(val_size / batch_size)
741
+ total_steps = steps_per_epoch * epochs
742
+ buffer_size = total_pairs // 10 # 10% of the dataset
743
+ epochs = min(epochs, 5) # Limit epochs in test mode
744
+ early_stopping_patience = 2
745
+ logger.info(f"New training pairs: {train_size}")
746
+ logger.info(f"New validation pairs: {val_size}")
747
+
748
  dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
749
+
750
+ # Split into training and validation sets
 
 
 
751
  train_dataset = dataset.take(train_size)
752
  val_dataset = dataset.skip(train_size).take(val_size)
753
 
754
+ # Shuffle the training data
755
+ train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
756
+
757
+ # Batch both datasets
758
+ train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
759
+ train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
760
+
761
+ val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
762
+ val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
763
+ val_dataset = val_dataset.cache()
764
+
765
  # Training loop
766
  best_val_loss = float("inf")
767
  epochs_no_improve = 0
768
 
769
+ for epoch in range(initial_epoch + 1, epochs + 1):
770
  # --- Training Phase ---
771
  epoch_loss_avg = tf.keras.metrics.Mean()
772
  batches_processed = 0
 
780
  logger.info("Training progress bar disabled")
781
 
782
  for q_batch, p_batch, n_batch in train_dataset:
783
+ loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
784
+
785
+ # Check for gradient issues
786
+ grad_norm_value = float(grad_norm.numpy())
787
+ post_clip_value = float(post_clip_norm.numpy())
788
+ if grad_norm_value < 1e-7:
789
+ logger.warning(f"Potential vanishing gradient detected: norm = {grad_norm_value:.2e}")
790
+ elif grad_norm_value > 100:
791
+ logger.warning(f"Potential exploding gradient detected: norm = {grad_norm_value:.2e}")
792
+
793
+ if grad_norm_value != post_clip_value:
794
+ logger.info(f"Gradient clipped: {grad_norm_value:.2e} -> {post_clip_value:.2e}")
795
+
796
  epoch_loss_avg(loss)
797
  batches_processed += 1
798
 
799
  # Log to TensorBoard
800
  with train_summary_writer.as_default():
801
+ step = (epoch - 1) * steps_per_epoch + batches_processed
802
+ tf.summary.scalar("loss", loss, step=step)
803
+ tf.summary.scalar("gradient_norm_pre_clip", grad_norm, step=step)
804
+ tf.summary.scalar("gradient_norm_post_clip", post_clip_norm, step=step)
805
 
806
  # Update progress bar
807
  if use_lr_schedule:
 
813
  train_pbar.update(1)
814
  train_pbar.set_postfix({
815
  "loss": f"{loss.numpy():.4f}",
816
+ "pre_clip": f"{grad_norm_value:.2e}",
817
+ "post_clip": f"{post_clip_value:.2e}",
818
  "lr": f"{current_lr:.2e}",
819
  "batches": f"{batches_processed}/{steps_per_epoch}"
820
  })
 
874
 
875
  # Save checkpoint
876
  manager.save()
877
+
878
+ # Save model after each epoch for testing/inference
879
+ model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
880
+ self.save_models(model_save_path)
881
+ logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
882
 
883
  # Store metrics in history
884
  self.history['train_loss'].append(train_loss)
 
889
  else:
890
  current_lr = float(self.optimizer.learning_rate.numpy())
891
 
892
+ # Log learning rate
893
  self.history.setdefault('learning_rate', []).append(current_lr)
894
 
895
+ # Save history to file
896
+ with open(history_path, 'w') as f:
897
+ json.dump(self.history, f)
898
+ logger.info(f"Saved training history to {history_path}")
899
+
900
  # Early stopping logic
901
  if val_loss < best_val_loss - min_delta:
902
  best_val_loss = val_loss
 
965
  )
966
  loss = tf.reduce_mean(loss)
967
 
968
+ # Calculate gradients
969
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
970
+ gradients_norm = tf.linalg.global_norm(gradients)
971
+
972
+ # Clip gradients if norm exceeds threshold
973
+ max_grad_norm = 1.0
974
+ gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
975
+ post_clip_norm = tf.linalg.global_norm(gradients)
976
+
977
+ # Apply gradients
978
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
979
+
980
+ return loss, gradients_norm, post_clip_norm
981
 
982
  @tf.function
983
  def validation_step(
 
1015
  loss = tf.reduce_mean(loss)
1016
 
1017
  return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1018
 
1019
  def _get_lr_schedule(
1020
  self,
 
1081
  }
1082
 
1083
  return CustomSchedule(total_steps, peak_lr, warmup_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chatbot_validator.py CHANGED
@@ -1,23 +1,23 @@
1
  from typing import Dict, List, Tuple, Any, Optional
2
  import numpy as np
3
- from logger_config import config_logger
4
 
 
5
  logger = config_logger(__name__)
6
 
7
  class ChatbotValidator:
8
  """Handles automated validation and performance analysis for the chatbot."""
9
-
10
  def __init__(self, chatbot, quality_checker):
11
  """
12
  Initialize the validator.
13
-
14
  Args:
15
  chatbot: RetrievalChatbot instance
16
  quality_checker: ResponseQualityChecker instance
17
  """
18
  self.chatbot = chatbot
19
  self.quality_checker = quality_checker
20
-
21
  # Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
22
  self.domain_queries = {
23
  'restaurant': [
@@ -59,50 +59,50 @@ class ChatbotValidator:
59
 
60
  def run_validation(
61
  self,
62
- num_examples: int = 10,
63
  top_k: int = 10,
64
  domains: Optional[List[str]] = None
65
  ) -> Dict[str, Any]:
66
  """
67
  Run comprehensive validation across specified domains.
68
-
69
  Args:
70
  num_examples: Number of test queries per domain
71
  top_k: Number of responses to retrieve for each query
72
  domains: Optional list of specific domains to test
73
-
74
  Returns:
75
  Dict containing detailed validation metrics and domain-specific performance
76
  """
77
  logger.info("\n=== Running Enhanced Automatic Validation ===")
78
-
79
  # Select domains to test
80
  test_domains = domains if domains else list(self.domain_queries.keys())
81
  metrics_history = []
82
  domain_metrics = {}
83
-
84
  # Run validation for each domain
85
  for domain in test_domains:
86
  domain_metrics[domain] = []
87
  queries = self.domain_queries[domain][:num_examples]
88
-
89
  logger.info(f"\n=== Testing {domain.title()} Domain ===")
90
-
91
  for i, query in enumerate(queries, 1):
92
  logger.info(f"\nTest Case {i}:")
93
  logger.info(f"Query: {query}")
94
-
95
  # Get responses with increased top_k
96
  responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
97
-
98
- # Enhanced quality checking with context
99
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
100
-
101
  # Add domain info
102
  quality_metrics['domain'] = domain
103
  metrics_history.append(quality_metrics)
104
  domain_metrics[domain].append(quality_metrics)
105
-
106
  # Detailed logging
107
  self._log_validation_results(query, responses, quality_metrics, i)
108
 
@@ -110,12 +110,12 @@ class ChatbotValidator:
110
  aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
111
  domain_analysis = self._analyze_domain_performance(domain_metrics)
112
  confidence_analysis = self._analyze_confidence_distribution(metrics_history)
113
-
114
  aggregate_metrics.update({
115
  'domain_performance': domain_analysis,
116
  'confidence_analysis': confidence_analysis
117
  })
118
-
119
  self._log_validation_summary(aggregate_metrics)
120
  return aggregate_metrics
121
 
@@ -129,7 +129,7 @@ class ChatbotValidator:
129
  'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
130
  'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
131
  'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
132
-
133
  # Additional statistical metrics
134
  'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
135
  'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
@@ -141,7 +141,7 @@ class ChatbotValidator:
141
  def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
142
  """Analyze performance by domain."""
143
  domain_analysis = {}
144
-
145
  for domain, metrics in domain_metrics.items():
146
  domain_analysis[domain] = {
147
  'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
@@ -150,13 +150,13 @@ class ChatbotValidator:
150
  'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
151
  'num_samples': len(metrics)
152
  }
153
-
154
  return domain_analysis
155
 
156
  def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
157
  """Analyze the distribution of confidence scores."""
158
  scores = [m.get('top_score', 0) for m in metrics_history]
159
-
160
  return {
161
  'percentile_25': np.percentile(scores, 25),
162
  'percentile_50': np.percentile(scores, 50),
@@ -180,7 +180,7 @@ class ChatbotValidator:
180
  for metric, value in metrics.items():
181
  if isinstance(value, (int, float)):
182
  logger.info(f" {metric}: {value:.4f}")
183
-
184
  logger.info("\nTop Responses:")
185
  for i, (response, score) in enumerate(responses[:3], 1):
186
  logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
@@ -190,18 +190,18 @@ class ChatbotValidator:
190
  def _log_validation_summary(self, metrics: Dict[str, Any]):
191
  """Log comprehensive validation summary."""
192
  logger.info("\n=== Validation Summary ===")
193
-
194
  logger.info("\nOverall Metrics:")
195
  for metric, value in metrics.items():
196
  if isinstance(value, (int, float)):
197
  logger.info(f"{metric}: {value:.4f}")
198
-
199
  logger.info("\nDomain Performance:")
200
  for domain, domain_metrics in metrics['domain_performance'].items():
201
  logger.info(f"\n{domain.title()}:")
202
  for metric, value in domain_metrics.items():
203
  logger.info(f" {metric}: {value:.4f}")
204
-
205
  logger.info("\nConfidence Distribution:")
206
  for percentile, value in metrics['confidence_analysis'].items():
207
  logger.info(f"{percentile}: {value:.4f}")
 
1
  from typing import Dict, List, Tuple, Any, Optional
2
  import numpy as np
 
3
 
4
+ from logger_config import config_logger
5
  logger = config_logger(__name__)
6
 
7
  class ChatbotValidator:
8
  """Handles automated validation and performance analysis for the chatbot."""
9
+
10
  def __init__(self, chatbot, quality_checker):
11
  """
12
  Initialize the validator.
13
+
14
  Args:
15
  chatbot: RetrievalChatbot instance
16
  quality_checker: ResponseQualityChecker instance
17
  """
18
  self.chatbot = chatbot
19
  self.quality_checker = quality_checker
20
+
21
  # Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
22
  self.domain_queries = {
23
  'restaurant': [
 
59
 
60
  def run_validation(
61
  self,
62
+ num_examples: int = 5,
63
  top_k: int = 10,
64
  domains: Optional[List[str]] = None
65
  ) -> Dict[str, Any]:
66
  """
67
  Run comprehensive validation across specified domains.
68
+
69
  Args:
70
  num_examples: Number of test queries per domain
71
  top_k: Number of responses to retrieve for each query
72
  domains: Optional list of specific domains to test
73
+
74
  Returns:
75
  Dict containing detailed validation metrics and domain-specific performance
76
  """
77
  logger.info("\n=== Running Enhanced Automatic Validation ===")
78
+
79
  # Select domains to test
80
  test_domains = domains if domains else list(self.domain_queries.keys())
81
  metrics_history = []
82
  domain_metrics = {}
83
+
84
  # Run validation for each domain
85
  for domain in test_domains:
86
  domain_metrics[domain] = []
87
  queries = self.domain_queries[domain][:num_examples]
88
+
89
  logger.info(f"\n=== Testing {domain.title()} Domain ===")
90
+
91
  for i, query in enumerate(queries, 1):
92
  logger.info(f"\nTest Case {i}:")
93
  logger.info(f"Query: {query}")
94
+
95
  # Get responses with increased top_k
96
  responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
97
+
98
+ # Enhanced quality checking with context (assuming no context here)
99
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
100
+
101
  # Add domain info
102
  quality_metrics['domain'] = domain
103
  metrics_history.append(quality_metrics)
104
  domain_metrics[domain].append(quality_metrics)
105
+
106
  # Detailed logging
107
  self._log_validation_results(query, responses, quality_metrics, i)
108
 
 
110
  aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
111
  domain_analysis = self._analyze_domain_performance(domain_metrics)
112
  confidence_analysis = self._analyze_confidence_distribution(metrics_history)
113
+
114
  aggregate_metrics.update({
115
  'domain_performance': domain_analysis,
116
  'confidence_analysis': confidence_analysis
117
  })
118
+
119
  self._log_validation_summary(aggregate_metrics)
120
  return aggregate_metrics
121
 
 
129
  'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
130
  'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
131
  'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
132
+
133
  # Additional statistical metrics
134
  'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
135
  'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
 
141
  def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
142
  """Analyze performance by domain."""
143
  domain_analysis = {}
144
+
145
  for domain, metrics in domain_metrics.items():
146
  domain_analysis[domain] = {
147
  'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
 
150
  'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
151
  'num_samples': len(metrics)
152
  }
153
+
154
  return domain_analysis
155
 
156
  def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
157
  """Analyze the distribution of confidence scores."""
158
  scores = [m.get('top_score', 0) for m in metrics_history]
159
+
160
  return {
161
  'percentile_25': np.percentile(scores, 25),
162
  'percentile_50': np.percentile(scores, 50),
 
180
  for metric, value in metrics.items():
181
  if isinstance(value, (int, float)):
182
  logger.info(f" {metric}: {value:.4f}")
183
+
184
  logger.info("\nTop Responses:")
185
  for i, (response, score) in enumerate(responses[:3], 1):
186
  logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
 
190
  def _log_validation_summary(self, metrics: Dict[str, Any]):
191
  """Log comprehensive validation summary."""
192
  logger.info("\n=== Validation Summary ===")
193
+
194
  logger.info("\nOverall Metrics:")
195
  for metric, value in metrics.items():
196
  if isinstance(value, (int, float)):
197
  logger.info(f"{metric}: {value:.4f}")
198
+
199
  logger.info("\nDomain Performance:")
200
  for domain, domain_metrics in metrics['domain_performance'].items():
201
  logger.info(f"\n{domain.title()}:")
202
  for metric, value in domain_metrics.items():
203
  logger.info(f" {metric}: {value:.4f}")
204
+
205
  logger.info("\nConfidence Distribution:")
206
  for percentile, value in metrics['confidence_analysis'].items():
207
  logger.info(f"{percentile}: {value:.4f}")
conversation_summarizer.py CHANGED
@@ -49,7 +49,15 @@ class Summarizer(DeviceAwareModel):
49
  Handles long conversations by intelligent chunking and progressive summarization.
50
  """
51
 
52
- def __init__(self, model_name="t5-small", max_summary_length=128, device=None, max_summary_rounds=2):
 
 
 
 
 
 
 
 
53
  self.setup_device(device)
54
 
55
  # Initialize model within strategy scope if using distribution
@@ -63,12 +71,11 @@ class Summarizer(DeviceAwareModel):
63
  self.max_summary_rounds = max_summary_rounds
64
 
65
  def _setup_model(self, model_name):
66
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
67
  self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
68
 
69
  # Optimize model for inference
70
- self.model.predict = tf.function(
71
- self.model.predict,
72
  input_signature=[
73
  {
74
  'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
 
49
  Handles long conversations by intelligent chunking and progressive summarization.
50
  """
51
 
52
+ def __init__(
53
+ self,
54
+ tokenizer: AutoTokenizer,
55
+ model_name="t5-small",
56
+ max_summary_length=128,
57
+ device=None,
58
+ max_summary_rounds=2
59
+ ):
60
+ self.tokenizer = tokenizer # Injected tokenizer
61
  self.setup_device(device)
62
 
63
  # Initialize model within strategy scope if using distribution
 
71
  self.max_summary_rounds = max_summary_rounds
72
 
73
  def _setup_model(self, model_name):
 
74
  self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
75
 
76
  # Optimize model for inference
77
+ self.model.generate = tf.function(
78
+ self.model.generate,
79
  input_signature=[
80
  {
81
  'input_ids': tf.TensorSpec(shape=[None, None], dtype=tf.int32),
environment_setup.py CHANGED
@@ -122,15 +122,6 @@ class EnvironmentSetup:
122
  except (subprocess.SubprocessError, FileNotFoundError):
123
  logger.warning("Could not detect specific GPU model")
124
 
125
- # # Enable XLA
126
- # tf.config.optimizer.set_jit(True)
127
- # logger.info("XLA compilation enabled for Colab GPU")
128
-
129
- # # Set mixed precision policy
130
- # policy = tf.keras.mixed_precision.Policy('mixed_float16')
131
- # tf.keras.mixed_precision.set_global_policy(policy)
132
- # logger.info("Mixed precision training enabled (float16)")
133
-
134
  strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
135
  return "GPU", strategy
136
 
 
122
  except (subprocess.SubprocessError, FileNotFoundError):
123
  logger.warning("Could not detect specific GPU model")
124
 
 
 
 
 
 
 
 
 
 
125
  strategy = tf.distribute.OneDeviceStrategy("/GPU:0")
126
  return "GPU", strategy
127
 
run_data_preparer.py → prepare_data.py RENAMED
@@ -1,6 +1,7 @@
1
  import os
2
  import sys
3
  import faiss
 
4
  import pickle
5
  from transformers import AutoTokenizer
6
  from tqdm.auto import tqdm
@@ -52,36 +53,24 @@ def main():
52
  config = ChatbotConfig()
53
  logger.info(f"Chatbot Configuration: {config}")
54
 
55
- # Initialize tokenizer
56
  try:
57
  tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
58
  logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
59
- except Exception as e:
60
- logger.error(f"Failed to load tokenizer: {e}")
61
- sys.exit(1)
62
-
63
- # Add special tokens
64
- try:
65
  tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
66
  logger.info("Added special tokens to tokenizer.")
67
  except Exception as e:
68
- logger.error(f"Failed to add special tokens: {e}")
69
  sys.exit(1)
70
 
71
- # Initialize encoder model
72
  try:
73
  encoder = EncoderModel(config=config)
74
  logger.info("EncoderModel initialized successfully.")
75
- except Exception as e:
76
- logger.error(f"Failed to initialize EncoderModel: {e}")
77
- sys.exit(1)
78
-
79
- # Resize token embeddings in encoder to match tokenizer
80
- try:
81
  encoder.pretrained.resize_token_embeddings(len(tokenizer))
82
  logger.info(f"Token embeddings resized to: {len(tokenizer)}")
83
  except Exception as e:
84
- logger.error(f"Failed to resize token embeddings: {e}")
85
  sys.exit(1)
86
 
87
  # Load JSON dialogues
@@ -116,6 +105,8 @@ def main():
116
  max_length=config.max_context_token_limit,
117
  neg_samples=config.neg_samples,
118
  query_embeddings_cache=query_embeddings_cache,
 
 
119
  max_retries=config.max_retries
120
  )
121
  logger.info("TFDataPipeline initialized successfully.")
@@ -135,17 +126,22 @@ def main():
135
  # Compute and add response embeddings to FAISS index
136
  try:
137
  logger.info("Computing and adding response embeddings to FAISS index...")
138
- data_pipeline._compute_and_index_response_embeddings()
139
  logger.info("Response embeddings computed and added to FAISS index.")
140
  except Exception as e:
141
  logger.error(f"Failed to compute or add response embeddings: {e}")
142
  sys.exit(1)
143
 
144
- # Save FAISS index
145
  try:
146
  logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
147
  faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
148
  logger.info("FAISS index saved successfully.")
 
 
 
 
 
149
  except Exception as e:
150
  logger.error(f"Failed to save FAISS index: {e}")
151
  sys.exit(1)
 
1
  import os
2
  import sys
3
  import faiss
4
+ import json
5
  import pickle
6
  from transformers import AutoTokenizer
7
  from tqdm.auto import tqdm
 
53
  config = ChatbotConfig()
54
  logger.info(f"Chatbot Configuration: {config}")
55
 
56
+ # Initialize tokenizer and add special tokens
57
  try:
58
  tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
59
  logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
 
 
 
 
 
 
60
  tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
61
  logger.info("Added special tokens to tokenizer.")
62
  except Exception as e:
63
+ logger.error(f"Failed to load tokenizer: {e}")
64
  sys.exit(1)
65
 
66
+ # Initialize encoder model and resize token embeddings
67
  try:
68
  encoder = EncoderModel(config=config)
69
  logger.info("EncoderModel initialized successfully.")
 
 
 
 
 
 
70
  encoder.pretrained.resize_token_embeddings(len(tokenizer))
71
  logger.info(f"Token embeddings resized to: {len(tokenizer)}")
72
  except Exception as e:
73
+ logger.error(f"Failed to initialize EncoderModel: {e}")
74
  sys.exit(1)
75
 
76
  # Load JSON dialogues
 
105
  max_length=config.max_context_token_limit,
106
  neg_samples=config.neg_samples,
107
  query_embeddings_cache=query_embeddings_cache,
108
+ index_type='IndexFlatIP',
109
+ nlist=100,
110
  max_retries=config.max_retries
111
  )
112
  logger.info("TFDataPipeline initialized successfully.")
 
126
  # Compute and add response embeddings to FAISS index
127
  try:
128
  logger.info("Computing and adding response embeddings to FAISS index...")
129
+ data_pipeline.compute_and_index_response_embeddings()
130
  logger.info("Response embeddings computed and added to FAISS index.")
131
  except Exception as e:
132
  logger.error(f"Failed to compute or add response embeddings: {e}")
133
  sys.exit(1)
134
 
135
+ # Save FAISS index and response pool
136
  try:
137
  logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
138
  faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
139
  logger.info("FAISS index saved successfully.")
140
+
141
+ response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json')
142
+ with open(response_pool_path, 'w', encoding='utf-8') as f:
143
+ json.dump(data_pipeline.response_pool, f, indent=2)
144
+ logger.info(f"Response pool saved to {response_pool_path}.")
145
  except Exception as e:
146
  logger.error(f"Failed to save FAISS index: {e}")
147
  sys.exit(1)
response_quality_checker.py CHANGED
@@ -6,14 +6,14 @@ from logger_config import config_logger
6
  logger = config_logger(__name__)
7
 
8
  if TYPE_CHECKING:
9
- from chatbot_model import RetrievalChatbot
10
 
11
  class ResponseQualityChecker:
12
  """Enhanced quality checking with dynamic thresholds."""
13
-
14
  def __init__(
15
  self,
16
- chatbot: 'RetrievalChatbot',
17
  confidence_threshold: float = 0.6,
18
  diversity_threshold: float = 0.15,
19
  min_response_length: int = 5,
@@ -23,15 +23,15 @@ class ResponseQualityChecker:
23
  self.diversity_threshold = diversity_threshold
24
  self.min_response_length = min_response_length
25
  self.similarity_cap = similarity_cap
26
- self.chatbot = chatbot
27
-
28
  # Dynamic thresholds based on response patterns
29
  self.thresholds = {
30
  'relevance': 0.35,
31
  'length_score': 0.85,
32
  'score_gap': 0.07
33
  }
34
-
35
  def check_response_quality(
36
  self,
37
  query: str,
@@ -39,11 +39,11 @@ class ResponseQualityChecker:
39
  ) -> Dict[str, Any]:
40
  """
41
  Evaluate the quality of responses based on various metrics.
42
-
43
  Args:
44
  query: The user's query
45
  responses: List of (response_text, score) tuples
46
-
47
  Returns:
48
  Dict containing quality metrics and confidence assessment
49
  """
@@ -56,7 +56,7 @@ class ResponseQualityChecker:
56
  'response_length_score': 0.0,
57
  'top_3_score_gap': 0.0
58
  }
59
-
60
  # Calculate core metrics
61
  metrics = {
62
  'response_diversity': self.calculate_diversity(responses),
@@ -67,10 +67,10 @@ class ResponseQualityChecker:
67
  'top_score': responses[0][1],
68
  'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
69
  }
70
-
71
  # Determine confidence using thresholds
72
  metrics['is_confident'] = self._determine_confidence(metrics)
73
-
74
  logger.info(f"Quality metrics: {metrics}")
75
  return metrics
76
 
@@ -78,44 +78,45 @@ class ResponseQualityChecker:
78
  """Calculate relevance as weighted similarity between query and responses."""
79
  if not responses:
80
  return 0.0
81
-
82
  # Get embeddings
83
- query_embedding = self.encode_query(query)
84
- response_embeddings = [self.encode_text(response) for response, _ in responses]
85
-
86
- # Compute similarities with decreasing weights for later responses
 
87
  similarities = cosine_similarity([query_embedding], response_embeddings)[0]
 
 
88
  weights = np.array([1.0 / (i + 1) for i in range(len(similarities))])
89
-
90
  return np.average(similarities, weights=weights)
91
 
92
  def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
93
  """Calculate diversity with length normalization and similarity capping."""
94
  if not responses:
95
  return 0.0
96
-
97
- embeddings = [self.encode_text(response) for response, _ in responses]
 
98
  if len(embeddings) < 2:
99
  return 1.0
100
-
101
- # Calculate similarities and apply cap
102
  similarity_matrix = cosine_similarity(embeddings)
 
 
 
103
  similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
104
-
105
- # Apply length normalization
106
- lengths = [len(resp[0].split()) for resp in responses]
107
- length_ratios = np.array([min(a, b) / max(a, b) for a in lengths for b in lengths])
108
- length_ratios = length_ratios.reshape(len(responses), len(responses))
109
-
110
- # Combine factors with weights
111
- adjusted_similarity = (similarity_matrix * 0.7 + length_ratios * 0.3)
112
-
113
- # Calculate final score
114
- sum_similarities = np.sum(adjusted_similarity) - len(responses)
115
- num_pairs = len(responses) * (len(responses) - 1)
116
  avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
117
-
118
- return 1 - avg_similarity
 
 
119
 
120
  def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
121
  """Determine confidence using primary and secondary conditions."""
@@ -125,20 +126,20 @@ class ResponseQualityChecker:
125
  metrics['response_diversity'] >= self.diversity_threshold,
126
  metrics['response_length_score'] >= self.thresholds['length_score']
127
  ]
128
-
129
  # Secondary conditions (majority must be met)
130
  secondary_conditions = [
131
  metrics['query_response_relevance'] >= self.thresholds['relevance'],
132
  metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
133
  metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
134
  ]
135
-
136
  return all(primary_conditions) and sum(secondary_conditions) >= 2
137
 
138
  def _calculate_length_score(self, response: str) -> float:
139
  """Calculate length score with penalty for very short or long responses."""
140
  words = len(response.split())
141
-
142
  if words < self.min_response_length:
143
  return words / self.min_response_length
144
  elif words > 50: # Penalty for very long responses
@@ -150,21 +151,4 @@ class ResponseQualityChecker:
150
  if len(scores) < top_n + 1:
151
  return 0.0
152
  gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
153
- return np.mean(gaps)
154
-
155
- def encode_text(self, text: str) -> np.ndarray:
156
- """Encode response text to embedding."""
157
- embedding_tensor = self.chatbot.encode_responses([text])
158
- embedding = embedding_tensor.numpy()[0].astype('float32')
159
- return self._normalize_embedding(embedding)
160
-
161
- def encode_query(self, query: str) -> np.ndarray:
162
- """Encode query text to embedding."""
163
- embedding_tensor = self.chatbot.encode_query(query)
164
- embedding = embedding_tensor.numpy()[0].astype('float32')
165
- return self._normalize_embedding(embedding)
166
-
167
- def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
168
- """Normalize embedding vector."""
169
- norm = np.linalg.norm(embedding)
170
- return embedding / norm if norm > 0 else embedding
 
6
  logger = config_logger(__name__)
7
 
8
  if TYPE_CHECKING:
9
+ from tf_data_pipeline import TFDataPipeline
10
 
11
  class ResponseQualityChecker:
12
  """Enhanced quality checking with dynamic thresholds."""
13
+
14
  def __init__(
15
  self,
16
+ data_pipeline: 'TFDataPipeline',
17
  confidence_threshold: float = 0.6,
18
  diversity_threshold: float = 0.15,
19
  min_response_length: int = 5,
 
23
  self.diversity_threshold = diversity_threshold
24
  self.min_response_length = min_response_length
25
  self.similarity_cap = similarity_cap
26
+ self.data_pipeline = data_pipeline # Reference to TFDataPipeline
27
+
28
  # Dynamic thresholds based on response patterns
29
  self.thresholds = {
30
  'relevance': 0.35,
31
  'length_score': 0.85,
32
  'score_gap': 0.07
33
  }
34
+
35
  def check_response_quality(
36
  self,
37
  query: str,
 
39
  ) -> Dict[str, Any]:
40
  """
41
  Evaluate the quality of responses based on various metrics.
42
+
43
  Args:
44
  query: The user's query
45
  responses: List of (response_text, score) tuples
46
+
47
  Returns:
48
  Dict containing quality metrics and confidence assessment
49
  """
 
56
  'response_length_score': 0.0,
57
  'top_3_score_gap': 0.0
58
  }
59
+
60
  # Calculate core metrics
61
  metrics = {
62
  'response_diversity': self.calculate_diversity(responses),
 
67
  'top_score': responses[0][1],
68
  'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
69
  }
70
+
71
  # Determine confidence using thresholds
72
  metrics['is_confident'] = self._determine_confidence(metrics)
73
+
74
  logger.info(f"Quality metrics: {metrics}")
75
  return metrics
76
 
 
78
  """Calculate relevance as weighted similarity between query and responses."""
79
  if not responses:
80
  return 0.0
81
+
82
  # Get embeddings
83
+ query_embedding = self.data_pipeline.encode_query(query)
84
+ response_texts = [resp for resp, _ in responses]
85
+ response_embeddings = self.data_pipeline.encode_responses(response_texts)
86
+
87
+ # Compute similarities
88
  similarities = cosine_similarity([query_embedding], response_embeddings)[0]
89
+
90
+ # Apply decreasing weights for later responses
91
  weights = np.array([1.0 / (i + 1) for i in range(len(similarities))])
92
+
93
  return np.average(similarities, weights=weights)
94
 
95
  def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
96
  """Calculate diversity with length normalization and similarity capping."""
97
  if not responses:
98
  return 0.0
99
+
100
+ response_texts = [resp for resp, _ in responses]
101
+ embeddings = self.data_pipeline.encode_responses(response_texts)
102
  if len(embeddings) < 2:
103
  return 1.0
104
+
105
+ # Calculate pairwise cosine similarities
106
  similarity_matrix = cosine_similarity(embeddings)
107
+ np.fill_diagonal(similarity_matrix, 0) # Exclude self-similarity
108
+
109
+ # Apply similarity cap
110
  similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
111
+
112
+ # Calculate average similarity
113
+ sum_similarities = np.sum(similarity_matrix)
114
+ num_pairs = len(embeddings) * (len(embeddings) - 1)
 
 
 
 
 
 
 
 
115
  avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
116
+
117
+ # Diversity is inversely related to average similarity
118
+ diversity_score = 1 - avg_similarity
119
+ return diversity_score
120
 
121
  def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
122
  """Determine confidence using primary and secondary conditions."""
 
126
  metrics['response_diversity'] >= self.diversity_threshold,
127
  metrics['response_length_score'] >= self.thresholds['length_score']
128
  ]
129
+
130
  # Secondary conditions (majority must be met)
131
  secondary_conditions = [
132
  metrics['query_response_relevance'] >= self.thresholds['relevance'],
133
  metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
134
  metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
135
  ]
136
+
137
  return all(primary_conditions) and sum(secondary_conditions) >= 2
138
 
139
  def _calculate_length_score(self, response: str) -> float:
140
  """Calculate length score with penalty for very short or long responses."""
141
  words = len(response.split())
142
+
143
  if words < self.min_response_length:
144
  return words / self.min_response_length
145
  elif words > 50: # Penalty for very long responses
 
151
  if len(scores) < top_n + 1:
152
  return 0.0
153
  gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
154
+ return np.mean(gaps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_trained_model.py DELETED
File without changes
tf_data_pipeline.py CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
11
  from typing import Union, Optional, List, Tuple, Generator
12
  from transformers import AutoTokenizer
13
  from typing import List, Tuple, Generator
 
14
  from gpu_monitor import GPUMemoryMonitor
15
 
16
  from logger_config import config_logger
@@ -31,7 +32,6 @@ class TFDataPipeline:
31
  nlist: int = 100,
32
  max_retries: int = 3
33
  ):
34
- #self.embedding_batch_size = embedding_batch_size
35
  self.config = config
36
  self.tokenizer = tokenizer
37
  self.encoder = encoder
@@ -64,14 +64,6 @@ class TFDataPipeline:
64
  dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
65
  self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
66
  self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
67
-
68
- def validate_faiss_index(self):
69
- """Validates that the FAISS index has the correct dimensionality."""
70
- expected_dim = self.encoder.config.embedding_dim
71
- if self.index.d != expected_dim:
72
- logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
73
- raise ValueError("FAISS index dimensionality mismatch.")
74
- logger.info("FAISS index dimension validated successfully.")
75
 
76
  def save_embeddings_cache_hdf5(self, cache_file_path: str):
77
  """Save the embeddings cache to an HDF5 file."""
@@ -92,8 +84,21 @@ class TFDataPipeline:
92
  logger.info(f"FAISS index saved to {index_file_path}")
93
 
94
  def load_faiss_index(self, index_file_path: str):
95
- self.index = faiss.read_index(index_file_path)
96
- logger.info(f"FAISS index loaded from {index_file_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def save_tokenizer(self, tokenizer_dir: str):
99
  self.tokenizer.save_pretrained(tokenizer_dir)
@@ -102,19 +107,6 @@ class TFDataPipeline:
102
  def load_tokenizer(self, tokenizer_dir: str):
103
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
104
  logger.info(f"Tokenizer loaded from {tokenizer_dir}")
105
-
106
- def estimate_total_pairs(self, dialogues: List[dict]) -> int:
107
- """Estimate total number of training pairs including hard negatives."""
108
- base_pairs = sum(
109
- len([
110
- 1 for i in range(len(d.get('turns', [])) - 1)
111
- if (d['turns'][i].get('speaker') == 'user' and
112
- d['turns'][i+1].get('speaker') == 'assistant')
113
- ])
114
- for d in dialogues
115
- )
116
- # Account for hard negatives
117
- return base_pairs * (1 + self.neg_samples)
118
 
119
  @staticmethod
120
  def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
@@ -179,7 +171,7 @@ class TFDataPipeline:
179
 
180
  return pairs
181
 
182
- def _compute_and_index_response_embeddings(self):
183
  """
184
  Computes embeddings for the response pool and adds them to the FAISS index with progress bars.
185
  """
@@ -239,49 +231,6 @@ class TFDataPipeline:
239
 
240
  # **Sanity Check:** Verify the number of embeddings in FAISS index
241
  logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
242
- # def _compute_and_index_response_embeddings(self):
243
- # """
244
- # Computes embeddings for the response pool and adds them to the FAISS index.
245
- # """
246
- # logger.info("Computing embeddings for the response pool...")
247
-
248
- # # Ensure all responses are strings
249
- # if not all(isinstance(response, str) for response in self.response_pool):
250
- # logger.error("All elements in response_pool must be strings.")
251
- # raise ValueError("Invalid data type in response_pool.")
252
-
253
- # # Proceed with tokenization
254
- # encoded_responses = self.tokenizer(
255
- # self.response_pool,
256
- # padding=True,
257
- # truncation=True,
258
- # max_length=self.max_length,
259
- # return_tensors='tf'
260
- # )
261
- # response_ids = encoded_responses['input_ids']
262
-
263
- # # Compute embeddings in batches
264
- # batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
265
- # embeddings = []
266
- # for i in range(0, len(response_ids), batch_size):
267
- # batch_ids = response_ids[i:i+batch_size]
268
- # # Compute embeddings
269
- # batch_embeddings = self.encoder(batch_ids, training=False).numpy()
270
- # # Normalize embeddings if using inner product or cosine similarity
271
- # faiss.normalize_L2(batch_embeddings)
272
- # embeddings.append(batch_embeddings)
273
-
274
- # if embeddings:
275
- # embeddings = np.vstack(embeddings).astype(np.float32)
276
- # # Add embeddings to FAISS index
277
- # logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
278
- # self.index.add(embeddings)
279
- # logger.info("Response embeddings added to FAISS index.")
280
- # else:
281
- # logger.warning("No embeddings to add to FAISS index.")
282
-
283
- # # **Sanity Check:** Verify the number of embeddings in FAISS index
284
- # logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
285
 
286
  def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
287
  """Find hard negatives for a batch of queries with error handling and retries."""
@@ -355,58 +304,109 @@ class TFDataPipeline:
355
  if tf.config.list_physical_devices('GPU'):
356
  tf.keras.backend.clear_session()
357
 
358
- def _tokenize_and_encode(self, queries: List[str], positives: List[str], negatives: List[List[str]]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
359
  """
360
- Tokenize and encode the queries, positives, and negatives.
 
 
 
 
 
361
  Returns:
362
- query_ids: [batch_size, max_length]
363
- positive_ids: [batch_size, max_length]
364
- negative_ids: [batch_size, neg_samples, max_length]
365
  """
366
- # Tokenize queries
367
- q_enc = self.tokenizer(
368
- queries,
369
- padding="max_length",
370
- truncation=True,
371
- max_length=self.max_length,
372
- return_tensors="np"
373
- )
374
- # Tokenize positives
375
- p_enc = self.tokenizer(
376
- positives,
377
- padding="max_length",
 
 
 
 
378
  truncation=True,
379
  max_length=self.max_length,
380
- return_tensors="np"
381
  )
382
- # Tokenize negatives
383
- # Flatten negatives
384
- flattened_negatives = [neg for sublist in negatives for neg in sublist]
385
- if len(flattened_negatives) == 0:
386
- # No negatives at all: return a zero array
387
- n_ids = np.zeros((len(queries), self.neg_samples, self.max_length), dtype=np.int32)
388
- else:
389
- n_enc = self.tokenizer(
390
- flattened_negatives,
391
- padding="max_length",
392
- truncation=True,
393
- max_length=self.max_length,
394
- return_tensors="np"
395
- )
396
- n_input_ids = n_enc["input_ids"]
397
-
398
- # Reshape to [batch_size, neg_samples, max_length]
399
- batch_size = len(queries)
400
- n_ids = n_input_ids.reshape(batch_size, self.neg_samples, self.max_length)
401
 
402
- # Convert to int32
403
- query_ids = q_enc["input_ids"].astype(np.int32)
404
- positive_ids = p_enc["input_ids"].astype(np.int32)
405
- negative_ids = n_ids.astype(np.int32)
406
 
407
- return query_ids, positive_ids, negative_ids
 
 
408
 
409
- # Testing updated batch tokenization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
411
  """
412
  Processes dialogues in batches and saves to a TFRecord file using optimized batch tokenization and encoding.
@@ -522,83 +522,6 @@ class TFDataPipeline:
522
  pbar.update(1)
523
 
524
  logger.info(f"Data preparation complete. TFRecord saved.")
525
- # def prepare_and_save_data(self, dialogues: List[dict], tfrecord_file_path: str, batch_size: int = 32):
526
- # """Processes dialogues in batches and saves to a TFRecord file."""
527
- # with tf.io.TFRecordWriter(tfrecord_file_path) as writer:
528
- # total_dialogues = len(dialogues)
529
- # logger.debug(f"Total dialogues to process: {total_dialogues}")
530
-
531
- # with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar:
532
- # for i in range(0, total_dialogues, batch_size):
533
- # batch_dialogues = dialogues[i:i+batch_size]
534
- # # Process each batch_dialogues
535
- # # Extract pairs, find negatives, tokenize, and serialize
536
- # # Example:
537
- # for dialogue in batch_dialogues:
538
- # pairs = self._extract_pairs_from_dialogue(dialogue)
539
- # queries = []
540
- # positives = []
541
-
542
- # for query, positive in pairs:
543
- # queries.append(query)
544
- # positives.append(positive)
545
-
546
- # if queries:
547
- # # **Compute and cache query embeddings before searching**
548
- # self._compute_embeddings(queries)
549
-
550
- # # Find hard negatives
551
- # hard_negatives = self._find_hard_negatives_batch(queries, positives)
552
-
553
- # # for idx, negatives in enumerate(hard_negatives[:5]): # Log first 5 examples
554
- # # logger.debug(f"Query: {queries[idx]}")
555
- # # logger.debug(f"Positive: {positives[idx]}")
556
- # # logger.debug(f"Hard Negatives: {negatives}")
557
- # # Tokenize and encode
558
- # query_ids, positive_ids, negative_ids = self._tokenize_and_encode(queries, positives, hard_negatives)
559
-
560
- # # Serialize each example and write to TFRecord
561
- # for q_id, p_id, n_id in zip(query_ids, positive_ids, negative_ids):
562
- # feature = {
563
- # 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)),
564
- # 'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)),
565
- # 'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())),
566
- # }
567
- # example = tf.train.Example(features=tf.train.Features(feature=feature))
568
- # writer.write(example.SerializeToString())
569
-
570
- # pbar.update(len(batch_dialogues))
571
- # logger.info(f"Data preparation complete. TFRecord saved at {tfrecord_file_path}")
572
-
573
- def _tokenize_negatives_tf(self, negatives):
574
- """Tokenizes negatives using tf.py_function."""
575
- # Handle the case where negatives is an empty tensor
576
- if tf.size(negatives) == 0:
577
- return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
578
-
579
- # Convert EagerTensor to a list of strings
580
- negatives_list = []
581
- for neg_list in negatives.numpy():
582
- decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
583
- negatives_list.append(decoded_negs)
584
-
585
- # Flatten the list of lists
586
- flattened_negatives = [neg for sublist in negatives_list for neg in sublist]
587
-
588
- # Tokenize the flattened negatives
589
- if flattened_negatives:
590
- n_tokens = self.tokenizer(
591
- flattened_negatives,
592
- padding='max_length',
593
- truncation=True,
594
- max_length=self.max_length,
595
- return_tensors='tf'
596
- )
597
- # Reshape the tokens
598
- n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [-1, self.neg_samples, self.max_length])
599
- return n_tokens_reshaped
600
- else:
601
- return tf.zeros([0, self.neg_samples, self.max_length], dtype=tf.int32)
602
 
603
  def _compute_embeddings(self, queries: List[str]) -> None:
604
  new_queries = [q for q in queries if q not in self.query_embeddings_cache]
@@ -642,51 +565,6 @@ class TFDataPipeline:
642
  hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
643
  yield (query, positive, hard_negatives)
644
  pbar.update(1)
645
-
646
- def _prepare_batch(self, queries: tf.Tensor, positives: tf.Tensor, negatives: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
647
- """Prepares a batch of data for training."""
648
-
649
- # Convert EagerTensors to lists of strings
650
- queries_list = [query.decode("utf-8") for query in queries.numpy()]
651
- positives_list = [pos.decode("utf-8") for pos in positives.numpy()]
652
-
653
- # Tokenize queries and positives
654
- q_tokens = self.tokenizer(queries_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
655
- p_tokens = self.tokenizer(positives_list, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
656
-
657
- # Decode negatives and ensure they are lists of strings
658
- negatives_list = []
659
- for neg_list in negatives.numpy():
660
- decoded_negs = [neg.decode("utf-8") for neg in neg_list if neg] # Filter out empty strings
661
- negatives_list.append(decoded_negs)
662
-
663
- # Flatten negatives for tokenization if there are any valid negatives
664
- flattened_negatives = [neg for sublist in negatives_list for neg in sublist if neg]
665
-
666
- # Tokenize negatives if there are any
667
- n_tokens_reshaped = None
668
- if flattened_negatives:
669
- n_tokens = self.tokenizer(flattened_negatives, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='tf')
670
-
671
- # Reshape n_tokens to match the expected shape based on the number of negatives per query
672
- # This part may need adjustment if the number of negatives varies per query
673
- n_tokens_reshaped = tf.reshape(n_tokens['input_ids'], [len(queries_list), -1, self.max_length])
674
- else:
675
- # Create a placeholder tensor for the case where there are no negatives
676
- n_tokens_reshaped = tf.zeros([len(queries_list), 0, self.max_length], dtype=tf.int32)
677
-
678
- # Ensure n_tokens_reshaped has a consistent shape even when there are no negatives
679
- # Adjust shape to [batch_size, num_neg_samples, max_length]
680
- if n_tokens_reshaped.shape[1] != self.neg_samples:
681
- # Pad or truncate the second dimension to match neg_samples
682
- padding = tf.zeros([len(queries_list), tf.maximum(0, self.neg_samples - n_tokens_reshaped.shape[1]), self.max_length], dtype=tf.int32)
683
- n_tokens_reshaped = tf.concat([n_tokens_reshaped, padding], axis=1)
684
- n_tokens_reshaped = n_tokens_reshaped[:, :self.neg_samples, :]
685
-
686
- # Concatenate the positive and negative examples along the 'neg_samples' dimension
687
- combined_p_n_tokens = tf.concat([tf.expand_dims(p_tokens['input_ids'], axis=1), n_tokens_reshaped], axis=1)
688
-
689
- return q_tokens['input_ids'], combined_p_n_tokens
690
 
691
  def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
692
  """
@@ -714,32 +592,6 @@ class TFDataPipeline:
714
 
715
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
716
  return dataset
717
- # def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
718
- # """
719
- # Creates a tf.data.Dataset for streaming training that yields
720
- # (input_ids_query, input_ids_positive, input_ids_negatives).
721
- # """
722
- # # 1) Start with a generator dataset
723
- # dataset = tf.data.Dataset.from_generator(
724
- # lambda: self.data_generator(dialogues),
725
- # output_signature=(
726
- # tf.TensorSpec(shape=(), dtype=tf.string), # Query (single string)
727
- # tf.TensorSpec(shape=(), dtype=tf.string), # Positive (single string)
728
- # tf.TensorSpec(shape=(None,), dtype=tf.string) # Hard Negatives (list of strings)
729
- # )
730
- # )
731
-
732
- # # 2) Batch the raw strings
733
- # dataset = dataset.batch(batch_size)
734
-
735
- # # 3) Now map them through a tokenize step (via py_function)
736
- # dataset = dataset.map(
737
- # lambda q, p, n: self._tokenize_triple(q, p, n),
738
- # num_parallel_calls=1 #tf.data.AUTOTUNE
739
- # )
740
-
741
- # dataset = dataset.prefetch(tf.data.AUTOTUNE)
742
- # return dataset
743
 
744
  def _tokenize_triple(
745
  self,
@@ -861,71 +713,3 @@ class TFDataPipeline:
861
  n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
862
 
863
  return q_ids, p_ids, n_ids
864
-
865
- # def parse_tfrecord_fn(example_proto, max_length, neg_samples):
866
- # """
867
- # Parses a single TFRecord example.
868
-
869
- # Args:
870
- # example_proto: A serialized TFRecord example.
871
- # max_length: The maximum sequence length for tokenization.
872
- # neg_samples: The number of hard negatives per query.
873
-
874
- # Returns:
875
- # A tuple of (query_ids, positive_ids, negative_ids).
876
- # """
877
- # feature_description = {
878
- # 'query_ids': tf.io.FixedLenFeature([max_length], tf.int64),
879
- # 'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64),
880
- # 'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64),
881
- # }
882
- # parsed_features = tf.io.parse_single_example(example_proto, feature_description)
883
-
884
- # query_ids = tf.cast(parsed_features['query_ids'], tf.int32)
885
- # positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32)
886
- # negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32)
887
- # negative_ids = tf.reshape(negative_ids, [neg_samples, max_length])
888
-
889
- # return query_ids, positive_ids, negative_ids
890
-
891
- # def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
892
- # """Find hard negatives for a batch of queries with error handling and retries."""
893
- # retry_count = 0
894
- # total_responses = len(self.response_pool)
895
-
896
- # while retry_count < self.max_retries:
897
- # try:
898
- # query_embeddings = np.vstack([
899
- # self.query_embeddings_cache[q] for q in queries
900
- # ]).astype(np.float32)
901
-
902
- # query_embeddings = np.ascontiguousarray(query_embeddings)
903
- # faiss.normalize_L2(query_embeddings)
904
-
905
- # k = 1 # TODO: try higher k for better results
906
- # #logger.debug(f"Searching with k={k} among {total_responses} responses")
907
-
908
- # distances, indices = self.index.search(query_embeddings, k)
909
-
910
- # all_negatives = []
911
- # for query_indices, query, positive in zip(indices, queries, positives):
912
- # negatives = []
913
- # positive_strip = positive.strip()
914
- # seen = {positive_strip}
915
-
916
- # for idx in query_indices:
917
- # if idx >= 0 and idx < total_responses:
918
- # candidate = self.response_pool[idx].strip()
919
- # if candidate and candidate not in seen:
920
- # seen.add(candidate)
921
- # negatives.append(candidate)
922
- # if len(negatives) >= self.neg_samples:
923
- # break
924
-
925
- # # Pad with a special empty negative if necessary
926
- # while len(negatives) < self.neg_samples:
927
- # negatives.append("<EMPTY_NEGATIVE>") # Use a special token
928
-
929
- # all_negatives.append(negatives)
930
-
931
- # return all_negatives
 
11
  from typing import Union, Optional, List, Tuple, Generator
12
  from transformers import AutoTokenizer
13
  from typing import List, Tuple, Generator
14
+ from transformers import AutoTokenizer
15
  from gpu_monitor import GPUMemoryMonitor
16
 
17
  from logger_config import config_logger
 
32
  nlist: int = 100,
33
  max_retries: int = 3
34
  ):
 
35
  self.config = config
36
  self.tokenizer = tokenizer
37
  self.encoder = encoder
 
64
  dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0]
65
  self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
66
  self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32))
 
 
 
 
 
 
 
 
67
 
68
  def save_embeddings_cache_hdf5(self, cache_file_path: str):
69
  """Save the embeddings cache to an HDF5 file."""
 
84
  logger.info(f"FAISS index saved to {index_file_path}")
85
 
86
  def load_faiss_index(self, index_file_path: str):
87
+ """Load the FAISS index from the specified file path."""
88
+ if os.path.exists(index_file_path):
89
+ self.index = faiss.read_index(index_file_path)
90
+ logger.info(f"FAISS index loaded from {index_file_path}.")
91
+ else:
92
+ logger.error(f"FAISS index file not found at {index_file_path}.")
93
+ raise FileNotFoundError(f"FAISS index file not found at {index_file_path}.")
94
+
95
+ def validate_faiss_index(self):
96
+ """Validates that the FAISS index has the correct dimensionality."""
97
+ expected_dim = self.encoder.config.embedding_dim
98
+ if self.index.d != expected_dim:
99
+ logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.")
100
+ raise ValueError("FAISS index dimensionality mismatch.")
101
+ logger.info("FAISS index dimension validated successfully.")
102
 
103
  def save_tokenizer(self, tokenizer_dir: str):
104
  self.tokenizer.save_pretrained(tokenizer_dir)
 
107
  def load_tokenizer(self, tokenizer_dir: str):
108
  self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
109
  logger.info(f"Tokenizer loaded from {tokenizer_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  @staticmethod
112
  def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]:
 
171
 
172
  return pairs
173
 
174
+ def compute_and_index_response_embeddings(self):
175
  """
176
  Computes embeddings for the response pool and adds them to the FAISS index with progress bars.
177
  """
 
231
 
232
  # **Sanity Check:** Verify the number of embeddings in FAISS index
233
  logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
236
  """Find hard negatives for a batch of queries with error handling and retries."""
 
304
  if tf.config.list_physical_devices('GPU'):
305
  tf.keras.backend.clear_session()
306
 
307
+ def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> np.ndarray:
308
  """
309
+ Encode a query with optional conversation context into an embedding vector.
310
+
311
+ Args:
312
+ query (str): The user query.
313
+ context (Optional[List[Tuple[str, str]]]): Optional conversation history as a list of (user, assistant) tuples.
314
+
315
  Returns:
316
+ np.ndarray: The normalized embedding vector for the query.
 
 
317
  """
318
+ # Prepare query with context
319
+ if context:
320
+ context_str = ' '.join([
321
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {q} "
322
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
323
+ for q, r in context[-self.config.max_context_turns:]
324
+ ])
325
+ query = f"{context_str} {self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]}" \
326
+ f" {query}"
327
+ else:
328
+ query = f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
329
+
330
+ # Tokenize and encode
331
+ encodings = self.tokenizer(
332
+ [query],
333
+ padding='max_length',
334
  truncation=True,
335
  max_length=self.max_length,
336
+ return_tensors='np' # Use NumPy arrays for compatibility with FAISS
337
  )
338
+ input_ids = encodings['input_ids']
339
+
340
+ # Verify token IDs
341
+ max_id = np.max(input_ids)
342
+ new_vocab_size = len(self.tokenizer)
343
+
344
+ if max_id >= new_vocab_size:
345
+ logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
346
+ raise ValueError("Token ID exceeds vocabulary size.")
347
+
348
+ # Get embeddings from the shared encoder
349
+ embeddings = self.encoder(input_ids, training=False).numpy()
350
+
351
+ # Normalize embeddings for cosine similarity
352
+ faiss.normalize_L2(embeddings)
353
+
354
+ return embeddings[0] # Return as a 1D array
 
 
355
 
356
+ def encode_responses(self, responses: List[str], context: Optional[List[Tuple[str, str]]] = None) -> np.ndarray:
357
+ """
358
+ Encode a list of responses into embedding vectors.
 
359
 
360
+ Args:
361
+ responses (List[str]): List of response texts.
362
+ context (Optional[List[Tuple[str, str]]]): Optional conversation history as a list of (user, assistant) tuples.
363
 
364
+ Returns:
365
+ np.ndarray: Array of normalized embedding vectors.
366
+ """
367
+ # Prepare responses with context if provided
368
+ if context:
369
+ prepared_responses = []
370
+ for response in responses:
371
+ context_str = ' '.join([
372
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {q} "
373
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
374
+ for q, r in context[-self.config.max_context_turns:]
375
+ ])
376
+ full_response = f"{context_str} {self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {response}"
377
+ prepared_responses.append(full_response)
378
+ else:
379
+ prepared_responses = [
380
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
381
+ for resp in responses
382
+ ]
383
+
384
+ # Tokenize and encode
385
+ encodings = self.tokenizer(
386
+ prepared_responses,
387
+ padding='max_length',
388
+ truncation=True,
389
+ max_length=self.max_length,
390
+ return_tensors='np' # Use NumPy arrays for compatibility with FAISS
391
+ )
392
+ input_ids = encodings['input_ids']
393
+
394
+ # Verify token IDs
395
+ max_id = np.max(input_ids)
396
+ new_vocab_size = len(self.tokenizer)
397
+
398
+ if max_id >= new_vocab_size:
399
+ logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
400
+ raise ValueError("Token ID exceeds vocabulary size.")
401
+
402
+ # Get embeddings from the shared encoder
403
+ embeddings = self.encoder(input_ids, training=False).numpy()
404
+
405
+ # Normalize embeddings for cosine similarity
406
+ faiss.normalize_L2(embeddings)
407
+
408
+ return embeddings.astype('float32')
409
+
410
  def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
411
  """
412
  Processes dialogues in batches and saves to a TFRecord file using optimized batch tokenization and encoding.
 
522
  pbar.update(1)
523
 
524
  logger.info(f"Data preparation complete. TFRecord saved.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  def _compute_embeddings(self, queries: List[str]) -> None:
527
  new_queries = [q for q in queries if q not in self.query_embeddings_cache]
 
565
  hard_negatives = self._find_hard_negatives_batch([query], [positive])[0]
566
  yield (query, positive, hard_negatives)
567
  pbar.update(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset:
570
  """
 
592
 
593
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
594
  return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
  def _tokenize_triple(
597
  self,
 
713
  n_ids = n_ids.astype(np.int32) # shape [batch_size, neg_samples, max_len]
714
 
715
  return q_ids, p_ids, n_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_model_train.py → train_model.py RENAMED
@@ -1,36 +1,11 @@
1
  import tensorflow as tf
2
  from chatbot_model import RetrievalChatbot, ChatbotConfig
3
  from environment_setup import EnvironmentSetup
4
- from response_quality_checker import ResponseQualityChecker
5
- from chatbot_validator import ChatbotValidator
6
  from training_plotter import TrainingPlotter
7
 
8
- # Configure logging
9
  from logger_config import config_logger
10
  logger = config_logger(__name__)
11
 
12
- def run_interactive_chat(chatbot, quality_checker):
13
- """Separate function for interactive chat loop"""
14
- while True:
15
- user_input = input("You: ")
16
- if user_input.lower() in ['quit', 'exit', 'bye']:
17
- print("Assistant: Goodbye!")
18
- break
19
-
20
- response, candidates, metrics = chatbot.chat(
21
- query=user_input,
22
- conversation_history=None,
23
- quality_checker=quality_checker,
24
- top_k=5
25
- )
26
-
27
- print(f"Assistant: {response}")
28
-
29
- if metrics.get('is_confident', False):
30
- print("\nAlternative responses:")
31
- for resp, score in candidates[1:4]:
32
- print(f"Score: {score:.4f} - {resp}")
33
-
34
  def inspect_tfrecord(tfrecord_file_path, num_examples=3):
35
  def parse_example(example_proto):
36
  feature_description = {
@@ -53,7 +28,7 @@ def inspect_tfrecord(tfrecord_file_path, num_examples=3):
53
  def main():
54
 
55
  # Quick test to inspect TFRecord
56
- #inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
57
 
58
  # Initialize environment
59
  tf.keras.backend.clear_session()
@@ -65,49 +40,40 @@ def main():
65
  TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
66
 
67
  # Optimize batch size for Colab
68
- batch_size = env.optimize_batch_size(base_batch_size=16)
69
-
70
 
71
- # Initialize configuration
72
- config = ChatbotConfig(
73
- embedding_dim=768, # DistilBERT
74
- max_context_token_limit=512,
75
- freeze_embeddings=False,
76
- )
77
 
78
  # Initialize chatbot
79
- #with env.strategy.scope():
80
  chatbot = RetrievalChatbot(config, mode='training')
81
- chatbot.build_models()
82
 
83
- if chatbot.mode == 'preparation':
84
- chatbot.verify_faiss_index()
85
-
86
- chatbot.train_streaming(
 
 
 
 
 
 
 
87
  tfrecord_file_path=TF_RECORD_FILE_PATH,
88
  epochs=EPOCHS,
89
  batch_size=batch_size,
90
  use_lr_schedule=True,
 
 
91
  )
92
 
93
  # Save final model
94
  model_save_path = env.training_dirs['base'] / 'final_model'
95
  chatbot.save_models(model_save_path)
96
 
97
- # Run automatic validation
98
- quality_checker = ResponseQualityChecker(chatbot=chatbot)
99
- validator = ChatbotValidator(chatbot, quality_checker)
100
- validation_metrics = validator.run_validation(num_examples=5)
101
- logger.info(f"Validation Metrics: {validation_metrics}")
102
-
103
  # Plot and save training history
104
  plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
105
  plotter.plot_training_history(chatbot.history)
106
- plotter.plot_validation_metrics(validation_metrics)
107
-
108
- # Run interactive chat
109
- logger.info("\nStarting interactive chat session...")
110
- run_interactive_chat(chatbot, quality_checker)
111
 
112
  if __name__ == "__main__":
113
  main()
 
1
  import tensorflow as tf
2
  from chatbot_model import RetrievalChatbot, ChatbotConfig
3
  from environment_setup import EnvironmentSetup
 
 
4
  from training_plotter import TrainingPlotter
5
 
 
6
  from logger_config import config_logger
7
  logger = config_logger(__name__)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def inspect_tfrecord(tfrecord_file_path, num_examples=3):
10
  def parse_example(example_proto):
11
  feature_description = {
 
28
  def main():
29
 
30
  # Quick test to inspect TFRecord
31
+ # inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
32
 
33
  # Initialize environment
34
  tf.keras.backend.clear_session()
 
40
  TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
41
 
42
  # Optimize batch size for Colab
43
+ batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
 
44
 
45
+ # Initialize config
46
+ config = ChatbotConfig()
 
 
 
 
47
 
48
  # Initialize chatbot
 
49
  chatbot = RetrievalChatbot(config, mode='training')
 
50
 
51
+ # Load from a checkpoint
52
+ checkpoint_dir = 'checkpoints/'
53
+ latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
54
+ initial_epoch = 0
55
+ if latest_checkpoint:
56
+ ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
57
+ initial_epoch = ckpt_number
58
+ logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
59
+
60
+ # Train the model
61
+ chatbot.train_model(
62
  tfrecord_file_path=TF_RECORD_FILE_PATH,
63
  epochs=EPOCHS,
64
  batch_size=batch_size,
65
  use_lr_schedule=True,
66
+ test_mode=False,
67
+ initial_epoch=initial_epoch
68
  )
69
 
70
  # Save final model
71
  model_save_path = env.training_dirs['base'] / 'final_model'
72
  chatbot.save_models(model_save_path)
73
 
 
 
 
 
 
 
74
  # Plot and save training history
75
  plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
76
  plotter.plot_training_history(chatbot.history)
 
 
 
 
 
77
 
78
  if __name__ == "__main__":
79
  main()
validate_model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from chatbot_model import ChatbotConfig, RetrievalChatbot
4
+ from response_quality_checker import ResponseQualityChecker
5
+ from chatbot_validator import ChatbotValidator
6
+ from training_plotter import TrainingPlotter
7
+ from environment_setup import EnvironmentSetup
8
+
9
+ from logger_config import config_logger
10
+ logger = config_logger(__name__)
11
+
12
+ def run_interactive_chat(chatbot, quality_checker):
13
+ """Separate function for interactive chat loop"""
14
+ while True:
15
+ try:
16
+ user_input = input("You: ")
17
+ except (KeyboardInterrupt, EOFError):
18
+ print("\nAssistant: Goodbye!")
19
+ break
20
+
21
+ if user_input.lower() in ['quit', 'exit', 'bye']:
22
+ print("Assistant: Goodbye!")
23
+ break
24
+
25
+ response, candidates, metrics = chatbot.chat(
26
+ query=user_input,
27
+ conversation_history=None,
28
+ quality_checker=quality_checker,
29
+ top_k=5
30
+ )
31
+
32
+ print(f"Assistant: {response}")
33
+
34
+ if metrics.get('is_confident', False):
35
+ print("\nAlternative responses:")
36
+ for resp, score in candidates[1:4]:
37
+ print(f"Score: {score:.4f} - {resp}")
38
+ else:
39
+ print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
40
+
41
+ # TODO:
42
+ def validate_chatbot():
43
+ # Initialize environment
44
+ env = EnvironmentSetup()
45
+ env.initialize()
46
+
47
+ MODEL_DIR = 'models'
48
+ FAISS_INDICES_DIR = os.path.join(MODEL_DIR, 'faiss_indices')
49
+ FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
50
+ FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
51
+ RESPONSE_POOL_PRODUCTION_PATH = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
52
+ RESPONSE_POOL_TEST_PATH = FAISS_INDEX_TEST_PATH.replace('.index', '_responses.json')
53
+ ENVIRONMENT = 'production' # or 'test'
54
+ if ENVIRONMENT == 'test':
55
+ FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
56
+ RESPONSE_POOL_PATH = RESPONSE_POOL_TEST_PATH
57
+ else:
58
+ FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
59
+ RESPONSE_POOL_PATH = RESPONSE_POOL_PRODUCTION_PATH
60
+
61
+ # Load config
62
+ config = ChatbotConfig()
63
+
64
+ # Initialize RetrievalChatbot in 'inference' mode
65
+ try:
66
+ chatbot = RetrievalChatbot(config=config, mode='inference')
67
+ logger.info("RetrievalChatbot initialized in 'inference' mode.")
68
+ except Exception as e:
69
+ logger.error(f"Failed to initialize RetrievalChatbot: {e}")
70
+ return
71
+
72
+ # Ensure FAISS index and response pool are accessible, then load
73
+ if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
74
+ logger.error("FAISS index or response pool file is missing.")
75
+ return
76
+
77
+ try:
78
+ chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
79
+ logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
80
+
81
+ with open(RESPONSE_POOL_PATH, 'r', encoding='utf-8') as f:
82
+ chatbot.data_pipeline.response_pool = json.load(f)
83
+ logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
84
+
85
+ chatbot.data_pipeline.validate_faiss_index()
86
+ logger.info("FAISS index and response pool validated successfully.")
87
+ except Exception as e:
88
+ logger.error(f"Failed to load FAISS index: {e}")
89
+ return
90
+
91
+ # Initialize ResponseQualityChecker and ChatbotValidator
92
+ quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
93
+ validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
94
+ logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
95
+
96
+ # Run validation
97
+ try:
98
+ validation_metrics = validator.run_validation(num_examples=5)
99
+ logger.info(f"Validation Metrics: {validation_metrics}")
100
+ except Exception as e:
101
+ logger.error(f"Validation process failed: {e}")
102
+ return
103
+
104
+ # Plot validation_metrics
105
+ try:
106
+ plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
107
+ plotter.plot_validation_metrics(validation_metrics)
108
+ logger.info("Validation metrics plotted successfully.")
109
+ except Exception as e:
110
+ logger.error(f"Failed to plot validation metrics: {e}")
111
+
112
+ # Run interactive chat
113
+ logger.info("\nStarting interactive chat session...")
114
+ run_interactive_chat(chatbot, quality_checker)
115
+
116
+ if __name__ == '__main__':
117
+ validate_chatbot()