JoeArmani commited on
Commit
d7fc7a7
Β·
1 Parent(s): c111c20

more structural updates

Browse files
.gitignore CHANGED
@@ -187,3 +187,5 @@ new_iteration/cache/*
187
  new_iteration/data_prep_iterative_models/*
188
  new_iteration/training_data/*
189
  new_iteration/processed_outputs/*
 
 
 
187
  new_iteration/data_prep_iterative_models/*
188
  new_iteration/training_data/*
189
  new_iteration/processed_outputs/*
190
+ raw_datasets/*
191
+
chatbot_model.py CHANGED
@@ -24,25 +24,25 @@ logger = config_logger(__name__)
24
 
25
  @dataclass
26
  class ChatbotConfig:
27
- """Configuration for the RetrievalChatbot."""
28
  max_context_token_limit: int = 512
29
  embedding_dim: int = 768
30
  encoder_units: int = 256
31
  num_attention_heads: int = 8
32
  dropout_rate: float = 0.2
33
  l2_reg_weight: float = 0.001
34
- learning_rate: float = 0.001
35
  min_text_length: int = 3
36
- max_context_turns: int = 5
37
  warmup_steps: int = 200
38
  pretrained_model: str = 'distilbert-base-uncased'
39
  cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
 
40
  dtype: str = 'float32'
41
  freeze_embeddings: bool = False
42
  embedding_batch_size: int = 64
43
  search_batch_size: int = 64
44
  max_batch_size: int = 64
45
- neg_samples: int = 10
46
  max_retries: int = 3
47
 
48
  def to_dict(self) -> Dict:
@@ -57,7 +57,7 @@ class ChatbotConfig:
57
  if k in cls.__dataclass_fields__})
58
 
59
  class EncoderModel(tf.keras.Model):
60
- """Dual encoder model with pretrained embeddings."""
61
  def __init__(
62
  self,
63
  config: ChatbotConfig,
@@ -71,7 +71,7 @@ class EncoderModel(tf.keras.Model):
71
  self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
72
  self._freeze_layers()
73
 
74
- # Add Pooling layer (Global Average Pooling), Projection layer, Dropout, and Normalization
75
  self.pooler = tf.keras.layers.GlobalAveragePooling1D()
76
  self.projection = tf.keras.layers.Dense(
77
  config.embedding_dim,
@@ -86,7 +86,7 @@ class EncoderModel(tf.keras.Model):
86
  )
87
 
88
  def _freeze_layers(self):
89
- """Freeze layers of the pretrained model based on configuration."""
90
  if self.config.freeze_embeddings:
91
  self.pretrained.trainable = False
92
  logger.info("All pretrained layers frozen.")
@@ -95,29 +95,29 @@ class EncoderModel(tf.keras.Model):
95
  for i, layer in enumerate(self.pretrained.layers):
96
  if isinstance(layer, tf.keras.layers.Layer):
97
  if hasattr(layer, 'trainable'):
98
- # Freeze the first transformer block
99
  if i < 1:
100
  layer.trainable = False
101
  logger.info(f"Layer {i} frozen.")
102
  else:
103
  layer.trainable = True
 
104
 
105
  def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
106
  """Forward pass."""
107
  # Get pretrained embeddings
108
  pretrained_outputs = self.pretrained(inputs, training=training)
109
- x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
110
 
111
  # Apply pooling, projection, dropout, and normalization
112
- x = self.pooler(x) # Shape: [batch_size, 768]
113
- x = self.projection(x) # Shape: [batch_size, 768]
114
  x = self.dropout(x, training=training)
115
- x = self.normalize(x) # Shape: [batch_size, 768]
116
 
117
  return x
118
 
119
  def get_config(self) -> dict:
120
- """Return the config of the model."""
121
  config = super().get_config()
122
  config.update({
123
  "config": self.config.to_dict(),
@@ -126,7 +126,10 @@ class EncoderModel(tf.keras.Model):
126
  return config
127
 
128
  class RetrievalChatbot(DeviceAwareModel):
129
- """Retrieval-based chatbot using pretrained embeddings and FAISS for similarity search."""
 
 
 
130
  def __init__(
131
  self,
132
  config: ChatbotConfig,
@@ -142,7 +145,7 @@ class RetrievalChatbot(DeviceAwareModel):
142
  self.device = device or self._setup_default_device()
143
  self.mode = mode.lower()
144
 
145
- # Initialize reranker, summarizer, tokenizer, encoder, and memory monitor
146
  self.reranker = reranker or self._initialize_reranker()
147
  self.tokenizer = self._initialize_tokenizer()
148
  self.encoder = self._initialize_encoder()
@@ -154,14 +157,9 @@ class RetrievalChatbot(DeviceAwareModel):
154
  config=self.config,
155
  tokenizer=self.tokenizer,
156
  encoder=self.encoder,
157
- index_file_path='new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
158
  response_pool=[],
159
  max_length=self.config.max_context_token_limit,
160
  query_embeddings_cache={},
161
- neg_samples=self.config.neg_samples,
162
- index_type='IndexFlatIP',
163
- nlist=100, # Not used with IndexFlatIP
164
- max_retries=self.config.max_retries
165
  )
166
 
167
  # Collect unique responses from dialogues
@@ -197,7 +195,7 @@ class RetrievalChatbot(DeviceAwareModel):
197
  """Initialize the Summarizer."""
198
  return Summarizer(
199
  tokenizer=self.tokenizer,
200
- model_name="t5-small",
201
  max_summary_length=self.config.max_context_token_limit // 4,
202
  device=self.device,
203
  max_summary_rounds=2
@@ -229,17 +227,18 @@ class RetrievalChatbot(DeviceAwareModel):
229
  new_vocab_size = len(self.tokenizer)
230
  encoder.pretrained.resize_token_embeddings(new_vocab_size)
231
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
 
232
  return encoder
233
 
234
  def _load_faiss_index_and_responses(self) -> None:
235
  """Load FAISS index and response pool for inference."""
236
  try:
237
- logger.info(f"Loading FAISS index from {self.data_pipeline.index_file_path}...")
238
- self.data_pipeline.load_faiss_index(self.data_pipeline.index_file_path)
239
  logger.info("FAISS index loaded successfully.")
240
 
241
- # Load response pool associated with the FAISS index
242
- response_pool_path = self.data_pipeline.index_file_path.replace('.index', '_responses.json')
243
  if os.path.exists(response_pool_path):
244
  with open(response_pool_path, 'r', encoding='utf-8') as f:
245
  self.data_pipeline.response_pool = json.load(f)
@@ -263,29 +262,24 @@ class RetrievalChatbot(DeviceAwareModel):
263
  """
264
  load_dir = Path(load_dir)
265
 
266
- # 1) Load config
267
  with open(load_dir / "config.json", "r") as f:
268
  config = ChatbotConfig.from_dict(json.load(f))
269
 
270
- # 2) Initialize chatbot
271
  chatbot = cls(config, mode=mode)
272
 
273
- # 3) Load DistilBERT from huggingface folder
274
- chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
275
- load_dir / "shared_encoder",
276
- config=config
277
- )
278
 
279
  dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
280
  _ = chatbot.encoder(dummy_input, training=False)
281
 
282
- # 4) Load tokenizer
283
  chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
284
  logger.info(f"Models and tokenizer loaded from {load_dir}")
285
 
286
-
287
-
288
- # 5) Load the custom top layers' weights
289
  custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
290
  if custom_weights_path.exists():
291
  chatbot.encoder.load_weights(str(custom_weights_path))
@@ -293,7 +287,7 @@ class RetrievalChatbot(DeviceAwareModel):
293
  else:
294
  logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
295
 
296
- # 6) If in inference mode, load FAISS, etc.
297
  if mode == 'inference':
298
  cls._prepare_model_for_inference(chatbot, load_dir)
299
 
@@ -301,7 +295,7 @@ class RetrievalChatbot(DeviceAwareModel):
301
 
302
  @classmethod
303
  def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None:
304
- """Internal method to load inference components."""
305
  try:
306
  # Load FAISS index
307
  faiss_path = load_dir / 'faiss_indices/faiss_index_production.index'
@@ -332,7 +326,7 @@ class RetrievalChatbot(DeviceAwareModel):
332
  raise
333
 
334
  def save_models(self, save_dir: Union[str, Path]):
335
- """Save models and configuration."""
336
  save_dir = Path(save_dir)
337
  save_dir.mkdir(parents=True, exist_ok=True)
338
 
@@ -340,21 +334,13 @@ class RetrievalChatbot(DeviceAwareModel):
340
  with open(save_dir / "config.json", "w") as f:
341
  json.dump(self.config.to_dict(), f, indent=2)
342
 
343
- # Save the HF DistilBERT submodule:
344
  self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
345
-
346
- # ALSO save custom top-level layers' weights
347
  self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
348
-
349
- # Save tokenizer
350
  self.tokenizer.save_pretrained(save_dir / "tokenizer")
351
-
352
  logger.info(f"Models and tokenizer saved to {save_dir}.")
353
 
354
- def sigmoid(self, x: float) -> float:
355
- return 1 / (1 + np.exp(-x))
356
-
357
- def retrieve_responses_cross_encoder(
358
  self,
359
  query: str,
360
  top_k: int = 10,
@@ -363,20 +349,20 @@ class RetrievalChatbot(DeviceAwareModel):
363
  summarize_threshold: int = 512
364
  ) -> List[Tuple[str, float]]:
365
  """
366
- Retrieve top-k responses with optional domain-based boosting
367
- and cross-encoder re-ranking.
368
-
369
  Args:
370
  query: The user's input text.
371
- top_k: Number of final results to return.
372
- reranker: CrossEncoderReranker for refined scoring, if available.
373
- summarizer: Summarizer for long queries, if desired.
374
- summarize_threshold: Summarize if query wordcount > threshold.
375
-
376
  Returns:
377
  List of (response_text, final_score).
378
  """
379
- # 1) Optional query summarization
 
 
 
380
  if summarizer and len(query.split()) > summarize_threshold:
381
  logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
382
  query = summarizer.summarize_text(query)
@@ -393,17 +379,17 @@ class RetrievalChatbot(DeviceAwareModel):
393
 
394
  texts = [item[0] for item in faiss_candidates]
395
 
396
- # Re-rank these boosted candidates
397
  if not reranker:
398
  reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
399
 
 
400
  ce_logits = reranker.rerank(query, texts, max_length=256)
401
 
402
- # Combine cross-encoder score with the base FAISS score (simple multiplicative approach)
403
  final_candidates = []
404
  for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
405
- ce_prob = self.sigmoid(logit) # [0...1]
406
- faiss_norm = (faiss_score + 1)/2.0 # [0...1]
407
  combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
408
  length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
409
 
@@ -415,22 +401,22 @@ class RetrievalChatbot(DeviceAwareModel):
415
  # Return top_k
416
  return final_candidates[:top_k]
417
 
418
- DOMAIN_KEYWORDS = {
419
- 'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'],
420
- 'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'],
421
- 'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'],
422
- 'coffee': ['coffee', 'cafΓ©', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'],
423
- 'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'],
424
- 'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'],
425
- }
426
-
427
  def extract_keywords(self, query: str) -> List[str]:
428
  """
429
  Return any domain keywords present in the query (lowercased).
430
  """
 
 
 
 
 
 
 
 
 
431
  query_lower = query.lower()
432
  found = set()
433
- for domain, kw_list in self.DOMAIN_KEYWORDS.items():
434
  for kw in kw_list:
435
  if kw in query_lower:
436
  found.add(kw)
@@ -456,7 +442,7 @@ class RetrievalChatbot(DeviceAwareModel):
456
 
457
  def detect_domain_from_query(self, query: str) -> str:
458
  """
459
- Detect the domain of the query based on keywords.
460
  """
461
  domain_patterns = {
462
  'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
@@ -476,8 +462,7 @@ class RetrievalChatbot(DeviceAwareModel):
476
 
477
  def is_numeric_response(self, text: str) -> bool:
478
  """
479
- Return True if `text` is purely digits (and/or spaces),
480
- with optional punctuation like '.' at the end.
481
  """
482
  pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
483
  return bool(re.match(pattern, text.strip()))
@@ -486,18 +471,16 @@ class RetrievalChatbot(DeviceAwareModel):
486
  self,
487
  query: str,
488
  domain: str = 'other',
489
- top_k: int = 5,
490
- boost_factor: float = 1.05
491
  ) -> List[Tuple[str, float]]:
492
  """
493
  Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
494
-
495
  Args:
496
  query (str): The user input text.
497
- domain (str, optional): The detected domain. Defaults to 'other'.
498
- top_k (int, optional): Number of top results to return. Defaults to 5.
499
- boost_factor (float, optional): Factor to boost scores for keyword matches. Defaults to 1.3.
500
-
501
  Returns:
502
  List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
503
  """
@@ -508,7 +491,7 @@ class RetrievalChatbot(DeviceAwareModel):
508
  # Search the index
509
  distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
510
 
511
- # IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors)
512
  candidates = []
513
  for rank, idx in enumerate(indices[0]):
514
  if idx < 0:
@@ -545,8 +528,7 @@ class RetrievalChatbot(DeviceAwareModel):
545
  boosted = []
546
  for (resp_text, resp_domain, score) in in_domain:
547
  new_score = score
548
- # If the domain is known AND the response text
549
- # shares any query keywords, apply a small boost
550
  if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
551
  new_score *= boost_factor
552
 
@@ -558,7 +540,7 @@ class RetrievalChatbot(DeviceAwareModel):
558
  # Sort boosted responses
559
  boosted.sort(key=lambda x: x[1], reverse=True)
560
 
561
- # Debug
562
  # for resp, score in boosted[:100]:
563
  # logger.debug(f"Candidate: '{resp}' with score {score}")
564
 
@@ -572,8 +554,7 @@ class RetrievalChatbot(DeviceAwareModel):
572
  top_k: int = 10,
573
  ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
574
  """
575
- Example chat method that always uses cross-encoder re-ranking
576
- if self.reranker is available.
577
  """
578
  @self.run_on_device
579
  def get_response(self_arg, query_arg):
@@ -581,7 +562,7 @@ class RetrievalChatbot(DeviceAwareModel):
581
  conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
582
 
583
  # Retrieve and re-rank
584
- results = self_arg.retrieve_responses_cross_encoder(
585
  query=conversation_str,
586
  top_k=top_k,
587
  reranker=self_arg.reranker,
@@ -605,7 +586,9 @@ class RetrievalChatbot(DeviceAwareModel):
605
  query: str,
606
  conversation_history: Optional[List[Tuple[str, str]]]
607
  ) -> str:
608
- """Build conversation context with better memory management."""
 
 
609
  if not conversation_history:
610
  return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
611
 
@@ -636,12 +619,12 @@ class RetrievalChatbot(DeviceAwareModel):
636
  ) -> None:
637
  """
638
  Train the retrieval model using a pre-prepared TFRecord dataset.
639
- This method handles:
640
  - Checkpoint loading/restoring
641
  - LR scheduling
642
  - Epoch/iteration tracking
643
- - Optional training-history logging
644
- - Basic early stopping
 
645
  """
646
  logger.info("Starting training with pre-prepared TFRecord dataset...")
647
 
@@ -673,7 +656,7 @@ class RetrievalChatbot(DeviceAwareModel):
673
  steps_per_epoch = math.ceil(train_size / batch_size)
674
  val_steps = math.ceil(val_size / batch_size)
675
  total_steps = steps_per_epoch * epochs
676
- buffer_size = max(1, total_pairs // 10) # 10% of the dataset
677
 
678
  logger.info(f"Training pairs: {train_size}")
679
  logger.info(f"Validation pairs: {val_size}")
@@ -695,7 +678,7 @@ class RetrievalChatbot(DeviceAwareModel):
695
  self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32))
696
  logger.info("Using fixed learning rate.")
697
 
698
- # Initialize optimizer with dummy step
699
  dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
700
  with tf.GradientTape() as tape:
701
  dummy_output = self.encoder(dummy_input)
@@ -710,6 +693,7 @@ class RetrievalChatbot(DeviceAwareModel):
710
  model=self.encoder
711
  )
712
 
 
713
  manager = tf.train.CheckpointManager(
714
  checkpoint,
715
  directory=checkpoint_dir,
@@ -717,18 +701,18 @@ class RetrievalChatbot(DeviceAwareModel):
717
  checkpoint_name='ckpt'
718
  )
719
 
720
- # Restore from existing checkpoint if present
721
  latest_checkpoint = manager.latest_checkpoint
722
  history_path = Path(checkpoint_dir) / 'training_history.json'
723
 
724
- # If you want to log all epoch losses across runs
725
  if not hasattr(self, 'history'):
726
  self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
727
 
728
  if latest_checkpoint and not test_mode:
729
- # Add checkpoint inspection
730
- logger.info(f"\nTrying to load checkpoint from: {latest_checkpoint}")
731
- reader = tf.train.load_checkpoint(latest_checkpoint)
732
  # shape_from_key = reader.get_variable_to_shape_map()
733
  # dtype_from_key = reader.get_variable_to_dtype_map()
734
  # logger.info("\nCheckpoint Variables:")
@@ -752,11 +736,11 @@ class RetrievalChatbot(DeviceAwareModel):
752
  if initial_epoch == 0:
753
  initial_epoch = ckpt_number
754
 
755
- # Assign to checkpoint.epoch so we keep counting from that
756
  checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32))
757
  logger.info(f"Resuming from epoch {initial_epoch}")
758
 
759
- # If you want to load old history from file:
760
  if history_path.exists():
761
  try:
762
  with open(history_path, 'r') as f:
@@ -765,7 +749,10 @@ class RetrievalChatbot(DeviceAwareModel):
765
  except Exception as e:
766
  logger.warning(f"Could not load history, starting fresh: {e}")
767
 
768
- # Fix for custom weights not being saved in the full model.
 
 
 
769
  self.save_models(Path(checkpoint_dir) / "pretrained_full_model")
770
  logger.info(f"Manually saved custom weights after restore.")
771
  else:
@@ -782,13 +769,13 @@ class RetrievalChatbot(DeviceAwareModel):
782
  train_summary_writer = tf.summary.create_file_writer(train_log_dir)
783
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
784
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
785
-
786
  # Parse dataset
787
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
788
-
789
- # Optional: test/debug mode with small subset
790
  if test_mode:
791
- subset_size = 150
792
  dataset = dataset.take(subset_size)
793
  logger.info(f"TEST MODE: Using only {subset_size} examples")
794
  # Recompute sizes, steps, epochs, etc., as needed
@@ -804,38 +791,36 @@ class RetrievalChatbot(DeviceAwareModel):
804
  early_stopping_patience = 2
805
  logger.info(f"New training pairs: {train_size}")
806
  logger.info(f"New validation pairs: {val_size}")
807
-
808
  dataset = dataset.map(
809
- lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.config.neg_samples),
810
  num_parallel_calls=tf.data.AUTOTUNE
811
  )
812
-
813
  # Train/val split
814
  train_dataset = dataset.take(train_size)
815
  val_dataset = dataset.skip(train_size).take(val_size)
816
-
817
  # Shuffle and batch
818
  train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
819
  train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
820
  train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
821
-
822
  val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
823
  val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
824
  val_dataset = val_dataset.cache()
825
-
826
  # Training loop
827
  best_val_loss = float("inf")
828
  epochs_no_improve = 0
829
-
830
  for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
831
  checkpoint.epoch.assign(epoch)
832
  logger.info(f"Starting Epoch {epoch}...")
833
-
834
- # --- Training Phase ---
835
  epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
836
  batches_processed = 0
837
-
838
- # Progress bar
839
  try:
840
  train_pbar = tqdm(
841
  total=steps_per_epoch,
@@ -846,7 +831,8 @@ class RetrievalChatbot(DeviceAwareModel):
846
  except ImportError:
847
  train_pbar = None
848
  is_tqdm_train = False
849
-
 
850
  for q_batch, p_batch, n_batch in train_dataset:
851
  loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
852
  epoch_loss_avg(loss)
@@ -874,54 +860,54 @@ class RetrievalChatbot(DeviceAwareModel):
874
  "lr": f"{current_lr:.2e}",
875
  "batches": f"{batches_processed}/{steps_per_epoch}"
876
  })
877
-
878
  gc.collect()
879
-
880
  # End the epoch early if we've processed all steps
881
  if batches_processed >= steps_per_epoch:
882
  break
883
-
884
  if is_tqdm_train and train_pbar:
885
  train_pbar.close()
886
-
887
- # --- Validation Phase ---
888
  val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
889
  val_batches_processed = 0
890
-
891
  try:
892
  val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
893
  is_tqdm_val = True
894
  except ImportError:
895
  val_pbar = None
896
  is_tqdm_val = False
897
-
898
  last_valid_val_loss = None
899
  valid_batches = False
900
-
901
  for q_batch, p_batch, n_batch in val_dataset:
902
  # If batch is too small, skip
903
  if tf.shape(q_batch)[0] < 2:
904
  logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}")
905
  continue
906
-
907
  valid_batches = True
908
  val_loss = self.validation_step(q_batch, p_batch, n_batch)
909
  val_loss_avg(val_loss)
910
  last_valid_val_loss = val_loss
911
  val_batches_processed += 1
912
-
913
  if is_tqdm_val:
914
  val_pbar.update(1)
915
  val_pbar.set_postfix({
916
  "val_loss": f"{val_loss.numpy():.4f}",
917
  "batches": f"{val_batches_processed}/{val_steps}"
918
  })
919
-
920
  gc.collect()
921
-
922
  if val_batches_processed >= val_steps:
923
  break
924
-
925
  if not valid_batches:
926
  # If no valid batch is found, fallback
927
  logger.warning("No valid validation batches in this epoch")
@@ -931,29 +917,29 @@ class RetrievalChatbot(DeviceAwareModel):
931
  else:
932
  val_loss = epoch_loss_avg.result()
933
  val_loss_avg(val_loss)
934
-
935
  if is_tqdm_val and val_pbar:
936
  val_pbar.close()
937
-
938
  # End of epoch: final stats
939
  train_loss = epoch_loss_avg.result().numpy()
940
  val_loss = val_loss_avg.result().numpy()
941
  logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
942
-
943
  # TensorBoard epoch logs
944
  with train_summary_writer.as_default():
945
  tf.summary.scalar("epoch_loss", train_loss, step=epoch)
946
  with val_summary_writer.as_default():
947
  tf.summary.scalar("val_loss", val_loss, step=epoch)
948
-
949
  # Save checkpoint
950
  manager.save()
951
-
952
- # (Optional) Save model for quick testing/inference
953
  model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
954
  self.save_models(model_save_path)
955
  logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
956
-
957
  # Update local history
958
  self.history['train_loss'].append(train_loss)
959
  self.history['val_loss'].append(val_loss)
@@ -972,13 +958,12 @@ class RetrievalChatbot(DeviceAwareModel):
972
  return obj
973
 
974
  json_history = convert_to_py_floats(self.history)
975
-
976
  # Save training history to file every epoch
977
- # (Create or overwrite the file so we always have the latest.)
978
  with open(history_path, 'w') as f:
979
  json.dump(json_history, f)
980
  logger.info(f"Saved training history to {history_path}")
981
-
982
  # Early stopping
983
  if val_loss < best_val_loss - min_delta:
984
  best_val_loss = val_loss
@@ -990,7 +975,7 @@ class RetrievalChatbot(DeviceAwareModel):
990
  if epochs_no_improve >= early_stopping_patience:
991
  logger.info("Early stopping triggered.")
992
  break
993
-
994
  logger.info("Training completed!")
995
 
996
  @tf.function
@@ -1004,37 +989,25 @@ class RetrievalChatbot(DeviceAwareModel):
1004
  Single training step using queries, positives, and hard negatives.
1005
  """
1006
  with tf.GradientTape() as tape:
1007
- # Encode queries
1008
  q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
1009
-
1010
- # Encode positives
1011
  p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
1012
-
1013
- # Encode negatives
1014
- # n_batch: [batch_size, neg_samples, max_length]
1015
  shape = tf.shape(n_batch)
1016
  bs = shape[0]
1017
  neg_samples = shape[1]
1018
 
1019
- # Flatten negatives to feed them in one pass:
1020
- # => [batch_size * neg_samples, max_length]
1021
  n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
1022
  n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
1023
 
1024
  # Reshape back => [batch_size, neg_samples, embed_dim]
1025
  n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
1026
 
1027
- # Combine the positive embedding and negative embeddings along dim=1
1028
- # => shape [batch_size, 1 + neg_samples, embed_dim]
1029
- # The first column is the positive; subsequent columns are negatives
1030
- combined_p_n = tf.concat(
1031
- [tf.expand_dims(p_enc, axis=1), n_enc],
1032
- axis=1
1033
- ) # [bs, (1+neg_samples), embed_dim]
1034
 
1035
- # Now compute scores: dot product of q_enc with each column in combined_p_n
1036
- # We'll use `tf.einsum` to handle the batch dimension properly
1037
- # dot_products => shape [batch_size, (1+neg_samples)]
1038
  dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
1039
  labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
1040
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
@@ -1043,14 +1016,13 @@ class RetrievalChatbot(DeviceAwareModel):
1043
  )
1044
  loss = tf.cast(tf.reduce_mean(loss), tf.float32)
1045
 
1046
- # Calculate gradients
1047
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
1048
  gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
1049
  max_grad_norm = tf.constant(1.5, dtype=tf.float32)
1050
  gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
1051
  post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
1052
 
1053
- # Apply gradients
1054
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
1055
 
1056
  return loss, gradients_norm, post_clip_norm
@@ -1064,6 +1036,7 @@ class RetrievalChatbot(DeviceAwareModel):
1064
  ) -> tf.Tensor:
1065
  """
1066
  Single validation step using queries, positives, and hard negatives.
 
1067
  """
1068
  q_enc = self.encoder(q_batch, training=False)
1069
  p_enc = self.encoder(p_batch, training=False)
@@ -1082,7 +1055,7 @@ class RetrievalChatbot(DeviceAwareModel):
1082
  )
1083
 
1084
  dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
1085
- labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
1086
 
1087
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1088
  labels=labels,
@@ -1098,7 +1071,9 @@ class RetrievalChatbot(DeviceAwareModel):
1098
  peak_lr: float,
1099
  warmup_steps: int
1100
  ) -> tf.keras.optimizers.schedules.LearningRateSchedule:
1101
- """Create a custom learning rate schedule with warmup and cosine decay."""
 
 
1102
  class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
1103
  def __init__(
1104
  self,
@@ -1110,11 +1085,11 @@ class RetrievalChatbot(DeviceAwareModel):
1110
  self.total_steps = tf.cast(total_steps, tf.float32)
1111
  self.peak_lr = tf.cast(peak_lr, tf.float32)
1112
 
1113
- # Adjust warmup_steps to not exceed half of total_steps
1114
  adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10))
1115
  self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
1116
 
1117
- # Calculate and store constants
1118
  self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32)
1119
  self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32)
1120
 
@@ -1128,21 +1103,20 @@ class RetrievalChatbot(DeviceAwareModel):
1128
  def __call__(self, step):
1129
  step = tf.cast(step, tf.float32)
1130
 
1131
- # Warmup phase
1132
  warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32)
1133
  warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
1134
 
1135
- # Decay phase
1136
  decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32)
1137
  decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32)
1138
  decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32)
1139
  cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32)
1140
  decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
1141
 
1142
- # Choose between warmup and decay
1143
  final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
1144
 
1145
- # Ensure learning rate is valid
1146
  final_lr = tf.maximum(self.min_lr, final_lr)
1147
  final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
1148
 
 
24
 
25
  @dataclass
26
  class ChatbotConfig:
27
+ """RetrievalChatbot Config"""
28
  max_context_token_limit: int = 512
29
  embedding_dim: int = 768
30
  encoder_units: int = 256
31
  num_attention_heads: int = 8
32
  dropout_rate: float = 0.2
33
  l2_reg_weight: float = 0.001
34
+ learning_rate: float = 0.0005
35
  min_text_length: int = 3
36
+ max_context_turns: int = 20
37
  warmup_steps: int = 200
38
  pretrained_model: str = 'distilbert-base-uncased'
39
  cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
40
+ summarizer_model: str = 't5-small'
41
  dtype: str = 'float32'
42
  freeze_embeddings: bool = False
43
  embedding_batch_size: int = 64
44
  search_batch_size: int = 64
45
  max_batch_size: int = 64
 
46
  max_retries: int = 3
47
 
48
  def to_dict(self) -> Dict:
 
57
  if k in cls.__dataclass_fields__})
58
 
59
  class EncoderModel(tf.keras.Model):
60
+ """Dual encoder model with pretrained DistilBERT embeddings."""
61
  def __init__(
62
  self,
63
  config: ChatbotConfig,
 
71
  self.pretrained = TFAutoModel.from_pretrained(config.pretrained_model)
72
  self._freeze_layers()
73
 
74
+ # Add Global Average Pooling, Projection, Dropout, and Normalization layers
75
  self.pooler = tf.keras.layers.GlobalAveragePooling1D()
76
  self.projection = tf.keras.layers.Dense(
77
  config.embedding_dim,
 
86
  )
87
 
88
  def _freeze_layers(self):
89
+ """Freeze n layers of the pretrained model"""
90
  if self.config.freeze_embeddings:
91
  self.pretrained.trainable = False
92
  logger.info("All pretrained layers frozen.")
 
95
  for i, layer in enumerate(self.pretrained.layers):
96
  if isinstance(layer, tf.keras.layers.Layer):
97
  if hasattr(layer, 'trainable'):
 
98
  if i < 1:
99
  layer.trainable = False
100
  logger.info(f"Layer {i} frozen.")
101
  else:
102
  layer.trainable = True
103
+ logger.info(f"Layer {i} trainable.")
104
 
105
  def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
106
  """Forward pass."""
107
  # Get pretrained embeddings
108
  pretrained_outputs = self.pretrained(inputs, training=training)
109
+ x = pretrained_outputs.last_hidden_state # Shape: [batch_size, seq_len, embedding_dim]
110
 
111
  # Apply pooling, projection, dropout, and normalization
112
+ x = self.pooler(x) # Shape: [batch_size, 768]
113
+ x = self.projection(x) # Shape: [batch_size, 768]
114
  x = self.dropout(x, training=training)
115
+ x = self.normalize(x) # Shape: [batch_size, 768]
116
 
117
  return x
118
 
119
  def get_config(self) -> dict:
120
+ """Return the model config"""
121
  config = super().get_config()
122
  config.update({
123
  "config": self.config.to_dict(),
 
126
  return config
127
 
128
  class RetrievalChatbot(DeviceAwareModel):
129
+ """
130
+ Retrieval-based learning chatbot model.
131
+ Uses trained embeddings and FAISS for similarity search.
132
+ """
133
  def __init__(
134
  self,
135
  config: ChatbotConfig,
 
145
  self.device = device or self._setup_default_device()
146
  self.mode = mode.lower()
147
 
148
+ # Initialize reranker, summarizer, tokenizer, and encoder
149
  self.reranker = reranker or self._initialize_reranker()
150
  self.tokenizer = self._initialize_tokenizer()
151
  self.encoder = self._initialize_encoder()
 
157
  config=self.config,
158
  tokenizer=self.tokenizer,
159
  encoder=self.encoder,
 
160
  response_pool=[],
161
  max_length=self.config.max_context_token_limit,
162
  query_embeddings_cache={},
 
 
 
 
163
  )
164
 
165
  # Collect unique responses from dialogues
 
195
  """Initialize the Summarizer."""
196
  return Summarizer(
197
  tokenizer=self.tokenizer,
198
+ model_name=self.config.summarizer_model,
199
  max_summary_length=self.config.max_context_token_limit // 4,
200
  device=self.device,
201
  max_summary_rounds=2
 
227
  new_vocab_size = len(self.tokenizer)
228
  encoder.pretrained.resize_token_embeddings(new_vocab_size)
229
  logger.info(f"Token embeddings resized to: {new_vocab_size}")
230
+
231
  return encoder
232
 
233
  def _load_faiss_index_and_responses(self) -> None:
234
  """Load FAISS index and response pool for inference."""
235
  try:
236
+ logger.info(f"Loading FAISS index from {self.data_pipeline.faiss_index_file_path}...")
237
+ self.data_pipeline.load_faiss_index(self.data_pipeline.faiss_index_file_path)
238
  logger.info("FAISS index loaded successfully.")
239
 
240
+ # Load response pool
241
+ response_pool_path = self.data_pipeline.faiss_index_file_path.replace('.index', '_responses.json')
242
  if os.path.exists(response_pool_path):
243
  with open(response_pool_path, 'r', encoding='utf-8') as f:
244
  self.data_pipeline.response_pool = json.load(f)
 
262
  """
263
  load_dir = Path(load_dir)
264
 
265
+ # Load config
266
  with open(load_dir / "config.json", "r") as f:
267
  config = ChatbotConfig.from_dict(json.load(f))
268
 
269
+ # Initialize chatbot
270
  chatbot = cls(config, mode=mode)
271
 
272
+ # Load DistilBERT
273
+ chatbot.encoder.pretrained = TFAutoModel.from_pretrained(load_dir / "shared_encoder", config=config)
 
 
 
274
 
275
  dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
276
  _ = chatbot.encoder(dummy_input, training=False)
277
 
278
+ # Load tokenizer
279
  chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
280
  logger.info(f"Models and tokenizer loaded from {load_dir}")
281
 
282
+ # Load the custom weights
 
 
283
  custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
284
  if custom_weights_path.exists():
285
  chatbot.encoder.load_weights(str(custom_weights_path))
 
287
  else:
288
  logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
289
 
290
+ # Handle 'inference' mode: load FAISS, etc.
291
  if mode == 'inference':
292
  cls._prepare_model_for_inference(chatbot, load_dir)
293
 
 
295
 
296
  @classmethod
297
  def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None:
298
+ """Load inference components."""
299
  try:
300
  # Load FAISS index
301
  faiss_path = load_dir / 'faiss_indices/faiss_index_production.index'
 
326
  raise
327
 
328
  def save_models(self, save_dir: Union[str, Path]):
329
+ """Save model and config"""
330
  save_dir = Path(save_dir)
331
  save_dir.mkdir(parents=True, exist_ok=True)
332
 
 
334
  with open(save_dir / "config.json", "w") as f:
335
  json.dump(self.config.to_dict(), f, indent=2)
336
 
337
+ # Save the HF DistilBERT submodule, custom top-level layers, and tokenizer
338
  self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
 
 
339
  self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
 
 
340
  self.tokenizer.save_pretrained(save_dir / "tokenizer")
 
341
  logger.info(f"Models and tokenizer saved to {save_dir}.")
342
 
343
+ def retrieve_responses(
 
 
 
344
  self,
345
  query: str,
346
  top_k: int = 10,
 
349
  summarize_threshold: int = 512
350
  ) -> List[Tuple[str, float]]:
351
  """
352
+ Retrieve top-k responses using FAISS and cross-encoder re-ranking.
 
 
353
  Args:
354
  query: The user's input text.
355
+ top_k: Number of FAISS results to return
356
+ reranker: CrossEncoderReranker for refined scoring
357
+ summarizer: Summarizer for long queries
358
+ summarize_threshold: Summarize if conversation tokens > threshold.
 
359
  Returns:
360
  List of (response_text, final_score).
361
  """
362
+ def sigmoid(x: float) -> float:
363
+ return 1 / (1 + np.exp(-x))
364
+
365
+ # Query summarization
366
  if summarizer and len(query.split()) > summarize_threshold:
367
  logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
368
  query = summarizer.summarize_text(query)
 
379
 
380
  texts = [item[0] for item in faiss_candidates]
381
 
 
382
  if not reranker:
383
  reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model)
384
 
385
+ # Re-rank the texts (candidates) from FAISS search using the cross-encoder
386
  ce_logits = reranker.rerank(query, texts, max_length=256)
387
 
388
+ # Combine scores from FAISS and cross-encoder
389
  final_candidates = []
390
  for (resp_text, faiss_score), logit in zip(faiss_candidates, ce_logits):
391
+ ce_prob = sigmoid(logit) # now in range [0...1]
392
+ faiss_norm = (faiss_score + 1)/2.0 # now in range [0...1]
393
  combined_score = 0.85 * ce_prob + 0.15 * faiss_norm
394
  length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
395
 
 
401
  # Return top_k
402
  return final_candidates[:top_k]
403
 
 
 
 
 
 
 
 
 
 
404
  def extract_keywords(self, query: str) -> List[str]:
405
  """
406
  Return any domain keywords present in the query (lowercased).
407
  """
408
+ domain_keywords = {
409
+ 'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'],
410
+ 'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'],
411
+ 'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'],
412
+ 'coffee': ['coffee', 'cafΓ©', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'],
413
+ 'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'],
414
+ 'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'],
415
+ }
416
+
417
  query_lower = query.lower()
418
  found = set()
419
+ for domain, kw_list in domain_keywords.items():
420
  for kw in kw_list:
421
  if kw in query_lower:
422
  found.add(kw)
 
442
 
443
  def detect_domain_from_query(self, query: str) -> str:
444
  """
445
+ Detect the domain of the query based on keywords. Used for boosting FAISS search.
446
  """
447
  domain_patterns = {
448
  'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
 
462
 
463
  def is_numeric_response(self, text: str) -> bool:
464
  """
465
+ Return True if `text` is purely digits and/or spaces.
 
466
  """
467
  pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
468
  return bool(re.match(pattern, text.strip()))
 
471
  self,
472
  query: str,
473
  domain: str = 'other',
474
+ top_k: int = 10,
475
+ boost_factor: float = 1.15
476
  ) -> List[Tuple[str, float]]:
477
  """
478
  Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
 
479
  Args:
480
  query (str): The user input text.
481
+ domain (str): The detected domain from possible domains: ['restaurant', 'movie', 'ride_share', 'coffee', 'pizza', 'auto', 'other']
482
+ top_k (int): Number of top results to return.
483
+ boost_factor (float, optional): Factor to boost scores for keyword matches.
 
484
  Returns:
485
  List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
486
  """
 
491
  # Search the index
492
  distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
493
 
494
+ # IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors).
495
  candidates = []
496
  for rank, idx in enumerate(indices[0]):
497
  if idx < 0:
 
528
  boosted = []
529
  for (resp_text, resp_domain, score) in in_domain:
530
  new_score = score
531
+ # If the domain is known AND the response text shares any query keywords, boost it
 
532
  if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
533
  new_score *= boost_factor
534
 
 
540
  # Sort boosted responses
541
  boosted.sort(key=lambda x: x[1], reverse=True)
542
 
543
+ # Debug logging (see FAISS responses)
544
  # for resp, score in boosted[:100]:
545
  # logger.debug(f"Candidate: '{resp}' with score {score}")
546
 
 
554
  top_k: int = 10,
555
  ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
556
  """
557
+ Live chat with the chatbot. Uses same processing flow as validation, except for context handling and quality checking.
 
558
  """
559
  @self.run_on_device
560
  def get_response(self_arg, query_arg):
 
562
  conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
563
 
564
  # Retrieve and re-rank
565
+ results = self_arg.retrieve_responses(
566
  query=conversation_str,
567
  top_k=top_k,
568
  reranker=self_arg.reranker,
 
586
  query: str,
587
  conversation_history: Optional[List[Tuple[str, str]]]
588
  ) -> str:
589
+ """
590
+ Build conversation context string from conversation history.
591
+ """
592
  if not conversation_history:
593
  return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
594
 
 
619
  ) -> None:
620
  """
621
  Train the retrieval model using a pre-prepared TFRecord dataset.
 
622
  - Checkpoint loading/restoring
623
  - LR scheduling
624
  - Epoch/iteration tracking
625
+ - Training-history logging
626
+ - Early stopping
627
+ - Custom loss function (Contrastive loss with hard negative sampling))
628
  """
629
  logger.info("Starting training with pre-prepared TFRecord dataset...")
630
 
 
656
  steps_per_epoch = math.ceil(train_size / batch_size)
657
  val_steps = math.ceil(val_size / batch_size)
658
  total_steps = steps_per_epoch * epochs
659
+ buffer_size = max(1, total_pairs // 2) # 50% of the dataset for shuffling
660
 
661
  logger.info(f"Training pairs: {train_size}")
662
  logger.info(f"Validation pairs: {val_size}")
 
678
  self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32))
679
  logger.info("Using fixed learning rate.")
680
 
681
+ # Dummy step to force initialization
682
  dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
683
  with tf.GradientTape() as tape:
684
  dummy_output = self.encoder(dummy_input)
 
693
  model=self.encoder
694
  )
695
 
696
+ # Create a CheckpointManager
697
  manager = tf.train.CheckpointManager(
698
  checkpoint,
699
  directory=checkpoint_dir,
 
701
  checkpoint_name='ckpt'
702
  )
703
 
704
+ # Restore from existing checkpoint if one is provided
705
  latest_checkpoint = manager.latest_checkpoint
706
  history_path = Path(checkpoint_dir) / 'training_history.json'
707
 
708
+ # Log epoch losses across runs, including restore from checkpoint
709
  if not hasattr(self, 'history'):
710
  self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
711
 
712
  if latest_checkpoint and not test_mode:
713
+ # Debug checkpoint loading
714
+ # logger.info(f"\nTrying to load checkpoint from: {latest_checkpoint}")
715
+ # reader = tf.train.load_checkpoint(latest_checkpoint)
716
  # shape_from_key = reader.get_variable_to_shape_map()
717
  # dtype_from_key = reader.get_variable_to_dtype_map()
718
  # logger.info("\nCheckpoint Variables:")
 
736
  if initial_epoch == 0:
737
  initial_epoch = ckpt_number
738
 
739
+ # Assign to checkpoint.epoch for counting
740
  checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32))
741
  logger.info(f"Resuming from epoch {initial_epoch}")
742
 
743
+ # Load history from file:
744
  if history_path.exists():
745
  try:
746
  with open(history_path, 'r') as f:
 
749
  except Exception as e:
750
  logger.warning(f"Could not load history, starting fresh: {e}")
751
 
752
+ # Save custom weights not being saved in the full model.
753
+ # This was a bugfix to extract weights from a checkpoint without retraining.
754
+ # Before updating save_models, only Distilbert weights were being saved (custom layers were missed).
755
+ # Not needed, also not harmful.
756
  self.save_models(Path(checkpoint_dir) / "pretrained_full_model")
757
  logger.info(f"Manually saved custom weights after restore.")
758
  else:
 
769
  train_summary_writer = tf.summary.create_file_writer(train_log_dir)
770
  val_summary_writer = tf.summary.create_file_writer(val_log_dir)
771
  logger.info(f"TensorBoard logs will be saved in {log_dir}")
772
+
773
  # Parse dataset
774
  dataset = tf.data.TFRecordDataset(tfrecord_file_path)
775
+
776
+ # Debug mode uses small subset. Useful for CPU debugging.
777
  if test_mode:
778
+ subset_size = 200
779
  dataset = dataset.take(subset_size)
780
  logger.info(f"TEST MODE: Using only {subset_size} examples")
781
  # Recompute sizes, steps, epochs, etc., as needed
 
791
  early_stopping_patience = 2
792
  logger.info(f"New training pairs: {train_size}")
793
  logger.info(f"New validation pairs: {val_size}")
794
+
795
  dataset = dataset.map(
796
+ lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.data_pipeline.neg_samples),
797
  num_parallel_calls=tf.data.AUTOTUNE
798
  )
799
+
800
  # Train/val split
801
  train_dataset = dataset.take(train_size)
802
  val_dataset = dataset.skip(train_size).take(val_size)
803
+
804
  # Shuffle and batch
805
  train_dataset = train_dataset.shuffle(buffer_size=buffer_size)
806
  train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
807
  train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
808
+
809
  val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
810
  val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
811
  val_dataset = val_dataset.cache()
812
+
813
  # Training loop
814
  best_val_loss = float("inf")
815
  epochs_no_improve = 0
816
+
817
  for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1):
818
  checkpoint.epoch.assign(epoch)
819
  logger.info(f"Starting Epoch {epoch}...")
820
+
 
821
  epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
822
  batches_processed = 0
823
+
 
824
  try:
825
  train_pbar = tqdm(
826
  total=steps_per_epoch,
 
831
  except ImportError:
832
  train_pbar = None
833
  is_tqdm_train = False
834
+
835
+ # --- Training ---
836
  for q_batch, p_batch, n_batch in train_dataset:
837
  loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch)
838
  epoch_loss_avg(loss)
 
860
  "lr": f"{current_lr:.2e}",
861
  "batches": f"{batches_processed}/{steps_per_epoch}"
862
  })
863
+
864
  gc.collect()
865
+
866
  # End the epoch early if we've processed all steps
867
  if batches_processed >= steps_per_epoch:
868
  break
869
+
870
  if is_tqdm_train and train_pbar:
871
  train_pbar.close()
872
+
873
+ # --- Validation ---
874
  val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32)
875
  val_batches_processed = 0
876
+
877
  try:
878
  val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch")
879
  is_tqdm_val = True
880
  except ImportError:
881
  val_pbar = None
882
  is_tqdm_val = False
883
+
884
  last_valid_val_loss = None
885
  valid_batches = False
886
+
887
  for q_batch, p_batch, n_batch in val_dataset:
888
  # If batch is too small, skip
889
  if tf.shape(q_batch)[0] < 2:
890
  logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}")
891
  continue
892
+
893
  valid_batches = True
894
  val_loss = self.validation_step(q_batch, p_batch, n_batch)
895
  val_loss_avg(val_loss)
896
  last_valid_val_loss = val_loss
897
  val_batches_processed += 1
898
+
899
  if is_tqdm_val:
900
  val_pbar.update(1)
901
  val_pbar.set_postfix({
902
  "val_loss": f"{val_loss.numpy():.4f}",
903
  "batches": f"{val_batches_processed}/{val_steps}"
904
  })
905
+
906
  gc.collect()
907
+
908
  if val_batches_processed >= val_steps:
909
  break
910
+
911
  if not valid_batches:
912
  # If no valid batch is found, fallback
913
  logger.warning("No valid validation batches in this epoch")
 
917
  else:
918
  val_loss = epoch_loss_avg.result()
919
  val_loss_avg(val_loss)
920
+
921
  if is_tqdm_val and val_pbar:
922
  val_pbar.close()
923
+
924
  # End of epoch: final stats
925
  train_loss = epoch_loss_avg.result().numpy()
926
  val_loss = val_loss_avg.result().numpy()
927
  logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
928
+
929
  # TensorBoard epoch logs
930
  with train_summary_writer.as_default():
931
  tf.summary.scalar("epoch_loss", train_loss, step=epoch)
932
  with val_summary_writer.as_default():
933
  tf.summary.scalar("val_loss", val_loss, step=epoch)
934
+
935
  # Save checkpoint
936
  manager.save()
937
+
938
+ # Save model for iterative testing/inference
939
  model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}"
940
  self.save_models(model_save_path)
941
  logger.info(f"Saved model for epoch {epoch} at {model_save_path}")
942
+
943
  # Update local history
944
  self.history['train_loss'].append(train_loss)
945
  self.history['val_loss'].append(val_loss)
 
958
  return obj
959
 
960
  json_history = convert_to_py_floats(self.history)
961
+
962
  # Save training history to file every epoch
 
963
  with open(history_path, 'w') as f:
964
  json.dump(json_history, f)
965
  logger.info(f"Saved training history to {history_path}")
966
+
967
  # Early stopping
968
  if val_loss < best_val_loss - min_delta:
969
  best_val_loss = val_loss
 
975
  if epochs_no_improve >= early_stopping_patience:
976
  logger.info("Early stopping triggered.")
977
  break
978
+
979
  logger.info("Training completed!")
980
 
981
  @tf.function
 
989
  Single training step using queries, positives, and hard negatives.
990
  """
991
  with tf.GradientTape() as tape:
992
+ # Encode queries, positives, and negatives
993
  q_enc = self.encoder(q_batch, training=True) # [batch_size, embed_dim]
 
 
994
  p_enc = self.encoder(p_batch, training=True) # [batch_size, embed_dim]
 
 
 
995
  shape = tf.shape(n_batch)
996
  bs = shape[0]
997
  neg_samples = shape[1]
998
 
999
+ # Flatten negatives to feed them in one pass: [batch_size * neg_samples, max_length]
 
1000
  n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]])
1001
  n_enc_flat = self.encoder(n_batch_flat, training=True) # [bs*neg_samples, embed_dim]
1002
 
1003
  # Reshape back => [batch_size, neg_samples, embed_dim]
1004
  n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1])
1005
 
1006
+ # Combine the positive embedding and negative embeddings along dim=1: shape [batch_size, 1 + neg_samples, embed_dim]
1007
+ # Col 1 is the pos, subsequent cols are negatives
1008
+ combined_p_n = tf.concat([tf.expand_dims(p_enc, axis=1), n_enc], axis=1) # [bs, (1+neg_samples), embed_dim]
 
 
 
 
1009
 
1010
+ # Compute scores: dot product of q_enc with each column in combined_p_n. `tf.einsum` handles the batch dimension
 
 
1011
  dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
1012
  labels = tf.zeros([bs], dtype=tf.int32) # Keep labels as int32
1013
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
 
1016
  )
1017
  loss = tf.cast(tf.reduce_mean(loss), tf.float32)
1018
 
1019
+ # Calculate gradients and clip
1020
  gradients = tape.gradient(loss, self.encoder.trainable_variables)
1021
  gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
1022
  max_grad_norm = tf.constant(1.5, dtype=tf.float32)
1023
  gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm)
1024
  post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32)
1025
 
 
1026
  self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
1027
 
1028
  return loss, gradients_norm, post_clip_norm
 
1036
  ) -> tf.Tensor:
1037
  """
1038
  Single validation step using queries, positives, and hard negatives.
1039
+ Same idea as train_step, but without gradient updates.
1040
  """
1041
  q_enc = self.encoder(q_batch, training=False)
1042
  p_enc = self.encoder(p_batch, training=False)
 
1055
  )
1056
 
1057
  dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32)
1058
+ labels = tf.zeros([bs], dtype=tf.int32)
1059
 
1060
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1061
  labels=labels,
 
1071
  peak_lr: float,
1072
  warmup_steps: int
1073
  ) -> tf.keras.optimizers.schedules.LearningRateSchedule:
1074
+ """
1075
+ Custom learning rate schedule with warmup and cosine decay.
1076
+ """
1077
  class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
1078
  def __init__(
1079
  self,
 
1085
  self.total_steps = tf.cast(total_steps, tf.float32)
1086
  self.peak_lr = tf.cast(peak_lr, tf.float32)
1087
 
1088
+ # warmup_steps 10% of total_steps
1089
  adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10))
1090
  self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32)
1091
 
1092
+ # Calculate constants
1093
  self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32)
1094
  self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32)
1095
 
 
1103
  def __call__(self, step):
1104
  step = tf.cast(step, tf.float32)
1105
 
1106
+ # Warmup
1107
  warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32)
1108
  warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor
1109
 
1110
+ # Decay
1111
  decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32)
1112
  decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32)
1113
  decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32)
1114
  cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32)
1115
  decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
1116
 
 
1117
  final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr)
1118
 
1119
+ # Ensure valid lr
1120
  final_lr = tf.maximum(self.min_lr, final_lr)
1121
  final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr)
1122
 
chatbot_validator.py CHANGED
@@ -113,7 +113,7 @@ class ChatbotValidator:
113
  logger.info(f"\nTest Case {i}: {query}")
114
 
115
  # Retrieve top_k responses, then evaluate with quality checker
116
- responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k, reranker=reranker)
117
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
118
 
119
  # Aggregate metrics and log
 
113
  logger.info(f"\nTest Case {i}: {query}")
114
 
115
  # Retrieve top_k responses, then evaluate with quality checker
116
+ responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=reranker)
117
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
118
 
119
  # Aggregate metrics and log
{data_augmentation β†’ data_augmentation_code}/augmentation_processing_pipeline.py RENAMED
File without changes
{data_augmentation β†’ data_augmentation_code}/back_translator.py RENAMED
File without changes
{data_augmentation β†’ data_augmentation_code}/dialogue_augmenter.py RENAMED
File without changes
{data_augmentation β†’ data_augmentation_code}/main.py RENAMED
File without changes
{data_augmentation β†’ data_augmentation_code}/paraphraser.py RENAMED
File without changes
{data_augmentation β†’ data_augmentation_code}/pipeline_config.py RENAMED
File without changes
{data_augmentation β†’ data_augmentation_code}/quality_metrics.py RENAMED
File without changes
{data_augmentation β†’ data_augmentation_code}/schema_guided_dialogue_processor.py RENAMED
File without changes
{data_augmentation β†’ data_augmentation_code}/taskmaster_processor.py RENAMED
File without changes
validate_model.py β†’ run_chatbot_validation.py RENAMED
@@ -39,7 +39,7 @@ def run_interactive_chat(chatbot, quality_checker):
39
  else:
40
  print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
41
 
42
- def validate_chatbot():
43
  # Initialize environment
44
  env = EnvironmentSetup()
45
  env.initialize()
@@ -86,15 +86,15 @@ def validate_chatbot():
86
  try:
87
  chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
88
  logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
89
- logger.info("FAISS dimensions:", chatbot.data_pipeline.index.d)
90
- logger.info("FAISS index type:", type(chatbot.data_pipeline.index))
91
- logger.info("FAISS index total vectors:", chatbot.data_pipeline.index.ntotal)
92
- logger.info("FAISS is_trained:", chatbot.data_pipeline.index.is_trained)
93
 
94
  with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
95
  chatbot.data_pipeline.response_pool = json.load(f)
96
  logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
97
- logger.info("\nTotal responses in pool:", len(chatbot.data_pipeline.response_pool))
98
 
99
  # Validate dimension consistency
100
  chatbot.data_pipeline.validate_faiss_index()
@@ -130,4 +130,4 @@ def validate_chatbot():
130
  run_interactive_chat(chatbot, quality_checker)
131
 
132
  if __name__ == "__main__":
133
- validate_chatbot()
 
39
  else:
40
  print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
41
 
42
+ def run_chatbot_validation():
43
  # Initialize environment
44
  env = EnvironmentSetup()
45
  env.initialize()
 
86
  try:
87
  chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
88
  logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
89
+ logger.info(f"FAISS dimensions: {chatbot.data_pipeline.index.d}")
90
+ logger.info(f"FAISS index type: {type(chatbot.data_pipeline.index)}")
91
+ logger.info(f"FAISS index total vectors: {chatbot.data_pipeline.index.ntotal}")
92
+ logger.info(f"FAISS is_trained: {chatbot.data_pipeline.index.is_trained}")
93
 
94
  with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
95
  chatbot.data_pipeline.response_pool = json.load(f)
96
  logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
97
+ logger.info(f"\nTotal responses in pool: {len(chatbot.data_pipeline.response_pool)}")
98
 
99
  # Validate dimension consistency
100
  chatbot.data_pipeline.validate_faiss_index()
 
130
  run_interactive_chat(chatbot, quality_checker)
131
 
132
  if __name__ == "__main__":
133
+ run_chatbot_validation()
tf_data_pipeline.py CHANGED
@@ -24,19 +24,19 @@ class TFDataPipeline:
24
  config,
25
  tokenizer,
26
  encoder,
27
- index_file_path: str,
28
  response_pool: List[str],
29
- max_length: int,
30
  query_embeddings_cache: dict,
31
- neg_samples: int = 5,
 
32
  index_type: str = 'IndexFlatIP',
 
33
  nlist: int = 100,
34
  max_retries: int = 3
35
  ):
36
  self.config = config
37
  self.tokenizer = tokenizer
38
  self.encoder = encoder
39
- self.index_file_path = index_file_path
40
  self.response_pool = response_pool
41
  self.max_length = max_length
42
  self.neg_samples = neg_samples
@@ -53,9 +53,9 @@ class TFDataPipeline:
53
  self.build_text_to_domain_map()
54
 
55
  # Initialize FAISS index
56
- if os.path.exists(index_file_path):
57
- logger.info(f"Loading existing FAISS index from {index_file_path}...")
58
- self.index = faiss.read_index(index_file_path)
59
  self.validate_faiss_index()
60
  logger.info("FAISS index loaded and validated successfully.")
61
  else:
@@ -83,18 +83,18 @@ class TFDataPipeline:
83
  self.query_embeddings_cache[query] = hf[query][:]
84
  logger.info(f"Embeddings cache loaded from {cache_file_path}.")
85
 
86
- def save_faiss_index(self, index_file_path: str):
87
- faiss.write_index(self.index, index_file_path)
88
- logger.info(f"FAISS index saved to {index_file_path}")
89
 
90
- def load_faiss_index(self, index_file_path: str):
91
  """Load FAISS index from specified file path."""
92
- if os.path.exists(index_file_path):
93
- self.index = faiss.read_index(index_file_path)
94
- logger.info(f"FAISS index loaded from {index_file_path}.")
95
  else:
96
- logger.error(f"FAISS index file not found at {index_file_path}.")
97
- raise FileNotFoundError(f"FAISS index file not found at {index_file_path}.")
98
 
99
  def validate_faiss_index(self):
100
  """Validates FAISS index dimensionality."""
 
24
  config,
25
  tokenizer,
26
  encoder,
 
27
  response_pool: List[str],
 
28
  query_embeddings_cache: dict,
29
+ max_length: int = 512,
30
+ neg_samples: int = 10,
31
  index_type: str = 'IndexFlatIP',
32
+ faiss_index_file_path: str = 'new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
33
  nlist: int = 100,
34
  max_retries: int = 3
35
  ):
36
  self.config = config
37
  self.tokenizer = tokenizer
38
  self.encoder = encoder
39
+ self.faiss_index_file_path = faiss_index_file_path
40
  self.response_pool = response_pool
41
  self.max_length = max_length
42
  self.neg_samples = neg_samples
 
53
  self.build_text_to_domain_map()
54
 
55
  # Initialize FAISS index
56
+ if os.path.exists(faiss_index_file_path):
57
+ logger.info(f"Loading existing FAISS index from {faiss_index_file_path}...")
58
+ self.index = faiss.read_index(faiss_index_file_path)
59
  self.validate_faiss_index()
60
  logger.info("FAISS index loaded and validated successfully.")
61
  else:
 
83
  self.query_embeddings_cache[query] = hf[query][:]
84
  logger.info(f"Embeddings cache loaded from {cache_file_path}.")
85
 
86
+ def save_faiss_index(self, faiss_index_file_path: str):
87
+ faiss.write_index(self.index, faiss_index_file_path)
88
+ logger.info(f"FAISS index saved to {faiss_index_file_path}")
89
 
90
+ def load_faiss_index(self, faiss_index_file_path: str):
91
  """Load FAISS index from specified file path."""
92
+ if os.path.exists(faiss_index_file_path):
93
+ self.index = faiss.read_index(faiss_index_file_path)
94
+ logger.info(f"FAISS index loaded from {faiss_index_file_path}.")
95
  else:
96
+ logger.error(f"FAISS index file not found at {faiss_index_file_path}.")
97
+ raise FileNotFoundError(f"FAISS index file not found at {faiss_index_file_path}.")
98
 
99
  def validate_faiss_index(self):
100
  """Validates FAISS index dimensionality."""