JoeArmani commited on
Commit
7a0020b
·
1 Parent(s): d53c64b

updates - new iteration with type token

Browse files
.gitignore CHANGED
@@ -180,4 +180,10 @@ cache/*
180
  !cache/.gitkeep
181
  training_data/*
182
  !training_data/.gitkeep
183
- augmented_dialogues.json
 
 
 
 
 
 
 
180
  !cache/.gitkeep
181
  training_data/*
182
  !training_data/.gitkeep
183
+ augmented_dialogues.json
184
+
185
+ checkpoints_old_REMOVE/*
186
+ new_iteration/cache/*
187
+ new_iteration/data_prep_iterative_models/*
188
+ new_iteration/training_data/*
189
+ new_iteration/processed_outputs/*
build_faiss_index.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import faiss
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from transformers import AutoTokenizer, TFAutoModel
9
+ from tqdm.auto import tqdm
10
+
11
+ from chatbot_model import ChatbotConfig, EncoderModel
12
+ from tf_data_pipeline import TFDataPipeline
13
+ from logger_config import config_logger
14
+
15
+ logger = config_logger(__name__)
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+ def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
19
+ """
20
+ Perform a quick sanity check to ensure the model is loaded correctly.
21
+ """
22
+ sample_response = "This is a test response."
23
+ encoded_sample = tokenizer(
24
+ [sample_response],
25
+ padding=True,
26
+ truncation=True,
27
+ max_length=config.max_context_token_limit,
28
+ return_tensors='tf'
29
+ )
30
+
31
+ # Get embedding
32
+ sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
33
+
34
+ # Check shape
35
+ if sample_embedding.shape[1] != config.embedding_dim:
36
+ logger.error(
37
+ f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
38
+ f"got {sample_embedding.shape[1]}"
39
+ )
40
+ raise ValueError("Embedding dimension mismatch.")
41
+ else:
42
+ logger.info("Embedding dimension matches the configuration.")
43
+
44
+ # Check normalization
45
+ embedding_norm = np.linalg.norm(sample_embedding, axis=1)
46
+ if not np.allclose(embedding_norm, 1.0, atol=1e-5):
47
+ logger.error("Embeddings are not properly normalized.")
48
+ raise ValueError("Embeddings are not normalized.")
49
+ else:
50
+ logger.info("Embeddings are properly normalized.")
51
+
52
+ logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
53
+
54
+ def build_faiss_index():
55
+ """
56
+ Rebuild the FAISS index by:
57
+ 1) Loading your config.json
58
+ 2) Initializing encoder + loading submodule & custom weights
59
+ 3) Loading tokenizer from disk
60
+ 4) Creating a TFDataPipeline
61
+ 5) Setting the pipeline's response_pool from a JSON file
62
+ 6) Using pipeline.compute_and_index_response_embeddings()
63
+ 7) Saving the FAISS index
64
+ """
65
+ # Directories
66
+ MODELS_DIR = Path("models")
67
+ FAISS_DIR = MODELS_DIR / "faiss_indices"
68
+ FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
69
+ RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
70
+ TOKENIZER_DIR = MODELS_DIR / "tokenizer"
71
+ SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
72
+ CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
73
+
74
+ # 1) Load ChatbotConfig
75
+ config_path = MODELS_DIR / "config.json"
76
+ if config_path.exists():
77
+ with open(config_path, "r", encoding="utf-8") as f:
78
+ config_dict = json.load(f)
79
+ config = ChatbotConfig.from_dict(config_dict)
80
+ logger.info(f"Loaded ChatbotConfig from {config_path}")
81
+ else:
82
+ config = ChatbotConfig()
83
+ logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
84
+
85
+ # 2) Initialize the EncoderModel
86
+ encoder = EncoderModel(config=config)
87
+ logger.info("EncoderModel instantiated (empty).")
88
+
89
+ # Overwrite the submodule from 'shared_encoder' directory
90
+ if SHARED_ENCODER_DIR.exists():
91
+ logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
92
+ encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
93
+ logger.info("Loaded HF submodule into encoder.pretrained.")
94
+ else:
95
+ logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
96
+
97
+ # Build model once, then load custom weights (projection, etc.)
98
+ dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
99
+ _ = encoder(dummy_input, training=False) # builds the layers
100
+
101
+ if CUSTOM_WEIGHTS_PATH.exists():
102
+ logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
103
+ encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
104
+ logger.info("Custom top-level weights loaded successfully.")
105
+ else:
106
+ logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
107
+
108
+ # 3) Load tokenizer
109
+ if TOKENIZER_DIR.exists():
110
+ logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
111
+ tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
112
+ else:
113
+ logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
114
+ tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
115
+ #tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
116
+
117
+ # 4) Quick sanity check
118
+ sanity_check(encoder, tokenizer, config)
119
+
120
+ # 5) Prepare a TFDataPipeline
121
+ pipeline = TFDataPipeline(
122
+ config=config,
123
+ tokenizer=tokenizer,
124
+ encoder=encoder,
125
+ index_file_path=str(FAISS_INDEX_PATH),
126
+ response_pool=[],
127
+ max_length=config.max_context_token_limit,
128
+ query_embeddings_cache={},
129
+ neg_samples=config.neg_samples,
130
+ index_type='IndexFlatIP',
131
+ nlist=100,
132
+ max_retries=config.max_retries
133
+ )
134
+
135
+ # 6) Load the existing response pool
136
+ if not RESPONSES_PATH.exists():
137
+ logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
138
+ raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
139
+
140
+ with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
141
+ response_pool = json.load(f)
142
+ logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
143
+
144
+ pipeline.response_pool = response_pool # assign to pipeline
145
+
146
+ # 7) Build (or rebuild) the FAISS index from pipeline method
147
+ # This does all the compute-embeddings + index.add in one place
148
+ logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
149
+ pipeline.compute_and_index_response_embeddings()
150
+
151
+ # 8) Save the rebuilt FAISS index
152
+ pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
153
+
154
+ # Verify
155
+ loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
156
+ logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
157
+
158
+ return loaded_index, pipeline.response_pool
159
+
160
+ if __name__ == "__main__":
161
+ build_faiss_index()
chatbot_model.py CHANGED
@@ -10,6 +10,8 @@ from pathlib import Path
10
  import datetime
11
  import faiss
12
  import gc
 
 
13
  from tf_data_pipeline import TFDataPipeline
14
  from response_quality_checker import ResponseQualityChecker
15
  from cross_encoder_reranker import CrossEncoderReranker
@@ -31,7 +33,7 @@ class ChatbotConfig:
31
  num_attention_heads: int = 8
32
  dropout_rate: float = 0.2
33
  l2_reg_weight: float = 0.001
34
- learning_rate: float = 0.0005
35
  min_text_length: int = 3
36
  max_context_turns: int = 5
37
  warmup_steps: int = 200
@@ -41,7 +43,7 @@ class ChatbotConfig:
41
  embedding_batch_size: int = 64
42
  search_batch_size: int = 64
43
  max_batch_size: int = 64
44
- neg_samples: int = 3
45
  max_retries: int = 3
46
 
47
  def to_dict(self) -> Dict:
@@ -54,7 +56,7 @@ class ChatbotConfig:
54
  """Create config from dictionary."""
55
  return cls(**{k: v for k, v in config_dict.items()
56
  if k in cls.__dataclass_fields__})
57
-
58
  class EncoderModel(tf.keras.Model):
59
  """Dual encoder model with pretrained embeddings."""
60
  def __init__(
@@ -154,7 +156,7 @@ class RetrievalChatbot(DeviceAwareModel):
154
  config=self.config,
155
  tokenizer=self.tokenizer,
156
  encoder=self.encoder,
157
- index_file_path='path/to/index', # Update as needed # TODO: Update this path
158
  response_pool=[],
159
  max_length=self.config.max_context_token_limit,
160
  query_embeddings_cache={},
@@ -260,32 +262,49 @@ class RetrievalChatbot(DeviceAwareModel):
260
  def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
261
  """
262
  Load saved models and configuration.
263
-
264
- Args:
265
- load_dir (Union[str, Path]): Directory containing saved model files
266
- mode (str): Either 'training' or 'inference'. In inference mode,
267
- also loads FAISS index and response pool.
268
  """
269
  load_dir = Path(load_dir)
270
 
271
- # Load config
272
  with open(load_dir / "config.json", "r") as f:
273
  config = ChatbotConfig.from_dict(json.load(f))
274
 
275
- # Initialize chatbot with appropriate mode
276
  chatbot = cls(config, mode=mode)
277
 
278
- # Load models
279
  chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
280
  load_dir / "shared_encoder",
281
  config=config
282
  )
283
 
284
- # Load tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
285
  chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
286
  logger.info(f"Models and tokenizer loaded from {load_dir}")
287
 
288
- # If in inference mode, load additional components
 
 
 
 
 
 
 
 
 
 
289
  if mode == 'inference':
290
  cls._prepare_model_for_inference(chatbot, load_dir)
291
 
@@ -296,7 +315,7 @@ class RetrievalChatbot(DeviceAwareModel):
296
  """Internal method to load inference components."""
297
  try:
298
  # Load FAISS index
299
- faiss_path = load_dir / 'faiss_index.bin'
300
  if faiss_path.exists():
301
  chatbot.index = faiss.read_index(str(faiss_path))
302
  logger.info("FAISS index loaded successfully")
@@ -304,7 +323,7 @@ class RetrievalChatbot(DeviceAwareModel):
304
  raise FileNotFoundError(f"FAISS index not found at {faiss_path}")
305
 
306
  # Load response pool
307
- response_pool_path = load_dir / 'response_pool.json'
308
  if response_pool_path.exists():
309
  with open(response_pool_path, 'r') as f:
310
  chatbot.response_pool = json.load(f)
@@ -332,9 +351,12 @@ class RetrievalChatbot(DeviceAwareModel):
332
  with open(save_dir / "config.json", "w") as f:
333
  json.dump(self.config.to_dict(), f, indent=2)
334
 
335
- # Save models
336
  self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
337
 
 
 
 
338
  # Save tokenizer
339
  self.tokenizer.save_pretrained(save_dir / "tokenizer")
340
 
@@ -343,139 +365,270 @@ class RetrievalChatbot(DeviceAwareModel):
343
  def retrieve_responses_cross_encoder(
344
  self,
345
  query: str,
346
- top_k: int,
347
  reranker: Optional[CrossEncoderReranker] = None,
348
  summarizer: Optional[Summarizer] = None,
349
- summarize_threshold: int = 512 # Summarize over 512 tokens
350
  ) -> List[Tuple[str, float]]:
351
  """
352
- Retrieve top-k from FAISS, then re-rank them with a cross-encoder.
353
- Optionally summarize the user query if it's too long.
 
 
 
 
 
 
 
 
 
 
354
  """
355
- if reranker is None:
356
- reranker = self.reranker
357
- if summarizer is None:
358
- summarizer = self.summarizer
359
-
360
- # Optional summarization
361
  if summarizer and len(query.split()) > summarize_threshold:
362
- logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}")
363
  query = summarizer.summarize_text(query)
364
- logger.info(f"Summarized query: {query}")
365
-
366
- # 2) Dense retrieval
367
- dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
368
 
369
- if not dense_topk:
370
- return []
371
 
372
- # 3) Cross-encoder rerank
373
- candidate_texts = [pair[0] for pair in dense_topk]
374
- cross_scores = reranker.rerank(query, candidate_texts, max_length=256)
375
 
376
- # Combine
377
- combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)]
378
- # Sort descending by cross-encoder score
379
- combined.sort(key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
- return combined
382
- # def retrieve_responses_cross_encoder(
383
- # self,
384
- # query: str,
385
- # top_k: int,
386
- # reranker: Optional[CrossEncoderReranker] = None,
387
- # summarizer: Optional[Summarizer] = None,
388
- # summarize_threshold: int = 512 # Summarize over 512 tokens
389
- # ) -> List[Tuple[str, float]]:
390
- # """
391
- # Retrieve top-k from FAISS, then re-rank them with a cross-encoder.
392
- # Optionally summarize the user query if it's too long.
393
- # """
394
- # if reranker is None:
395
- # reranker = self.reranker
396
- # if summarizer is None:
397
- # summarizer = self.summarizer
398
-
399
- # # Optional summarization
400
- # if summarizer and len(query.split()) > summarize_threshold:
401
- # logger.info(f"Query is long. Summarizing before cross-encoder. Original length: {len(query.split())}")
402
- # query = summarizer.summarize_text(query)
403
- # logger.info(f"Summarized query: {query}")
404
 
405
- # # 2) Dense retrieval
406
- # dense_topk = self.retrieve_responses_faiss(query, top_k=top_k) # [(resp, dense_score), ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
- # if not dense_topk:
409
- # return []
 
 
 
410
 
411
- # # 3) Cross-encoder rerank
412
- # candidate_texts = [pair[0] for pair in dense_topk]
413
- # cross_scores = reranker.rerank(query, candidate_texts, max_length=256)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
- # # Combine
416
- # combined = [(text, score) for (text, _), score in zip(dense_topk, cross_scores)]
417
- # # Sort descending by cross-encoder score
418
- # combined.sort(key=lambda x: x[1], reverse=True)
419
 
420
- # return combined
421
-
422
- def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
423
- """Retrieve top-k responses using FAISS."""
424
- if not hasattr(self.data_pipeline, 'index') or self.data_pipeline.index is None:
425
- logger.warning("FAISS index not initialized. Cannot retrieve responses.")
426
- return []
427
-
428
- # Encode the query using TFDataPipeline's method
429
- q_emb = self.data_pipeline.encode_query(query) # Ensure encode_query is within TFDataPipeline
430
- q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
431
-
432
- # Normalize the query embedding for cosine similarity
433
- faiss.normalize_L2(q_emb_np)
434
-
435
- # Search the FAISS index
436
- distances, indices = self.data_pipeline.index.search(q_emb_np, top_k)
437
-
438
- # Map indices to responses and distances to similarities
439
- top_responses = []
440
- for i, idx in enumerate(indices[0]):
441
- if idx < len(self.data_pipeline.response_pool):
442
- top_responses.append((self.data_pipeline.response_pool[idx], float(distances[0][i])))
443
  else:
444
- logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
445
-
 
 
 
 
 
 
 
 
446
  return top_responses
447
- # def retrieve_responses_faiss(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
448
- # """Retrieve top-k responses using FAISS."""
449
- # if not hasattr(self, 'index') or self.index is None:
450
- # logger.warning("FAISS index not initialized. Cannot retrieve responses.")
451
- # return []
452
-
 
 
 
 
 
 
 
 
 
 
 
453
  # # Encode the query
454
- # q_emb = self.encode_query(query) # Shape: [1, embedding_dim]
455
- # q_emb_np = q_emb.numpy().astype('float32') # Ensure type match
456
-
457
- # # Normalize the query embedding for cosine similarity
458
- # faiss.normalize_L2(q_emb_np)
459
-
460
- # # Search the FAISS index
461
- # distances, indices = self.index.search(q_emb_np, top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
- # # Map indices to responses and distances to similarities
464
- # top_responses = []
465
- # for i, idx in enumerate(indices[0]):
466
- # if idx < len(self.response_pool):
467
- # top_responses.append((self.response_pool[idx], float(distances[0][i])))
 
 
 
 
 
468
  # else:
469
- # logger.warning(f"FAISS returned invalid index {idx}. Skipping.")
470
-
471
- # return top_responses
 
 
 
 
 
 
472
 
473
  def chat(
474
  self,
475
  query: str,
476
  conversation_history: Optional[List[Tuple[str, str]]] = None,
477
  quality_checker: Optional['ResponseQualityChecker'] = None,
478
- top_k: int = 5,
479
  ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
480
  """
481
  Example chat method that always uses cross-encoder re-ranking
@@ -516,52 +669,6 @@ class RetrievalChatbot(DeviceAwareModel):
516
  return results[0][0], results, {}
517
 
518
  return get_response(self, query)
519
- # def chat(
520
- # self,
521
- # query: str,
522
- # conversation_history: Optional[List[Tuple[str, str]]] = None,
523
- # quality_checker: Optional['ResponseQualityChecker'] = None,
524
- # top_k: int = 5,
525
- # ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
526
- # """
527
- # Example chat method that always uses cross-encoder re-ranking
528
- # if self.reranker is available.
529
- # """
530
- # @self.run_on_device
531
- # def get_response(self_arg, query_arg): # Add parameters that match decorator's expectations
532
- # # 1) Build conversation context string
533
- # conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
534
-
535
- # # 2) Retrieve + cross-encoder re-rank
536
- # results = self_arg.retrieve_responses_cross_encoder(
537
- # query=conversation_str,
538
- # top_k=top_k,
539
- # reranker=self_arg.reranker,
540
- # summarizer=self_arg.summarizer,
541
- # summarize_threshold=512
542
- # )
543
-
544
- # # 3) Handle empty or confidence
545
- # if not results:
546
- # return (
547
- # "I'm sorry, but I couldn't find a relevant response.",
548
- # [],
549
- # {}
550
- # )
551
-
552
- # if quality_checker:
553
- # metrics = quality_checker.check_response_quality(query_arg, results)
554
- # if not metrics.get('is_confident', False):
555
- # return (
556
- # "I need more information to provide a good answer. Could you please clarify?",
557
- # results,
558
- # metrics
559
- # )
560
- # return results[0][0], results, metrics
561
-
562
- # return results[0][0], results, {}
563
-
564
- # return get_response(self, query)
565
 
566
  def _build_conversation_context(
567
  self,
@@ -581,24 +688,6 @@ class RetrievalChatbot(DeviceAwareModel):
581
 
582
  conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
583
  return "\n".join(conversation_parts)
584
- # def _build_conversation_context(
585
- # self,
586
- # query: str,
587
- # conversation_history: Optional[List[Tuple[str, str]]]
588
- # ) -> str:
589
- # """Build conversation context with better memory management."""
590
- # if not conversation_history:
591
- # return f"{self.special_tokens['user']} {query}"
592
-
593
- # conversation_parts = []
594
- # for user_txt, assistant_txt in conversation_history:
595
- # conversation_parts.extend([
596
- # f"{self.special_tokens['user']} {user_txt}",
597
- # f"{self.special_tokens['assistant']} {assistant_txt}"
598
- # ])
599
-
600
- # conversation_parts.append(f"{self.special_tokens['user']} {query}")
601
- # return "\n".join(conversation_parts)
602
 
603
  def train_model(
604
  self,
@@ -707,23 +796,14 @@ class RetrievalChatbot(DeviceAwareModel):
707
  self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
708
 
709
  if latest_checkpoint and not test_mode:
710
- # Debug info before restore
711
- logger.info("\nEncoder Variables:")
712
- for var in self.encoder.variables:
713
- logger.info(f"{var.name}: {var.dtype} - Shape: {var.shape}")
714
-
715
- logger.info("\nOptimizer Variables:")
716
- for var in self.optimizer.variables:
717
- logger.info(f"{var.name}: {var.dtype} - Shape: {var.shape}")
718
-
719
  # Add checkpoint inspection
720
- logger.info("\nTrying to load checkpoint from: ", latest_checkpoint)
721
  reader = tf.train.load_checkpoint(latest_checkpoint)
722
- shape_from_key = reader.get_variable_to_shape_map()
723
- dtype_from_key = reader.get_variable_to_dtype_map()
724
- logger.info("\nCheckpoint Variables:")
725
- for key in shape_from_key:
726
- logger.info(f"{key}: dtype={dtype_from_key[key]} - Shape: {shape_from_key[key]}")
727
 
728
  status = checkpoint.restore(latest_checkpoint)
729
  status.assert_consumed()
@@ -754,6 +834,10 @@ class RetrievalChatbot(DeviceAwareModel):
754
  logger.info(f"Loaded previous training history from {history_path}")
755
  except Exception as e:
756
  logger.warning(f"Could not load history, starting fresh: {e}")
 
 
 
 
757
  else:
758
  logger.info("Starting training from scratch")
759
  checkpoint.epoch.assign(tf.cast(0, tf.int32))
 
10
  import datetime
11
  import faiss
12
  import gc
13
+
14
+ import re
15
  from tf_data_pipeline import TFDataPipeline
16
  from response_quality_checker import ResponseQualityChecker
17
  from cross_encoder_reranker import CrossEncoderReranker
 
33
  num_attention_heads: int = 8
34
  dropout_rate: float = 0.2
35
  l2_reg_weight: float = 0.001
36
+ learning_rate: float = 0.001
37
  min_text_length: int = 3
38
  max_context_turns: int = 5
39
  warmup_steps: int = 200
 
43
  embedding_batch_size: int = 64
44
  search_batch_size: int = 64
45
  max_batch_size: int = 64
46
+ neg_samples: int = 10
47
  max_retries: int = 3
48
 
49
  def to_dict(self) -> Dict:
 
56
  """Create config from dictionary."""
57
  return cls(**{k: v for k, v in config_dict.items()
58
  if k in cls.__dataclass_fields__})
59
+
60
  class EncoderModel(tf.keras.Model):
61
  """Dual encoder model with pretrained embeddings."""
62
  def __init__(
 
156
  config=self.config,
157
  tokenizer=self.tokenizer,
158
  encoder=self.encoder,
159
+ index_file_path='new_iteration/data_prep_iterative_models/faiss_indices/faiss_index_production.index',
160
  response_pool=[],
161
  max_length=self.config.max_context_token_limit,
162
  query_embeddings_cache={},
 
262
  def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot':
263
  """
264
  Load saved models and configuration.
 
 
 
 
 
265
  """
266
  load_dir = Path(load_dir)
267
 
268
+ # 1) Load config
269
  with open(load_dir / "config.json", "r") as f:
270
  config = ChatbotConfig.from_dict(json.load(f))
271
 
272
+ # 2) Initialize chatbot
273
  chatbot = cls(config, mode=mode)
274
 
275
+ # 3) Load DistilBERT from huggingface folder
276
  chatbot.encoder.pretrained = TFAutoModel.from_pretrained(
277
  load_dir / "shared_encoder",
278
  config=config
279
  )
280
 
281
+ dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
282
+ _ = chatbot.encoder(dummy_input, training=False)
283
+
284
+ # # Then load your custom weights
285
+ # custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
286
+ # if custom_weights_path.exists():
287
+ # logger.info(f"Loading custom top-level weights from {custom_weights_path}")
288
+ # chatbot.encoder.load_weights(str(custom_weights_path))
289
+ # logger.info("Custom top-level weights loaded successfully.")
290
+ # else:
291
+ # logger.warning(f"Custom weights file not found at {custom_weights_path}.")
292
+
293
+ # 4) Load tokenizer
294
  chatbot.tokenizer = AutoTokenizer.from_pretrained(load_dir / "tokenizer")
295
  logger.info(f"Models and tokenizer loaded from {load_dir}")
296
 
297
+
298
+
299
+ # 5) Load the custom top layers' weights
300
+ custom_weights_path = load_dir / "encoder_custom_weights.weights.h5"
301
+ if custom_weights_path.exists():
302
+ chatbot.encoder.load_weights(str(custom_weights_path))
303
+ logger.info("Loaded custom encoder weights for projection/dropout/etc.")
304
+ else:
305
+ logger.warning(f"No custom encoder weights found at {custom_weights_path}. The top-level projection layer won't have learned parameters.")
306
+
307
+ # 6) If in inference mode, load FAISS, etc.
308
  if mode == 'inference':
309
  cls._prepare_model_for_inference(chatbot, load_dir)
310
 
 
315
  """Internal method to load inference components."""
316
  try:
317
  # Load FAISS index
318
+ faiss_path = load_dir / 'faiss_indices/faiss_index_production.index'
319
  if faiss_path.exists():
320
  chatbot.index = faiss.read_index(str(faiss_path))
321
  logger.info("FAISS index loaded successfully")
 
323
  raise FileNotFoundError(f"FAISS index not found at {faiss_path}")
324
 
325
  # Load response pool
326
+ response_pool_path = load_dir / 'faiss_indices/faiss_index_production_responses.json'
327
  if response_pool_path.exists():
328
  with open(response_pool_path, 'r') as f:
329
  chatbot.response_pool = json.load(f)
 
351
  with open(save_dir / "config.json", "w") as f:
352
  json.dump(self.config.to_dict(), f, indent=2)
353
 
354
+ # Save the HF DistilBERT submodule:
355
  self.encoder.pretrained.save_pretrained(save_dir / "shared_encoder")
356
 
357
+ # ALSO save custom top-level layers' weights
358
+ self.encoder.save_weights(save_dir / "encoder_custom_weights.weights.h5")
359
+
360
  # Save tokenizer
361
  self.tokenizer.save_pretrained(save_dir / "tokenizer")
362
 
 
365
  def retrieve_responses_cross_encoder(
366
  self,
367
  query: str,
368
+ top_k: int = 10,
369
  reranker: Optional[CrossEncoderReranker] = None,
370
  summarizer: Optional[Summarizer] = None,
371
+ summarize_threshold: int = 512
372
  ) -> List[Tuple[str, float]]:
373
  """
374
+ Retrieve top-k responses with optional domain-based boosting
375
+ and cross-encoder re-ranking.
376
+
377
+ Args:
378
+ query: The user's input text.
379
+ top_k: Number of final results to return.
380
+ reranker: CrossEncoderReranker for refined scoring, if available.
381
+ summarizer: Summarizer for long queries, if desired.
382
+ summarize_threshold: Summarize if query wordcount > threshold.
383
+
384
+ Returns:
385
+ List of (response_text, final_score).
386
  """
387
+ # 1) Optional query summarization
 
 
 
 
 
388
  if summarizer and len(query.split()) > summarize_threshold:
389
+ logger.info(f"Query is long ({len(query.split())} words). Summarizing.")
390
  query = summarizer.summarize_text(query)
391
+ logger.info(f"Summarized Query: {query}")
 
 
 
392
 
393
+ detected_domain = self.detect_domain_from_query(query)
394
+ logger.debug(f"Detected domain '{detected_domain}' for query: {query}")
395
 
396
+ # 2) Retrieve more initial candidates from FAISS
397
+ initial_k = min(top_k * 10, len(self.data_pipeline.response_pool))
398
+ dense_candidates = self.retrieve_responses_faiss(query, domain=detected_domain, top_k=initial_k)
399
 
400
+ boosted_candidates = dense_candidates
401
+
402
+ # 4) If we have a cross-encoder, re-rank these boosted candidates
403
+ if not reranker:
404
+ logger.warning("No CrossEncoderReranker provided; creating a default one.")
405
+ reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
406
+
407
+ texts = [item[0] for item in boosted_candidates]
408
+ ce_scores = reranker.rerank(query, texts, max_length=256)
409
+
410
+ # Combine cross-encoder score with the base FAISS score (simple multiplicative approach)
411
+ final_candidates = []
412
+ for (resp_text, faiss_score), ce_score in zip(boosted_candidates, ce_scores):
413
+ # TODO: dial this in.
414
+ alpha = 0.8
415
+ combined_score = alpha * ce_score + (1 - alpha) * faiss_score
416
+ length_adjusted_score = self.length_adjust_score(resp_text, combined_score)
417
+ #combined_score = ce_score * faiss_score
418
+ final_candidates.append((resp_text, combined_score))
419
+
420
+ # Sort descending by combined score
421
+ final_candidates.sort(key=lambda x: x[1], reverse=True)
422
+
423
+ # Return top_k
424
+ return final_candidates[:top_k]
425
+
426
+ DOMAIN_KEYWORDS = {
427
+ 'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'],
428
+ 'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'],
429
+ 'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'],
430
+ 'coffee': ['coffee', 'café', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'],
431
+ 'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'],
432
+ 'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'],
433
+ }
434
+
435
+ def extract_keywords(self, query: str) -> List[str]:
436
+ """
437
+ Extract keywords from the query based on DOMAIN_KEYWORDS.
438
+ """
439
+ query_lower = query.lower()
440
+ keywords = set()
441
+ for domain, kws in self.DOMAIN_KEYWORDS.items():
442
+ for kw in kws:
443
+ if kw in query_lower:
444
+ keywords.add(kw)
445
+ return list(keywords)
446
+
447
+ def length_adjust_score(resp_text: str, base_score: float) -> float:
448
+ # Apply a short penalty
449
+ words = len(resp_text.split())
450
+ if words < 3:
451
+ # big penalty or skip entirely
452
+ return base_score * 0.1 # or base_score - 0.01
453
+
454
+ # Add a mild bonus for lines that exceed 12 words:
455
+ if words > 12:
456
+ # e.g. +0.002 * (words - 12)
457
+ bonus = 0.002 * (words - 12)
458
+ base_score += bonus
459
+
460
+ return base_score
461
+
462
+ def detect_domain_from_query(self, query: str) -> str:
463
+ """
464
+ Detect the domain of the query based on keywords.
465
+ """
466
+ domain_patterns = {
467
+ 'restaurant': r'\b(restaurant|dining|food|dine|reservation|table|menu|cuisine|eat|place\s?to\s?eat|hungry|chef|dish|meal|fork|knife|spoon|brunch|bistro|buffet|catering|gourmet|fast\s?food|fine\s?dining|takeaway|delivery|restaurant\s?booking)\b',
468
+ 'movie': r'\b(movie|cinema|film|ticket|showtime|showing|theater|flick|screening|film\s?ticket|film\s?show|blockbuster|premiere|trailer|director|actor|actress|plot|genre|screen|sequel|animation|documentary)\b',
469
+ 'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff|driver|cab|hailing|rideshare|ride\s?hailing|carpool|chauffeur|transit|transportation|hail\s?ride)\b',
470
+ 'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano|barista|brew|cappuccino|macchiato|iced\s?coffee|cold\s?brew|espresso\s?machine|coffee\s?shop|tea|chai|java|bean|roast|decaf)\b',
471
+ 'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria|slice|pie|margherita|deep\s?dish|thin\s?crust|cheese|oven|tossed|sauce|garlic\s?bread|calzone)\b',
472
+ 'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change|garage|auto\s?shop|tire|check\s?engine|battery|transmission|brake|engine\s?diagnostics|carwash|detail|alignment|exhaust|spark\s?plug|dashboard)\b',
473
+ }
474
 
475
+ # Check for matches
476
+ for domain, pattern in domain_patterns.items():
477
+ if re.search(pattern, query.lower()):
478
+ return domain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
+ return 'other'
481
+
482
+ def is_numeric_response(text: str) -> bool:
483
+ """
484
+ Return True if `text` is purely digits (and/or spaces).
485
+ e.g.: "4 3 13" -> True, " 42 " -> True, "hello 42" -> False
486
+ """
487
+ pattern = r'^\s*[0-9]+(\s+[0-9]+)*\s*$'
488
+ return bool(re.match(pattern, text))
489
+
490
+ def retrieve_responses_faiss(
491
+ self,
492
+ query: str,
493
+ domain: str = 'other',
494
+ top_k: int = 5,
495
+ boost_factor: float = 1.3
496
+ ) -> List[Tuple[str, float]]:
497
+ """
498
+ Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
499
 
500
+ Args:
501
+ query (str): The user input text.
502
+ domain (str, optional): The detected domain. Defaults to 'other'.
503
+ top_k (int, optional): Number of top results to return. Defaults to 5.
504
+ boost_factor (float, optional): Factor to boost scores for keyword matches. Defaults to 1.3.
505
 
506
+ Returns:
507
+ List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
508
+ """
509
+ # Encode the query
510
+ q_emb = self.data_pipeline.encode_query(query)
511
+ q_emb_np = q_emb.reshape(1, -1).astype('float32')
512
+
513
+ # Search the index
514
+ distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10) # Adjust multiplier as needed
515
+
516
+ # IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors)
517
+ candidates = []
518
+ for rank, idx in enumerate(indices[0]):
519
+ if idx == -1:
520
+ continue # FAISS may return -1 for invalid indices
521
+ response = self.data_pipeline.response_pool[idx]
522
+ text = response.get('text', '')
523
+ cand_domain = response.get('domain', 'other')
524
+ score = distances[0][rank]
525
+
526
+ # Filter out numeric responses and very short texts
527
+ if not self.is_numeric_response(text) and len(text.split()) >= self.config.min_text_length:
528
+ candidates.append((text, cand_domain, score))
529
+
530
+ if not candidates:
531
+ logger.warning("No valid candidates found after initial filtering.")
532
+ return []
533
 
534
+ # Sort candidates by score descending
535
+ candidates.sort(key=lambda x: x[2], reverse=True)
 
 
536
 
537
+ # Filter in-domain responses
538
+ if domain != 'other':
539
+ in_domain_responses = [c for c in candidates if c[1] == domain]
540
+ if not in_domain_responses:
541
+ logger.info(f"No in-domain responses found for domain '{domain}'. Falling back to all candidates.")
542
+ in_domain_responses = candidates
543
+ else:
544
+ in_domain_responses = candidates
545
+
546
+ # Boost responses containing query keywords
547
+ query_keywords = self.extract_keywords(query)
548
+ boosted_responses = []
549
+ for resp_text, resp_domain, score in in_domain_responses:
550
+ if any(kw in resp_text.lower() for kw in query_keywords):
551
+ boosted_score = score * boost_factor
552
+ logger.debug(f"Boosting response: '{resp_text}' by factor {boost_factor}")
 
 
 
 
 
 
 
553
  else:
554
+ boosted_score = score
555
+ boosted_responses.append((resp_text, boosted_score))
556
+
557
+ # Sort boosted responses
558
+ boosted_responses.sort(key=lambda x: x[1], reverse=True)
559
+
560
+ # Select top_k responses
561
+ top_responses = boosted_responses[:top_k]
562
+ logger.debug(f"Top {top_k} responses selected.")
563
+
564
  return top_responses
565
+ # def retrieve_responses_faiss(
566
+ # self,
567
+ # query: str,
568
+ # domain: str = 'other',
569
+ # top_k: int = 5,
570
+ # boost_factor: float = 1.3
571
+ # ) -> List[Tuple[str, float]]:
572
+ # """
573
+ # Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
574
+
575
+ # Args:
576
+ # query: The user input text
577
+ # top_k: Number of top results to return
578
+
579
+ # Returns:
580
+ # List of (response_text, similarity) sorted by descending similarity
581
+ # """
582
  # # Encode the query
583
+ # q_emb = self.data_pipeline.encode_query(query)
584
+ # q_emb_np = q_emb.reshape(1, -1).astype('float32')
585
+
586
+ # # Search the index
587
+ # distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10) # distances: shape [1, k], indices: shape [1, k]
588
+
589
+ # # IndexFlatIP: 'distances' are cosine similarities in [-1, +1].
590
+ # candidates = []
591
+ # for rank, idx in enumerate(indices[0]):
592
+ # text = self.response_pool[idx]['text']
593
+ # cand_domain = self.response_pool[idx]['domain']
594
+ # score = distances[0][rank]
595
+
596
+ # # filter out responses with only numbers or too few words
597
+ # word_count = len(text.split())
598
+ # if not self.is_numeric_resonse(text) and word_count >= 2:
599
+ # candidates.append((text, cand_domain, score))
600
+
601
+ # # Filter to in-domain responses
602
+ # candidates.sort(key=lambda x: x[2], reverse=True)
603
+ # in_domain_responses = [(text, score) for (text, cand_domain, score) in candidates if cand_domain == domain]
604
 
605
+ # # Boost keyword matching responses
606
+ # query_keywords = self.extract_keywords(query)
607
+ # boosted_responses = []
608
+ # for (resp_text, domain, score) in in_domain_responses:
609
+ # # Check if any keyword is present in the response text
610
+ # for kw in query_keywords:
611
+ # if kw in resp_text.lower():
612
+ # boosted_score = score * boost_factor
613
+ # print(f"Boosting response: '{resp_text}' by factor {boost_factor}")
614
+ # break
615
  # else:
616
+ # boosted_score = score
617
+ # boosted_responses.append((resp_text, domain, boosted_score))
618
+
619
+ # # Debug
620
+ # logger.debug("\nFAISS Search Results (top 15 for debug):")
621
+ # for i, (resp, score) in enumerate(boosted_responses[:15], start=1):
622
+ # logger.debug(f"{i}. Score: {score:.4f} -> {resp[:60]}")
623
+
624
+ # return boosted_responses[:top_k]
625
 
626
  def chat(
627
  self,
628
  query: str,
629
  conversation_history: Optional[List[Tuple[str, str]]] = None,
630
  quality_checker: Optional['ResponseQualityChecker'] = None,
631
+ top_k: int = 10,
632
  ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]:
633
  """
634
  Example chat method that always uses cross-encoder re-ranking
 
669
  return results[0][0], results, {}
670
 
671
  return get_response(self, query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
 
673
  def _build_conversation_context(
674
  self,
 
688
 
689
  conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
690
  return "\n".join(conversation_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
 
692
  def train_model(
693
  self,
 
796
  self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
797
 
798
  if latest_checkpoint and not test_mode:
 
 
 
 
 
 
 
 
 
799
  # Add checkpoint inspection
800
+ logger.info(f"\nTrying to load checkpoint from: {latest_checkpoint}")
801
  reader = tf.train.load_checkpoint(latest_checkpoint)
802
+ # shape_from_key = reader.get_variable_to_shape_map()
803
+ # dtype_from_key = reader.get_variable_to_dtype_map()
804
+ # logger.info("\nCheckpoint Variables:")
805
+ # for key in shape_from_key:
806
+ # logger.info(f"{key}: dtype={dtype_from_key[key]} - Shape: {shape_from_key[key]}")
807
 
808
  status = checkpoint.restore(latest_checkpoint)
809
  status.assert_consumed()
 
834
  logger.info(f"Loaded previous training history from {history_path}")
835
  except Exception as e:
836
  logger.warning(f"Could not load history, starting fresh: {e}")
837
+
838
+ # Fix for custom weights not being saved in the full model.
839
+ self.save_models(Path(checkpoint_dir) / "pretrained_full_model")
840
+ logger.info(f"Manually saved custom weights after restore.")
841
  else:
842
  logger.info("Starting training from scratch")
843
  checkpoint.epoch.assign(tf.cast(0, tf.int32))
chatbot_validator.py CHANGED
@@ -1,30 +1,41 @@
1
  from typing import Dict, List, Tuple, Any, Optional
2
  import numpy as np
3
-
4
  from logger_config import config_logger
 
 
5
  logger = config_logger(__name__)
6
 
 
7
  class ChatbotValidator:
8
- """Handles automated validation and performance analysis for the chatbot."""
 
 
 
 
 
 
 
9
 
10
  def __init__(self, chatbot, quality_checker):
11
  """
12
  Initialize the validator.
13
 
14
  Args:
15
- chatbot: RetrievalChatbot instance
16
  quality_checker: ResponseQualityChecker instance
17
  """
18
  self.chatbot = chatbot
19
  self.quality_checker = quality_checker
20
 
21
- # Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
 
22
  self.domain_queries = {
23
  'restaurant': [
24
  "I'd like to make a reservation for dinner tonight.",
25
- "Can you book a table for 4 people at an Italian place?",
26
- "Do you have any availability for tomorrow at 7pm?",
27
- "I need to change my dinner reservation time.",
28
  "What's the wait time for a table right now?"
29
  ],
30
  'movie_tickets': [
@@ -38,8 +49,8 @@ class ChatbotValidator:
38
  "I need a ride from the airport to downtown.",
39
  "How much would it cost to get to the mall?",
40
  "Can you book a car for tomorrow morning?",
41
- "Is there a driver available now?",
42
- "What's the estimated arrival time?"
43
  ],
44
  'services': [
45
  "I need to schedule an oil change for my car.",
@@ -61,7 +72,9 @@ class ChatbotValidator:
61
  self,
62
  num_examples: int = 5,
63
  top_k: int = 10,
64
- domains: Optional[List[str]] = None
 
 
65
  ) -> Dict[str, Any]:
66
  """
67
  Run comprehensive validation across specified domains.
@@ -69,36 +82,55 @@ class ChatbotValidator:
69
  Args:
70
  num_examples: Number of test queries per domain
71
  top_k: Number of responses to retrieve for each query
72
- domains: Optional list of specific domains to test
 
 
73
 
74
  Returns:
75
  Dict containing detailed validation metrics and domain-specific performance
76
  """
77
  logger.info("\n=== Running Enhanced Automatic Validation ===")
78
 
79
- # Select domains to test
80
  test_domains = domains if domains else list(self.domain_queries.keys())
 
 
81
  metrics_history = []
82
  domain_metrics = {}
 
 
 
 
 
83
 
84
  # Run validation for each domain
85
  for domain in test_domains:
 
 
 
 
 
 
 
 
 
 
 
 
86
  domain_metrics[domain] = []
87
- queries = self.domain_queries[domain][:num_examples]
88
 
89
  logger.info(f"\n=== Testing {domain.title()} Domain ===")
90
 
91
  for i, query in enumerate(queries, 1):
92
- logger.info(f"\nTest Case {i}:")
93
- logger.info(f"Query: {query}")
94
 
95
- # Get responses with increased top_k
96
- responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
97
 
98
- # Enhanced quality checking with context (assuming no context here)
99
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
100
 
101
- # Add domain info
102
  quality_metrics['domain'] = domain
103
  metrics_history.append(quality_metrics)
104
  domain_metrics[domain].append(quality_metrics)
@@ -106,11 +138,12 @@ class ChatbotValidator:
106
  # Detailed logging
107
  self._log_validation_results(query, responses, quality_metrics, i)
108
 
109
- # Calculate and log overall metrics
110
  aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
111
  domain_analysis = self._analyze_domain_performance(domain_metrics)
112
  confidence_analysis = self._analyze_confidence_distribution(metrics_history)
113
 
 
114
  aggregate_metrics.update({
115
  'domain_performance': domain_analysis,
116
  'confidence_analysis': confidence_analysis
@@ -120,48 +153,74 @@ class ChatbotValidator:
120
  return aggregate_metrics
121
 
122
  def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
123
- """Calculate comprehensive aggregate metrics."""
 
 
 
 
 
 
 
 
 
124
  metrics = {
125
  'num_queries_tested': len(metrics_history),
126
- 'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
127
- 'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
128
- 'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
129
- 'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
130
- 'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
131
- 'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
 
132
 
133
  # Additional statistical metrics
134
- 'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
135
- 'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
136
- 'min_score': np.min([m.get('top_score', 0) for m in metrics_history]),
137
- 'max_score': np.max([m.get('top_score', 0) for m in metrics_history])
138
  }
139
  return metrics
140
 
141
- def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
142
- """Analyze performance by domain."""
143
- domain_analysis = {}
144
-
145
- for domain, metrics in domain_metrics.items():
146
- domain_analysis[domain] = {
147
- 'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
148
- 'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics]),
149
- 'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics]),
150
- 'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
151
- 'num_samples': len(metrics)
 
 
 
 
 
 
 
 
 
 
 
152
  }
153
 
154
- return domain_analysis
155
 
156
  def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
157
- """Analyze the distribution of confidence scores."""
158
- scores = [m.get('top_score', 0) for m in metrics_history]
 
 
 
 
159
 
 
160
  return {
161
- 'percentile_25': np.percentile(scores, 25),
162
- 'percentile_50': np.percentile(scores, 50),
163
- 'percentile_75': np.percentile(scores, 75),
164
- 'percentile_90': np.percentile(scores, 90)
165
  }
166
 
167
  def _log_validation_results(
@@ -171,37 +230,51 @@ class ChatbotValidator:
171
  metrics: Dict[str, Any],
172
  case_num: int
173
  ):
174
- """Log detailed validation results."""
175
- logger.info(f"\nTest Case {case_num}:")
176
- logger.info(f"Query: {query}")
177
- logger.info(f"Domain: {metrics.get('domain', 'Unknown')}")
178
- logger.info(f"Confidence: {'Yes' if metrics.get('is_confident', False) else 'No'}")
179
- logger.info("\nQuality Metrics:")
180
- for metric, value in metrics.items():
181
- if isinstance(value, (int, float)):
182
- logger.info(f" {metric}: {value:.4f}")
 
 
183
 
184
- logger.info("\nTop Responses:")
185
- for i, (response, score) in enumerate(responses[:3], 1):
186
- logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
187
- if i == 1 and not metrics.get('is_confident', False):
188
- logger.info(" [Low Confidence]")
189
 
190
  def _log_validation_summary(self, metrics: Dict[str, Any]):
191
- """Log comprehensive validation summary."""
 
 
 
 
 
 
192
  logger.info("\n=== Validation Summary ===")
193
 
 
194
  logger.info("\nOverall Metrics:")
195
  for metric, value in metrics.items():
 
196
  if isinstance(value, (int, float)):
197
  logger.info(f"{metric}: {value:.4f}")
198
 
 
 
199
  logger.info("\nDomain Performance:")
200
- for domain, domain_metrics in metrics['domain_performance'].items():
201
  logger.info(f"\n{domain.title()}:")
202
- for metric, value in domain_metrics.items():
203
  logger.info(f" {metric}: {value:.4f}")
204
 
 
 
205
  logger.info("\nConfidence Distribution:")
206
- for percentile, value in metrics['confidence_analysis'].items():
207
- logger.info(f"{percentile}: {value:.4f}")
 
1
  from typing import Dict, List, Tuple, Any, Optional
2
  import numpy as np
3
+ import random
4
  from logger_config import config_logger
5
+ from cross_encoder_reranker import CrossEncoderReranker
6
+
7
  logger = config_logger(__name__)
8
 
9
+
10
  class ChatbotValidator:
11
+ """
12
+ Handles automated validation and performance analysis for the chatbot.
13
+
14
+ This validator executes domain-specific test queries, obtains candidate
15
+ responses via the chatbot, then evaluates them with a quality checker.
16
+ It aggregates metrics across queries and domains, logs intermediate
17
+ results, and returns a comprehensive summary.
18
+ """
19
 
20
  def __init__(self, chatbot, quality_checker):
21
  """
22
  Initialize the validator.
23
 
24
  Args:
25
+ chatbot: RetrievalChatbot instance for inference
26
  quality_checker: ResponseQualityChecker instance
27
  """
28
  self.chatbot = chatbot
29
  self.quality_checker = quality_checker
30
 
31
+ # Basic domain-specific test queries (easy examples)
32
+ # Taskmaster-1 and Schema-Guided style
33
  self.domain_queries = {
34
  'restaurant': [
35
  "I'd like to make a reservation for dinner tonight.",
36
+ "Can you book a table for 4 at an Italian restaurant?",
37
+ "Is there any availability to dine tomorrow at 7pm?",
38
+ "I'd like to cancel my reservation for tonight.",
39
  "What's the wait time for a table right now?"
40
  ],
41
  'movie_tickets': [
 
49
  "I need a ride from the airport to downtown.",
50
  "How much would it cost to get to the mall?",
51
  "Can you book a car for tomorrow morning?",
52
+ "Is there a driver available right now?",
53
+ "What's the estimated arrival time for the driver?"
54
  ],
55
  'services': [
56
  "I need to schedule an oil change for my car.",
 
72
  self,
73
  num_examples: int = 5,
74
  top_k: int = 10,
75
+ domains: Optional[List[str]] = None,
76
+ randomize: bool = False,
77
+ seed: int = 42
78
  ) -> Dict[str, Any]:
79
  """
80
  Run comprehensive validation across specified domains.
 
82
  Args:
83
  num_examples: Number of test queries per domain
84
  top_k: Number of responses to retrieve for each query
85
+ domains: Optional list of domain keys to test. If None, test all.
86
+ randomize: If True, randomly select queries from the domain lists
87
+ seed: Random seed for consistent sampling if randomize=True
88
 
89
  Returns:
90
  Dict containing detailed validation metrics and domain-specific performance
91
  """
92
  logger.info("\n=== Running Enhanced Automatic Validation ===")
93
 
94
+ # Select which domains to test
95
  test_domains = domains if domains else list(self.domain_queries.keys())
96
+
97
+ # Initialize results
98
  metrics_history = []
99
  domain_metrics = {}
100
+
101
+ reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
102
+
103
+ # Prepare random selection if needed
104
+ rng = random.Random(seed)
105
 
106
  # Run validation for each domain
107
  for domain in test_domains:
108
+ # Avoid errors if domain key missing
109
+ if domain not in self.domain_queries:
110
+ logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.")
111
+ continue
112
+
113
+ all_queries = self.domain_queries[domain]
114
+ if randomize:
115
+ queries = rng.sample(all_queries, min(num_examples, len(all_queries)))
116
+ else:
117
+ queries = all_queries[:num_examples]
118
+
119
+ # Store domain-level metrics
120
  domain_metrics[domain] = []
 
121
 
122
  logger.info(f"\n=== Testing {domain.title()} Domain ===")
123
 
124
  for i, query in enumerate(queries, 1):
125
+ logger.info(f"\nTest Case {i}: {query}")
 
126
 
127
+ # Retrieve top_k responses (including cross-encoder re-ranking if available)
128
+ responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k, reranker=reranker)
129
 
130
+ # Evaluate with quality checker
131
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
132
 
133
+ # Save domain info
134
  quality_metrics['domain'] = domain
135
  metrics_history.append(quality_metrics)
136
  domain_metrics[domain].append(quality_metrics)
 
138
  # Detailed logging
139
  self._log_validation_results(query, responses, quality_metrics, i)
140
 
141
+ # Final aggregation
142
  aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
143
  domain_analysis = self._analyze_domain_performance(domain_metrics)
144
  confidence_analysis = self._analyze_confidence_distribution(metrics_history)
145
 
146
+ # Combine into one dictionary
147
  aggregate_metrics.update({
148
  'domain_performance': domain_analysis,
149
  'confidence_analysis': confidence_analysis
 
153
  return aggregate_metrics
154
 
155
  def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
156
+ """
157
+ Calculate comprehensive aggregate metrics over all tested queries.
158
+ """
159
+ if not metrics_history:
160
+ logger.warning("No metrics to aggregate. Returning empty summary.")
161
+ return {}
162
+
163
+ top_scores = [m.get('top_score', 0.0) for m in metrics_history]
164
+
165
+ # The length-based metrics are robust to missing or zero-length data
166
  metrics = {
167
  'num_queries_tested': len(metrics_history),
168
+ 'avg_top_response_score': np.mean(top_scores),
169
+ 'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_history]),
170
+ 'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]),
171
+ 'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]),
172
+ 'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]),
173
+ 'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0
174
+ for m in metrics_history]),
175
 
176
  # Additional statistical metrics
177
+ 'median_top_score': np.median(top_scores),
178
+ 'score_std': np.std(top_scores),
179
+ 'min_score': np.min(top_scores),
180
+ 'max_score': np.max(top_scores)
181
  }
182
  return metrics
183
 
184
+ def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]:
185
+ """
186
+ Analyze performance by domain, returning a nested dict.
187
+ """
188
+ analysis = {}
189
+
190
+ for domain, metrics_list in domain_metrics.items():
191
+ if not metrics_list:
192
+ analysis[domain] = {}
193
+ continue
194
+
195
+ top_scores = [m.get('top_score', 0.0) for m in metrics_list]
196
+
197
+ analysis[domain] = {
198
+ 'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0
199
+ for m in metrics_list]),
200
+ 'avg_relevance': np.mean([m.get('query_response_relevance', 0.0)
201
+ for m in metrics_list]),
202
+ 'avg_diversity': np.mean([m.get('response_diversity', 0.0)
203
+ for m in metrics_list]),
204
+ 'avg_top_score': np.mean(top_scores),
205
+ 'num_samples': len(metrics_list)
206
  }
207
 
208
+ return analysis
209
 
210
  def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
211
+ """
212
+ Analyze the distribution of top scores to gauge system confidence levels.
213
+ """
214
+ if not metrics_history:
215
+ return {'percentile_25': 0.0, 'percentile_50': 0.0,
216
+ 'percentile_75': 0.0, 'percentile_90': 0.0}
217
 
218
+ scores = [m.get('top_score', 0.0) for m in metrics_history]
219
  return {
220
+ 'percentile_25': float(np.percentile(scores, 25)),
221
+ 'percentile_50': float(np.percentile(scores, 50)),
222
+ 'percentile_75': float(np.percentile(scores, 75)),
223
+ 'percentile_90': float(np.percentile(scores, 90))
224
  }
225
 
226
  def _log_validation_results(
 
230
  metrics: Dict[str, Any],
231
  case_num: int
232
  ):
233
+ """
234
+ Log detailed validation results for each test case.
235
+ """
236
+ domain = metrics.get('domain', 'Unknown')
237
+ is_confident = metrics.get('is_confident', False)
238
+
239
+ logger.info(f"Domain: {domain} | Confidence: {'Yes' if is_confident else 'No'}")
240
+ logger.info("Quality Metrics:")
241
+ for k, v in metrics.items():
242
+ if isinstance(v, (int, float)):
243
+ logger.info(f" {k}: {v:.4f}")
244
 
245
+ logger.info("Top 3 Responses:")
246
+ for i, (resp_text, score) in enumerate(responses[:3], 1):
247
+ logger.info(f"{i}) Score: {score:.4f} | {resp_text}")
248
+ if i == 1 and not is_confident:
249
+ logger.info(" [Low Confidence on Top Response]")
250
 
251
  def _log_validation_summary(self, metrics: Dict[str, Any]):
252
+ """
253
+ Log a summary of all validation metrics and domain performance.
254
+ """
255
+ if not metrics:
256
+ logger.info("No metrics to summarize.")
257
+ return
258
+
259
  logger.info("\n=== Validation Summary ===")
260
 
261
+ # Overall
262
  logger.info("\nOverall Metrics:")
263
  for metric, value in metrics.items():
264
+ # Skip sub-dicts here
265
  if isinstance(value, (int, float)):
266
  logger.info(f"{metric}: {value:.4f}")
267
 
268
+ # Domain performance
269
+ domain_perf = metrics.get('domain_performance', {})
270
  logger.info("\nDomain Performance:")
271
+ for domain, domain_stats in domain_perf.items():
272
  logger.info(f"\n{domain.title()}:")
273
+ for metric, value in domain_stats.items():
274
  logger.info(f" {metric}: {value:.4f}")
275
 
276
+ # Confidence distribution
277
+ conf_analysis = metrics.get('confidence_analysis', {})
278
  logger.info("\nConfidence Distribution:")
279
+ for pct, val in conf_analysis.items():
280
+ logger.info(f" {pct}: {val:.4f}")
conversation_summarizer.py CHANGED
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
9
  @dataclass
10
  class ChatConfig:
11
  max_sequence_length: int = 512
12
- default_top_k: int = 5
13
  chunk_size: int = 512
14
  chunk_overlap: int = 256
15
  min_confidence_score: float = 0.7
 
9
  @dataclass
10
  class ChatConfig:
11
  max_sequence_length: int = 512
12
+ default_top_k: int = 10
13
  chunk_size: int = 512
14
  chunk_overlap: int = 256
15
  min_confidence_score: float = 0.7
cross_encoder_reranker.py CHANGED
@@ -1,19 +1,28 @@
1
  from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
  import tensorflow as tf
3
- from typing import List, Tuple
 
4
 
5
  from logger_config import config_logger
6
  logger = config_logger(__name__)
7
 
8
  class CrossEncoderReranker:
9
  """
10
- Cross-Encoder Re-Ranker: Takes (query, candidate) pairs,
11
- outputs a single relevance score (one logit).
12
  """
13
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
 
 
 
 
 
 
 
 
14
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
16
- # Model outputs shape [batch_size, 1] -> Interpret the logit as relevance score.
17
 
18
  def rerank(
19
  self,
@@ -22,13 +31,21 @@ class CrossEncoderReranker:
22
  max_length: int = 256
23
  ) -> List[float]:
24
  """
25
- Returns a list of re_scores, one for each candidate, indicating
26
- how relevant the candidate is to the query.
 
 
 
 
 
 
 
 
27
  """
28
- # Build (query, candidate) pairs
29
  pair_texts = [(query, candidate) for candidate in candidates]
30
 
31
- # Tokenize the entire batch
32
  encodings = self.tokenizer(
33
  pair_texts,
34
  padding=True,
@@ -37,15 +54,24 @@ class CrossEncoderReranker:
37
  return_tensors="tf"
38
  )
39
 
40
- # Forward pass -> logits shape [batch_size, 1]
41
  outputs = self.model(
42
  input_ids=encodings["input_ids"],
43
  attention_mask=encodings["attention_mask"],
44
- token_type_ids=encodings.get("token_type_ids")
45
  )
46
 
47
- logits = outputs.logits
48
- # Flatten to shape [batch_size]
49
- scores = tf.reshape(logits, [-1]).numpy()
 
 
 
 
 
 
 
 
 
50
 
51
  return scores.tolist()
 
1
  from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
  import tensorflow as tf
3
+ from typing import List
4
+ import numpy as np
5
 
6
  from logger_config import config_logger
7
  logger = config_logger(__name__)
8
 
9
  class CrossEncoderReranker:
10
  """
11
+ Cross-Encoder Re-Ranker that takes (query, candidate) pairs,
12
+ outputs a single relevance score in [0,1].
13
  """
14
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"):
15
+ """
16
+ Initialize the cross-encoder with a pretrained model.
17
+
18
+ Args:
19
+ model_name: Name of a HF cross-encoder model. Must be
20
+ compatible with TFAutoModelForSequenceClassification.
21
+ """
22
+ logger.info(f"Initializing CrossEncoderReranker with {model_name}...")
23
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
24
  self.model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
25
+ logger.info("Cross encoder model loaded successfully.")
26
 
27
  def rerank(
28
  self,
 
31
  max_length: int = 256
32
  ) -> List[float]:
33
  """
34
+ Compute relevance scores for each candidate w.r.t. the query.
35
+
36
+ Args:
37
+ query: User's query text.
38
+ candidates: List of candidate response texts.
39
+ max_length: Max token length for each (query, candidate) pair.
40
+
41
+ Returns:
42
+ A list of float scores in [0,1], one per candidate,
43
+ indicating model's predicted relevance.
44
  """
45
+ # 1) Build (query, candidate) pairs
46
  pair_texts = [(query, candidate) for candidate in candidates]
47
 
48
+ # 2) Tokenize the entire batch
49
  encodings = self.tokenizer(
50
  pair_texts,
51
  padding=True,
 
54
  return_tensors="tf"
55
  )
56
 
57
+ # 3) Forward pass -> logits shape [batch_size, 1]
58
  outputs = self.model(
59
  input_ids=encodings["input_ids"],
60
  attention_mask=encodings["attention_mask"],
61
+ token_type_ids=encodings.get("token_type_ids") # Some models need token_type_ids
62
  )
63
 
64
+ logits = outputs.logits # shape [batch_size, 1]
65
+ # 4) Convert logits -> [0,1] range via sigmoid
66
+ # If the cross-encoder is a single-logit regression to [0,1],
67
+ # this is a typical interpretation.
68
+ scores = tf.nn.sigmoid(logits) # shape [batch_size, 1]
69
+
70
+ # 5) Flatten to a 1D NumPy array of floats
71
+ scores = tf.reshape(scores, [-1])
72
+ scores = scores.numpy().astype(float)
73
+
74
+ # logger.debug(f"Cross-Encoder raw logits: {logits.numpy().flatten().tolist()}")
75
+ # logger.debug(f"Cross-Encoder sigmoid scores: {scores.tolist()}")
76
 
77
  return scores.tolist()
new_iteration/pipeline_config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class PipelineConfig:
5
+ """Minimal pipeline config."""
6
+ max_length: int = 512 # max length if you want to skip long utterances
7
+ min_turns: int = 4 # minimum total turns (user + assistant)
8
+ min_user_words: int = 3 # min words in each user turn
9
+ debug: bool = True # enable debug prints
new_iteration/run_taskmaster_processor.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ from pipeline_config import PipelineConfig
6
+ from taskmaster_processor import TaskmasterProcessor
7
+
8
+ def main():
9
+ # 1) Setup config
10
+ config = PipelineConfig(
11
+ max_length=512,
12
+ min_turns=3,
13
+ min_user_words=3,
14
+ debug=True
15
+ )
16
+
17
+ # 2) Instantiate processor
18
+ base_dir = "datasets/taskmaster"
19
+ processor = TaskmasterProcessor(config)
20
+
21
+ # 3) Load raw dialogues
22
+ dialogues = processor.load_taskmaster_dataset(base_dir=base_dir, max_examples=None)
23
+
24
+ # 4) Filter & convert to final structure
25
+ final_dialogues = processor.filter_and_convert(dialogues)
26
+
27
+ # 5) Save final data
28
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
29
+ output_dir = Path("processed_outputs")
30
+ output_dir.mkdir(parents=True, exist_ok=True)
31
+ out_file = output_dir / f"taskmaster_only_{timestamp}.json"
32
+
33
+ with open(out_file, 'w', encoding='utf-8') as f:
34
+ json.dump(final_dialogues, f, indent=2)
35
+
36
+ print(f"[Taskmaster Only] Kept {len(final_dialogues)} dialogues => {out_file}")
37
+
38
+ if __name__ == "__main__":
39
+ main()
new_iteration/taskmaster_processor.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from pathlib import Path
4
+ from typing import List, Dict, Any, Optional
5
+ from dataclasses import dataclass, field
6
+
7
+ from pipeline_config import PipelineConfig
8
+
9
+ @dataclass
10
+ class TaskmasterDialogue:
11
+ """Structured representation of a Taskmaster-1 dialogue."""
12
+ conversation_id: str
13
+ instruction_id: Optional[str]
14
+ scenario: Optional[str]
15
+ domain: str
16
+ turns: List[Dict[str, Any]] = field(default_factory=list)
17
+
18
+ def validate(self) -> bool:
19
+ """Check if this dialogue has an ID and a list of turns."""
20
+ return bool(self.conversation_id and isinstance(self.turns, list))
21
+
22
+ class TaskmasterProcessor:
23
+ """
24
+ Loads Taskmaster-1 dialogues, extracts domain from scenario,
25
+ filters them, and outputs a final pipeline-friendly format.
26
+ """
27
+ def __init__(self, config: PipelineConfig):
28
+ self.config = config
29
+
30
+ def load_taskmaster_dataset(self, base_dir: str, max_examples: Optional[int] = None) -> List[TaskmasterDialogue]:
31
+ """
32
+ Load and parse Taskmaster JSON for self-dialogs & woz-dialogs (Taskmaster-1).
33
+ Combines scenario text + conversation utterances to detect domain more robustly.
34
+ """
35
+ required_files = {
36
+ "self-dialogs": "self-dialogs.json",
37
+ "woz-dialogs": "woz-dialogs.json",
38
+ "ontology": "ontology.json", # we might not actively use this, but let's expect it
39
+ }
40
+ # Check for missing
41
+ missing = [k for k, v in required_files.items() if not Path(base_dir, v).exists()]
42
+ if missing:
43
+ raise FileNotFoundError(f"Missing Taskmaster files: {missing}")
44
+
45
+ # Load ontology (optional usage)
46
+ ontology_path = Path(base_dir, required_files["ontology"])
47
+ with open(ontology_path, 'r', encoding='utf-8') as f:
48
+ ontology = json.load(f)
49
+ if self.config.debug:
50
+ print(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).")
51
+
52
+ dialogues: List[TaskmasterDialogue] = []
53
+
54
+ # We'll read the 2 main files
55
+ file_keys = ["self-dialogs", "woz-dialogs"]
56
+ for file_key in file_keys:
57
+ file_path = Path(base_dir, required_files[file_key])
58
+ with open(file_path, 'r', encoding='utf-8') as f:
59
+ raw_data = json.load(f)
60
+
61
+ for d in raw_data:
62
+ conversation_id = d.get("conversation_id", "")
63
+ instruction_id = d.get("instruction_id", None)
64
+ scenario_text = d.get("scenario", "") # old scenario approach
65
+
66
+ # Collect utterances -> turns
67
+ utterances = d.get("utterances", [])
68
+ turns = self._process_utterances(utterances)
69
+
70
+ # Instead of only using scenario_text, we combine scenario + turn texts.
71
+ # We'll pass everything to _extract_domain
72
+ domain = self._extract_domain(
73
+ scenario_text,
74
+ turns # pass the entire turn list so we can pick up domain keywords
75
+ )
76
+
77
+ # Create a structured object
78
+ new_dlg = TaskmasterDialogue(
79
+ conversation_id=conversation_id,
80
+ instruction_id=instruction_id,
81
+ scenario=scenario_text,
82
+ domain=domain,
83
+ turns=turns
84
+ )
85
+ dialogues.append(new_dlg)
86
+
87
+ if max_examples and len(dialogues) >= max_examples:
88
+ break
89
+
90
+ if self.config.debug:
91
+ print(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.")
92
+ return dialogues
93
+
94
+ def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str:
95
+ """
96
+ Combine scenario text + all turn texts to detect the domain more robustly.
97
+ """
98
+ # 1) Combine scenario + conversation text
99
+ combined_text = scenario.lower()
100
+ for turn in turns:
101
+ text = turn.get('text', '').strip().lower()
102
+ combined_text += " " + text
103
+
104
+ # 2) Expanded domain patterns (edit or expand as you wish)
105
+ domain_patterns = {
106
+ 'restaurant': r'\b(restaurant|dining|food|reservation|table|menu|cuisine|eat)\b',
107
+ 'movie': r'\b(movie|cinema|film|ticket|showtime|theater)\b',
108
+ 'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff)\b',
109
+ 'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano)\b',
110
+ 'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria)\b',
111
+ 'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change)\b'
112
+ }
113
+
114
+ # 3) Return first matched domain or 'other'
115
+ for dom, pattern in domain_patterns.items():
116
+ if re.search(pattern, combined_text):
117
+ print(f"Matched domain: {dom}")
118
+ return dom
119
+
120
+ print("No domain match, returning 'other'")
121
+ return 'other'
122
+
123
+ def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]:
124
+ """Map speaker to user/assistant, store text."""
125
+ turns = []
126
+ for utt in utterances:
127
+ speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user'
128
+ text = utt.get('text', '').strip()
129
+ turns.append({
130
+ 'speaker': speaker,
131
+ 'text': text
132
+ })
133
+ return turns
134
+
135
+ def filter_and_convert(self, dialogues: List[TaskmasterDialogue]) -> List[Dict]:
136
+ """
137
+ Filter out dialogues that don't meet min turns / min user words,
138
+ then convert them to final pipeline dict:
139
+
140
+ {
141
+ "dialogue_id": "...",
142
+ "domain": "...",
143
+ "turns": [
144
+ {"speaker": "user", "text": "..."},
145
+ ...
146
+ ]
147
+ }
148
+ """
149
+ results = []
150
+ for dlg in dialogues:
151
+ if not dlg.validate():
152
+ continue
153
+
154
+ if len(dlg.turns) < self.config.min_turns:
155
+ continue
156
+
157
+ # Check user-turn min words
158
+ keep = True
159
+ for turn in dlg.turns:
160
+ if turn['speaker'] == 'user':
161
+ word_count = len(turn['text'].split())
162
+ if word_count < self.config.min_user_words:
163
+ keep = False
164
+ break
165
+ if not keep:
166
+ continue
167
+
168
+ pipeline_dlg = {
169
+ 'dialogue_id': dlg.conversation_id,
170
+ 'domain': dlg.domain,
171
+ 'turns': dlg.turns # or you can refine further if needed
172
+ }
173
+ results.append(pipeline_dlg)
174
+
175
+ if self.config.debug:
176
+ print(f"[TaskmasterProcessor] Filtered down to {len(results)} dialogues.")
177
+ return results
prepare_data.py CHANGED
@@ -3,10 +3,13 @@ import sys
3
  import faiss
4
  import json
5
  import pickle
6
- from transformers import AutoTokenizer
 
7
  from tqdm.auto import tqdm
 
 
 
8
  from chatbot_model import ChatbotConfig, EncoderModel
9
- from environment_setup import EnvironmentSetup
10
  from tf_data_pipeline import TFDataPipeline
11
  from logger_config import config_logger
12
 
@@ -14,32 +17,23 @@ logger = config_logger(__name__)
14
 
15
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
 
17
- def cleanup_test_indices(faiss_dir, test_prefix='test_'):
18
- test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
19
- for file in test_files:
20
- file_path = os.path.join(faiss_dir, file)
21
- os.remove(file_path)
22
- logger.info(f"Removed test FAISS index file: {file_path}")
23
-
24
  def main():
25
  # Constants
26
- MODELS_DIR = 'models'
27
- PROCESSED_DATA_DIR = 'processed_outputs'
28
- CACHE_DIR = 'cache'
29
  TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
30
  FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
31
- TF_RECORD_DIR = 'training_data'
32
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
33
- FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
34
- ENVIRONMENT = 'production' # or 'test'
35
- if ENVIRONMENT == 'test':
36
- FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
37
- else:
38
- FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
39
- JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
40
  CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
41
- TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord')
42
- DEBUG_SAMPLES = None
 
 
 
 
43
 
44
  # Ensure output directories exist
45
  os.makedirs(MODELS_DIR, exist_ok=True)
@@ -49,58 +43,120 @@ def main():
49
  os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
50
  os.makedirs(TF_RECORD_DIR, exist_ok=True)
51
 
52
- # Initialize configuration
53
- config = ChatbotConfig()
54
- logger.info(f"Chatbot Configuration: {config}")
 
 
 
 
 
 
 
 
 
55
 
56
- # Initialize tokenizer and add special tokens
57
  try:
58
- tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
59
- logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
60
- tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
61
- logger.info("Added special tokens to tokenizer.")
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
- logger.error(f"Failed to load tokenizer: {e}")
64
  sys.exit(1)
65
 
66
- # Initialize encoder model and resize token embeddings
67
  try:
68
  encoder = EncoderModel(config=config)
69
  logger.info("EncoderModel initialized successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  encoder.pretrained.resize_token_embeddings(len(tokenizer))
71
  logger.info(f"Token embeddings resized to: {len(tokenizer)}")
 
72
  except Exception as e:
73
  logger.error(f"Failed to initialize EncoderModel: {e}")
74
  sys.exit(1)
75
 
76
  # Load JSON dialogues
77
  try:
78
- dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES)
79
- logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
 
 
 
 
80
  except Exception as e:
81
  logger.error(f"Failed to load dialogues: {e}")
82
  sys.exit(1)
83
 
84
  # Load or initialize query_embeddings_cache
85
- try:
86
- if os.path.exists(CACHE_FILE):
 
87
  with open(CACHE_FILE, 'rb') as f:
88
  query_embeddings_cache = pickle.load(f)
89
  logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
90
- else:
91
- query_embeddings_cache = {}
92
- logger.info("Initialized empty query embeddings cache.")
93
- except Exception as e:
94
- logger.error(f"Failed to load or initialize query embeddings cache: {e}")
95
- sys.exit(1)
96
 
97
  # Initialize TFDataPipeline
98
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  data_pipeline = TFDataPipeline(
100
  config=config,
101
  tokenizer=tokenizer,
102
  encoder=encoder,
103
- index_file_path=FAISS_INDEX_PATH,
104
  response_pool=[],
105
  max_length=config.max_context_token_limit,
106
  neg_samples=config.neg_samples,
@@ -114,48 +170,55 @@ def main():
114
  logger.error(f"Failed to initialize TFDataPipeline: {e}")
115
  sys.exit(1)
116
 
117
- # Collect unique assistant responses from dialogues
118
  try:
119
- response_pool = data_pipeline.collect_responses(dialogues)
120
- data_pipeline.response_pool = response_pool
121
- logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
 
 
 
122
  except Exception as e:
123
  logger.error(f"Failed to collect responses: {e}")
124
  sys.exit(1)
125
 
126
- # Compute and add response embeddings to FAISS index
 
127
  try:
128
- logger.info("Computing and adding response embeddings to FAISS index...")
129
- data_pipeline.compute_and_index_response_embeddings()
130
- logger.info("Response embeddings computed and added to FAISS index.")
131
- except Exception as e:
132
- logger.error(f"Failed to compute or add response embeddings: {e}")
133
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
134
 
135
- # Save FAISS index and response pool
136
- try:
137
- logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
138
- faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
139
- logger.info("FAISS index saved successfully.")
140
-
141
- response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json')
142
- with open(response_pool_path, 'w', encoding='utf-8') as f:
143
- json.dump(data_pipeline.response_pool, f, indent=2)
144
- logger.info(f"Response pool saved to {response_pool_path}.")
145
  except Exception as e:
146
- logger.error(f"Failed to save FAISS index: {e}")
147
  sys.exit(1)
148
 
149
- # Prepare and save training data as TFRecords
150
  try:
151
- logger.info("Starting data preparation and saving as TFRecord...")
152
- data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
153
- logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
 
 
 
154
  except Exception as e:
155
  logger.error(f"Failed during data preparation and saving: {e}")
156
  sys.exit(1)
157
 
158
- # Save query embeddings cache
159
  try:
160
  with open(CACHE_FILE, 'wb') as f:
161
  pickle.dump(data_pipeline.query_embeddings_cache, f)
@@ -164,7 +227,7 @@ def main():
164
  logger.error(f"Failed to save query embeddings cache: {e}")
165
  sys.exit(1)
166
 
167
- # Save Tokenizer (including special tokens)
168
  try:
169
  tokenizer.save_pretrained(TOKENIZER_DIR)
170
  logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
@@ -173,6 +236,7 @@ def main():
173
  sys.exit(1)
174
 
175
  logger.info("Data preparation pipeline completed successfully.")
176
-
 
177
  if __name__ == "__main__":
178
- main()
 
3
  import faiss
4
  import json
5
  import pickle
6
+ import tensorflow as tf
7
+ from transformers import AutoTokenizer, TFAutoModel
8
  from tqdm.auto import tqdm
9
+ from pathlib import Path
10
+
11
+ # Your existing modules
12
  from chatbot_model import ChatbotConfig, EncoderModel
 
13
  from tf_data_pipeline import TFDataPipeline
14
  from logger_config import config_logger
15
 
 
17
 
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
 
 
 
 
 
 
 
 
20
  def main():
21
  # Constants
22
+ MODELS_DIR = 'new_iteration/data_prep_iterative_models'
23
+ PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
24
+ CACHE_DIR = 'new_iteration/cache'
25
  TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
26
  FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
27
+ TF_RECORD_DIR = 'new_iteration/training_data'
28
  FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
29
+ JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
 
 
 
 
 
 
30
  CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
31
+ TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
32
+
33
+ # Decide whether to load the **custom** fine-tuned model or just base DistilBERT.
34
+ # True for custom, False for base DistilBERT.
35
+ LOAD_CUSTOM_MODEL = True
36
+ NUM_NEG_SAMPLES = 10
37
 
38
  # Ensure output directories exist
39
  os.makedirs(MODELS_DIR, exist_ok=True)
 
43
  os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
44
  os.makedirs(TF_RECORD_DIR, exist_ok=True)
45
 
46
+ # Initialize config
47
+ config_json = Path(MODELS_DIR) / "config.json"
48
+ if config_json.exists():
49
+ with open(config_json, "r", encoding="utf-8") as f:
50
+ config_dict = json.load(f)
51
+ config = ChatbotConfig.from_dict(config_dict)
52
+ logger.info(f"Loaded ChatbotConfig from {config_json}")
53
+ else:
54
+ config = ChatbotConfig()
55
+ logger.warning("No config.json found. Using default ChatbotConfig.")
56
+
57
+ config.neg_samples = NUM_NEG_SAMPLES
58
 
59
+ # Load or initialize tokenizer
60
  try:
61
+ # If the directory has a valid tokenizer
62
+ if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
63
+ logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
64
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
65
+ else:
66
+ # Initialize from base DistilBERT
67
+ logger.info(f"Loading base tokenizer for {config.pretrained_model}")
68
+ tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
69
+
70
+ # Save to disk
71
+ Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
72
+ tokenizer.save_pretrained(TOKENIZER_DIR)
73
+ logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
74
  except Exception as e:
75
+ logger.error(f"Failed to load or create tokenizer: {e}")
76
  sys.exit(1)
77
 
78
+ # Initialize the encoder
79
  try:
80
  encoder = EncoderModel(config=config)
81
  logger.info("EncoderModel initialized successfully.")
82
+
83
+ if LOAD_CUSTOM_MODEL:
84
+ # Load the DistilBERT submodule from 'shared_encoder'
85
+ shared_encoder_path = Path(MODELS_DIR) / "shared_encoder"
86
+ if shared_encoder_path.exists():
87
+ logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}")
88
+ encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path)
89
+ else:
90
+ logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")
91
+
92
+ # Load top-level custom .weights.h5 (projection, dropout, etc.)
93
+ custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
94
+ if custom_weights_path.exists():
95
+ logger.info(f"Loading custom top-level weights from {custom_weights_path}")
96
+ # Build model layers with a dummy forward pass
97
+ dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
98
+ _ = encoder(dummy_input, training=False)
99
+ encoder.load_weights(str(custom_weights_path))
100
+ logger.info("Custom encoder weights loaded successfully.")
101
+ else:
102
+ logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
103
+ else:
104
+ # Just base DistilBERT with special tokens resized
105
+ logger.info("Using the base DistilBERT without loading custom weights.")
106
+
107
+ # Resize token embeddings in case we added special tokens
108
  encoder.pretrained.resize_token_embeddings(len(tokenizer))
109
  logger.info(f"Token embeddings resized to: {len(tokenizer)}")
110
+
111
  except Exception as e:
112
  logger.error(f"Failed to initialize EncoderModel: {e}")
113
  sys.exit(1)
114
 
115
  # Load JSON dialogues
116
  try:
117
+ if not Path(JSON_TRAINING_DATA_PATH).exists():
118
+ logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.")
119
+ dialogues = []
120
+ else:
121
+ dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None)
122
+ logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
123
  except Exception as e:
124
  logger.error(f"Failed to load dialogues: {e}")
125
  sys.exit(1)
126
 
127
  # Load or initialize query_embeddings_cache
128
+ query_embeddings_cache = {}
129
+ if os.path.exists(CACHE_FILE):
130
+ try:
131
  with open(CACHE_FILE, 'rb') as f:
132
  query_embeddings_cache = pickle.load(f)
133
  logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
134
+ except Exception as e:
135
+ logger.warning(f"Failed to load query embeddings cache: {e}")
136
+ else:
137
+ logger.info("No existing query embeddings cache found. Starting fresh.")
 
 
138
 
139
  # Initialize TFDataPipeline
140
  try:
141
+ # Determine if FAISS index should be loaded or initialized
142
+ if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
143
+ # Load existing index
144
+ logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
145
+ faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
146
+ logger.info("FAISS index loaded successfully.")
147
+ else:
148
+ # Initialize a new FAISS index
149
+ logger.info("No existing FAISS index found. Initializing a new index.")
150
+ dimension = config.embedding_dim # Ensure this matches your encoder's output
151
+ faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity
152
+ logger.info(f"Initialized new FAISS index with dimension {dimension}.")
153
+
154
+ # Initialize TFDataPipeline with the FAISS index
155
  data_pipeline = TFDataPipeline(
156
  config=config,
157
  tokenizer=tokenizer,
158
  encoder=encoder,
159
+ index_file_path=FAISS_INDEX_PRODUCTION_PATH,
160
  response_pool=[],
161
  max_length=config.max_context_token_limit,
162
  neg_samples=config.neg_samples,
 
170
  logger.error(f"Failed to initialize TFDataPipeline: {e}")
171
  sys.exit(1)
172
 
173
+ # 7) Collect unique assistant responses from dialogues
174
  try:
175
+ if dialogues:
176
+ response_pool = data_pipeline.collect_responses_with_domain(dialogues)
177
+ data_pipeline.response_pool = response_pool
178
+ logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
179
+ else:
180
+ logger.warning("No dialogues loaded. response_pool remains empty.")
181
  except Exception as e:
182
  logger.error(f"Failed to collect responses: {e}")
183
  sys.exit(1)
184
 
185
+ # 8) Build the FAISS index with response embeddings
186
+ # Instead of manually computing embeddings, we use the pipeline method
187
  try:
188
+ if data_pipeline.response_pool:
189
+ data_pipeline.build_text_to_domain_map()
190
+ logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...")
191
+ data_pipeline.compute_and_index_response_embeddings()
192
+ logger.info("Response embeddings computed and added to FAISS index.")
193
+
194
+ # Save the updated FAISS index
195
+ data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
196
+
197
+ # Also save the response pool JSON
198
+ response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
199
+ with open(response_pool_path, 'w', encoding='utf-8') as f:
200
+ json.dump(data_pipeline.response_pool, f, indent=2)
201
+ logger.info(f"Response pool saved to {response_pool_path}.")
202
+ else:
203
+ logger.warning("No responses to embed. Skipping FAISS indexing.")
204
 
 
 
 
 
 
 
 
 
 
 
205
  except Exception as e:
206
+ logger.error(f"Failed to compute or add response embeddings: {e}")
207
  sys.exit(1)
208
 
209
+ # 9) Prepare and save training data as TFRecords
210
  try:
211
+ if dialogues:
212
+ logger.info("Starting data preparation and saving as TFRecord...")
213
+ data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
214
+ logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
215
+ else:
216
+ logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.")
217
  except Exception as e:
218
  logger.error(f"Failed during data preparation and saving: {e}")
219
  sys.exit(1)
220
 
221
+ # 10) Save query embeddings cache
222
  try:
223
  with open(CACHE_FILE, 'wb') as f:
224
  pickle.dump(data_pipeline.query_embeddings_cache, f)
 
227
  logger.error(f"Failed to save query embeddings cache: {e}")
228
  sys.exit(1)
229
 
230
+ # Save Tokenizer
231
  try:
232
  tokenizer.save_pretrained(TOKENIZER_DIR)
233
  logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
 
236
  sys.exit(1)
237
 
238
  logger.info("Data preparation pipeline completed successfully.")
239
+
240
+
241
  if __name__ == "__main__":
242
+ main()
response_quality_checker.py CHANGED
@@ -9,27 +9,41 @@ if TYPE_CHECKING:
9
  from tf_data_pipeline import TFDataPipeline
10
 
11
  class ResponseQualityChecker:
12
- """Enhanced quality checking with dynamic thresholds."""
 
 
 
 
 
 
13
 
14
  def __init__(
15
  self,
16
  data_pipeline: 'TFDataPipeline',
17
- confidence_threshold: float = 0.6,
18
  diversity_threshold: float = 0.15,
19
  min_response_length: int = 5,
20
- similarity_cap: float = 0.85 # Renamed from max_similarity_ratio and used in diversity calc
21
  ):
 
 
 
 
 
 
 
 
22
  self.confidence_threshold = confidence_threshold
23
  self.diversity_threshold = diversity_threshold
24
  self.min_response_length = min_response_length
25
  self.similarity_cap = similarity_cap
26
- self.data_pipeline = data_pipeline # Reference to TFDataPipeline
27
 
28
- # Dynamic thresholds based on response patterns
29
  self.thresholds = {
30
- 'relevance': 0.35,
31
- 'length_score': 0.85,
32
- 'score_gap': 0.07
33
  }
34
 
35
  def check_response_quality(
@@ -38,14 +52,14 @@ class ResponseQualityChecker:
38
  responses: List[Tuple[str, float]]
39
  ) -> Dict[str, Any]:
40
  """
41
- Evaluate the quality of responses based on various metrics.
42
 
43
  Args:
44
- query: The user's query
45
- responses: List of (response_text, score) tuples
46
 
47
  Returns:
48
- Dict containing quality metrics and confidence assessment
49
  """
50
  if not responses:
51
  return {
@@ -57,98 +71,282 @@ class ResponseQualityChecker:
57
  'top_3_score_gap': 0.0
58
  }
59
 
60
- # Calculate core metrics
61
- metrics = {
62
- 'response_diversity': self.calculate_diversity(responses),
63
- 'query_response_relevance': self.calculate_relevance(query, responses),
64
- 'response_length_score': np.mean([
65
- self._calculate_length_score(response) for response, _ in responses
66
- ]),
67
- 'top_score': responses[0][1],
68
- 'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
69
- }
70
 
71
- # Determine confidence using thresholds
72
  metrics['is_confident'] = self._determine_confidence(metrics)
73
-
74
  logger.info(f"Quality metrics: {metrics}")
75
  return metrics
76
 
77
  def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
78
- """Calculate relevance as weighted similarity between query and responses."""
 
 
 
79
  if not responses:
80
  return 0.0
81
 
82
- # Get embeddings
83
- query_embedding = self.data_pipeline.encode_query(query)
84
- response_texts = [resp for resp, _ in responses]
85
- response_embeddings = self.data_pipeline.encode_responses(response_texts)
86
 
87
- # Compute similarities
88
- similarities = cosine_similarity([query_embedding], response_embeddings)[0]
 
 
89
 
90
- # Apply decreasing weights for later responses
91
- weights = np.array([1.0 / (i + 1) for i in range(len(similarities))])
92
 
93
- return np.average(similarities, weights=weights)
 
 
 
 
 
 
94
 
95
  def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
96
- """Calculate diversity with length normalization and similarity capping."""
97
- if not responses:
98
- return 0.0
 
 
 
99
 
100
- response_texts = [resp for resp, _ in responses]
101
- embeddings = self.data_pipeline.encode_responses(response_texts)
102
- if len(embeddings) < 2:
103
- return 1.0
104
 
105
- # Calculate pairwise cosine similarities
106
- similarity_matrix = cosine_similarity(embeddings)
107
- np.fill_diagonal(similarity_matrix, 0) # Exclude self-similarity
108
 
109
- # Apply similarity cap
110
- similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
111
 
112
- # Calculate average similarity
113
- sum_similarities = np.sum(similarity_matrix)
114
- num_pairs = len(embeddings) * (len(embeddings) - 1)
115
- avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
116
 
117
- # Diversity is inversely related to average similarity
118
- diversity_score = 1 - avg_similarity
119
- return diversity_score
120
 
121
  def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
122
- """Determine confidence using primary and secondary conditions."""
123
- # Primary conditions (must all be met)
 
124
  primary_conditions = [
125
  metrics['top_score'] >= self.confidence_threshold,
126
  metrics['response_diversity'] >= self.diversity_threshold,
127
  metrics['response_length_score'] >= self.thresholds['length_score']
128
  ]
129
 
130
- # Secondary conditions (majority must be met)
131
  secondary_conditions = [
132
  metrics['query_response_relevance'] >= self.thresholds['relevance'],
133
  metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
134
- metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
135
  ]
136
 
137
- return all(primary_conditions) and sum(secondary_conditions) >= 2
 
138
 
139
- def _calculate_length_score(self, response: str) -> float:
140
- """Calculate length score with penalty for very short or long responses."""
141
- words = len(response.split())
 
 
 
 
 
142
 
 
 
 
 
 
 
 
143
  if words < self.min_response_length:
144
- return words / self.min_response_length
145
- elif words > 50: # Penalty for very long responses
146
- return min(1.0, 50 / words)
147
  return 1.0
148
 
149
  def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
150
- """Calculate average gap between top N scores."""
151
- if len(scores) < top_n + 1:
 
 
152
  return 0.0
153
- gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
154
- return np.mean(gaps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from tf_data_pipeline import TFDataPipeline
10
 
11
  class ResponseQualityChecker:
12
+ """
13
+ Enhanced quality checking that calculates:
14
+ - Relevance between query & responses
15
+ - Diversity among top responses
16
+ - Response length scoring
17
+ - Confidence determination based on multiple thresholds
18
+ """
19
 
20
  def __init__(
21
  self,
22
  data_pipeline: 'TFDataPipeline',
23
+ confidence_threshold: float = 0.45,
24
  diversity_threshold: float = 0.15,
25
  min_response_length: int = 5,
26
+ similarity_cap: float = 0.85,
27
  ):
28
+ """
29
+ Args:
30
+ data_pipeline: Reference to TFDataPipeline for encoding
31
+ confidence_threshold: Minimum top_score for a 'confident' result
32
+ diversity_threshold: Minimum required diversity among top responses
33
+ min_response_length: Minimum words for a decent response
34
+ similarity_cap: Cap on pairwise similarity for diversity calc
35
+ """
36
  self.confidence_threshold = confidence_threshold
37
  self.diversity_threshold = diversity_threshold
38
  self.min_response_length = min_response_length
39
  self.similarity_cap = similarity_cap
40
+ self.data_pipeline = data_pipeline
41
 
42
+ # Additional thresholds for more refined checks
43
  self.thresholds = {
44
+ 'relevance': 0.30, # Slightly relaxed
45
+ 'length_score': 0.80, # Stricter length requirement
46
+ 'score_gap': 0.05 # Gap between top scores
47
  }
48
 
49
  def check_response_quality(
 
52
  responses: List[Tuple[str, float]]
53
  ) -> Dict[str, Any]:
54
  """
55
+ Evaluate the quality of a set of ranked responses for a given query.
56
 
57
  Args:
58
+ query: The user's original query
59
+ responses: List of (response_text, score) sorted by descending score
60
 
61
  Returns:
62
+ Dictionary of metrics, including 'is_confident' and others
63
  """
64
  if not responses:
65
  return {
 
71
  'top_3_score_gap': 0.0
72
  }
73
 
74
+ # 1) Calculate relevant metrics
75
+ metrics = {}
76
+ metrics['response_diversity'] = self.calculate_diversity(responses)
77
+ metrics['query_response_relevance'] = self.calculate_relevance(query, responses)
78
+ metrics['response_length_score'] = self._average_length_score(responses)
79
+ metrics['top_score'] = responses[0][1]
80
+ metrics['top_3_score_gap'] = self._calculate_score_gap([s for _, s in responses], top_n=3)
 
 
 
81
 
82
+ # 2) Determine confidence
83
  metrics['is_confident'] = self._determine_confidence(metrics)
 
84
  logger.info(f"Quality metrics: {metrics}")
85
  return metrics
86
 
87
  def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
88
+ """
89
+ Compute an overall 'relevance' metric between the query and the top responses.
90
+ Uses an exponential transform on the similarity to penalize weaker matches.
91
+ """
92
  if not responses:
93
  return 0.0
94
 
95
+ # Encode query and responses
96
+ query_emb = self.data_pipeline.encode_query(query)
97
+ resp_texts = [r for r, _ in responses]
98
+ resp_embs = self.data_pipeline.encode_responses(resp_texts)
99
 
100
+ # Normalize embeddings
101
+ query_emb = query_emb / (np.linalg.norm(query_emb) + 1e-12)
102
+ resp_norms = np.linalg.norm(resp_embs, axis=1, keepdims=True) + 1e-12
103
+ resp_embs = resp_embs / resp_norms
104
 
105
+ # Cosine similarity
106
+ sims = cosine_similarity([query_emb], resp_embs)[0]
107
 
108
+ # Exponential transform: higher sims remain close to 1, lower sims drop quickly
109
+ sims = np.exp(sims - 1.0)
110
+
111
+ # Weighted average: give heavier weighting to higher-ranked items
112
+ weights = np.exp(-np.arange(len(sims)) / 2.0)
113
+ weighted_avg = np.average(sims, weights=weights)
114
+ return float(weighted_avg)
115
 
116
  def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
117
+ """
118
+ Calculate how 'different' the top responses are from each other.
119
+ Diversity = 1 - avg_cosine_similarity (capped).
120
+ """
121
+ if len(responses) < 2:
122
+ return 1.0 # Single response is trivially 'unique'
123
 
124
+ resp_texts = [r for r, _ in responses]
125
+ embs = self.data_pipeline.encode_responses(resp_texts)
 
 
126
 
127
+ # Pairwise similarity
128
+ sim_matrix = cosine_similarity(embs, embs)
129
+ np.fill_diagonal(sim_matrix, 0.0)
130
 
131
+ # Cap similarity to avoid outliers
132
+ sim_matrix = np.minimum(sim_matrix, self.similarity_cap)
133
 
134
+ # Mean off-diagonal similarity
135
+ sum_sims = np.sum(sim_matrix)
136
+ num_pairs = len(resp_texts) * (len(resp_texts) - 1)
137
+ avg_sim = sum_sims / num_pairs if num_pairs > 0 else 0.0
138
 
139
+ # Invert to get diversity
140
+ return 1.0 - avg_sim
 
141
 
142
  def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
143
+ """
144
+ Decide if we're 'confident' based on multiple metric thresholds.
145
+ """
146
  primary_conditions = [
147
  metrics['top_score'] >= self.confidence_threshold,
148
  metrics['response_diversity'] >= self.diversity_threshold,
149
  metrics['response_length_score'] >= self.thresholds['length_score']
150
  ]
151
 
 
152
  secondary_conditions = [
153
  metrics['query_response_relevance'] >= self.thresholds['relevance'],
154
  metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
155
+ metrics['top_score'] >= (self.confidence_threshold + 0.05) # Extra buffer
156
  ]
157
 
158
+ # Must pass all primary checks, and at least 2 of the 3 secondary
159
+ return all(primary_conditions) and (sum(secondary_conditions) >= 2)
160
 
161
+ def _average_length_score(self, responses: List[Tuple[str, float]]) -> float:
162
+ """
163
+ Compute an average length score across all responses.
164
+ """
165
+ length_scores = []
166
+ for response, _ in responses:
167
+ length_scores.append(self._length_score(response))
168
+ return float(np.mean(length_scores)) if length_scores else 0.0
169
 
170
+ def _length_score(self, text: str) -> float:
171
+ """
172
+ Calculate how well the text meets our length requirement.
173
+ Scores 1.0 if text is >= min_response_length and not too long,
174
+ else it scales down.
175
+ """
176
+ words = len(text.split())
177
  if words < self.min_response_length:
178
+ return words / float(self.min_response_length)
179
+ elif words > 60:
180
+ return max(0.5, 60.0 / words) # Slight penalty for very long
181
  return 1.0
182
 
183
  def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
184
+ """
185
+ Calculate the average gap between consecutive scores in the top N.
186
+ """
187
+ if len(scores) < 2:
188
  return 0.0
189
+ top_n = min(len(scores), top_n)
190
+ gaps = []
191
+ for i in range(top_n - 1):
192
+ gaps.append(scores[i] - scores[i + 1])
193
+ return float(np.mean(gaps)) if gaps else 0.0
194
+
195
+ # import numpy as np
196
+ # from typing import List, Tuple, Dict, Any, TYPE_CHECKING
197
+ # from sklearn.metrics.pairwise import cosine_similarity
198
+
199
+ # from logger_config import config_logger
200
+ # logger = config_logger(__name__)
201
+
202
+ # if TYPE_CHECKING:
203
+ # from tf_data_pipeline import TFDataPipeline
204
+
205
+ # class ResponseQualityChecker:
206
+ # """Enhanced quality checking with dynamic thresholds."""
207
+
208
+ # def __init__(
209
+ # self,
210
+ # data_pipeline: 'TFDataPipeline',
211
+ # confidence_threshold: float = 0.4,
212
+ # diversity_threshold: float = 0.15,
213
+ # min_response_length: int = 5,
214
+ # similarity_cap: float = 0.85 # Renamed from max_similarity_ratio and used in diversity calc
215
+ # ):
216
+ # self.confidence_threshold = confidence_threshold
217
+ # self.diversity_threshold = diversity_threshold
218
+ # self.min_response_length = min_response_length
219
+ # self.similarity_cap = similarity_cap
220
+ # self.data_pipeline = data_pipeline # Reference to TFDataPipeline
221
+
222
+ # # Dynamic thresholds based on response patterns
223
+ # self.thresholds = {
224
+ # 'relevance': 0.35,
225
+ # 'length_score': 0.85,
226
+ # 'score_gap': 0.04
227
+ # }
228
+
229
+ # def check_response_quality(
230
+ # self,
231
+ # query: str,
232
+ # responses: List[Tuple[str, float]]
233
+ # ) -> Dict[str, Any]:
234
+ # """
235
+ # Evaluate the quality of responses based on various metrics.
236
+
237
+ # Args:
238
+ # query: The user's query
239
+ # responses: List of (response_text, score) tuples
240
+
241
+ # Returns:
242
+ # Dict containing quality metrics and confidence assessment
243
+ # """
244
+ # if not responses:
245
+ # return {
246
+ # 'response_diversity': 0.0,
247
+ # 'query_response_relevance': 0.0,
248
+ # 'is_confident': False,
249
+ # 'top_score': 0.0,
250
+ # 'response_length_score': 0.0,
251
+ # 'top_3_score_gap': 0.0
252
+ # }
253
+
254
+ # # Calculate core metrics
255
+ # metrics = {
256
+ # 'response_diversity': self.calculate_diversity(responses),
257
+ # 'query_response_relevance': self.calculate_relevance(query, responses),
258
+ # 'response_length_score': np.mean([
259
+ # self._calculate_length_score(response) for response, _ in responses
260
+ # ]),
261
+ # 'top_score': responses[0][1],
262
+ # 'top_3_score_gap': self._calculate_score_gap([score for _, score in responses], top_n=3)
263
+ # }
264
+
265
+ # # Determine confidence using thresholds
266
+ # metrics['is_confident'] = self._determine_confidence(metrics)
267
+
268
+ # logger.info(f"Quality metrics: {metrics}")
269
+ # return metrics
270
+
271
+ # def calculate_relevance(self, query: str, responses: List[Tuple[str, float]]) -> float:
272
+ # """Calculate relevance with stricter scoring."""
273
+ # if not responses:
274
+ # return 0.0
275
+
276
+ # query_embedding = self.data_pipeline.encode_query(query)
277
+ # response_texts = [resp for resp, _ in responses]
278
+ # response_embeddings = self.data_pipeline.encode_responses(response_texts)
279
+
280
+ # # Normalize embeddings
281
+ # query_embedding = query_embedding / np.linalg.norm(query_embedding)
282
+ # response_embeddings = response_embeddings / np.linalg.norm(response_embeddings, axis=1)[:, np.newaxis]
283
+
284
+ # # Compute similarities with exponential decay for far matches
285
+ # similarities = cosine_similarity([query_embedding], response_embeddings)[0]
286
+ # similarities = np.exp(similarities - 1) # Penalize lower similarities more strongly
287
+
288
+ # # Apply stronger position weighting
289
+ # weights = np.exp(-np.arange(len(similarities)) / 2)
290
+
291
+ # return float(np.average(similarities, weights=weights))
292
+
293
+ # def calculate_diversity(self, responses: List[Tuple[str, float]]) -> float:
294
+ # """Calculate diversity with length normalization and similarity capping."""
295
+ # if not responses:
296
+ # return 0.0
297
+
298
+ # response_texts = [resp for resp, _ in responses]
299
+ # embeddings = self.data_pipeline.encode_responses(response_texts)
300
+ # if len(embeddings) < 2:
301
+ # return 1.0
302
+
303
+ # # Calculate pairwise cosine similarities
304
+ # similarity_matrix = cosine_similarity(embeddings)
305
+ # np.fill_diagonal(similarity_matrix, 0) # Exclude self-similarity
306
+
307
+ # # Apply similarity cap
308
+ # similarity_matrix = np.minimum(similarity_matrix, self.similarity_cap)
309
+
310
+ # # Calculate average similarity
311
+ # sum_similarities = np.sum(similarity_matrix)
312
+ # num_pairs = len(embeddings) * (len(embeddings) - 1)
313
+ # avg_similarity = sum_similarities / num_pairs if num_pairs > 0 else 0.0
314
+
315
+ # # Diversity is inversely related to average similarity
316
+ # diversity_score = 1 - avg_similarity
317
+ # return diversity_score
318
+
319
+ # def _determine_confidence(self, metrics: Dict[str, float]) -> bool:
320
+ # """Determine confidence using primary and secondary conditions."""
321
+ # # Primary conditions (must all be met)
322
+ # primary_conditions = [
323
+ # metrics['top_score'] >= self.confidence_threshold,
324
+ # metrics['response_diversity'] >= self.diversity_threshold,
325
+ # metrics['response_length_score'] >= self.thresholds['length_score']
326
+ # ]
327
+
328
+ # # Secondary conditions (majority must be met)
329
+ # secondary_conditions = [
330
+ # metrics['query_response_relevance'] >= self.thresholds['relevance'],
331
+ # metrics['top_3_score_gap'] >= self.thresholds['score_gap'],
332
+ # metrics['top_score'] >= (self.confidence_threshold * 1.1) # Extra confidence boost
333
+ # ]
334
+
335
+ # return all(primary_conditions) and sum(secondary_conditions) >= 2
336
+
337
+ # def _calculate_length_score(self, response: str) -> float:
338
+ # """Calculate length score with penalty for very short or long responses."""
339
+ # words = len(response.split())
340
+
341
+ # if words < self.min_response_length:
342
+ # return words / self.min_response_length
343
+ # elif words > 50: # Penalty for very long responses
344
+ # return min(1.0, 50 / words)
345
+ # return 1.0
346
+
347
+ # def _calculate_score_gap(self, scores: List[float], top_n: int = 3) -> float:
348
+ # """Calculate average gap between top N scores."""
349
+ # if len(scores) < top_n + 1:
350
+ # return 0.0
351
+ # gaps = [scores[i] - scores[i + 1] for i in range(min(len(scores) - 1, top_n))]
352
+ # return np.mean(gaps)
tf_data_pipeline.py CHANGED
@@ -8,11 +8,12 @@ import math
8
  from tqdm import tqdm
9
  import json
10
  from pathlib import Path
11
- from typing import Union, Optional, List, Tuple, Generator
12
  from transformers import AutoTokenizer
13
  from typing import List, Tuple, Generator
14
  from transformers import AutoTokenizer
15
  from gpu_monitor import GPUMemoryMonitor
 
16
 
17
  from logger_config import config_logger
18
  logger = config_logger(__name__)
@@ -27,7 +28,7 @@ class TFDataPipeline:
27
  response_pool: List[str],
28
  max_length: int,
29
  query_embeddings_cache: dict,
30
- neg_samples: int = 3,
31
  index_type: str = 'IndexFlatIP',
32
  nlist: int = 100,
33
  max_retries: int = 3
@@ -47,6 +48,10 @@ class TFDataPipeline:
47
  self.max_batch_size = 16 if len(response_pool) < 100 else 64
48
  self.memory_monitor = GPUMemoryMonitor()
49
  self.max_retries = max_retries
 
 
 
 
50
 
51
  if os.path.exists(index_file_path):
52
  logger.info(f"Loading existing FAISS index from {index_file_path}...")
@@ -135,21 +140,49 @@ class TFDataPipeline:
135
 
136
  logger.info(f"Loaded {len(dialogues)} dialogues.")
137
  return dialogues
138
-
139
- def collect_responses(self, dialogues: List[dict]) -> List[str]:
140
- """Extract unique assistant responses from dialogues."""
141
- response_set = set()
 
 
 
 
 
142
  for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
 
 
143
  turns = dialogue.get('turns', [])
144
  for turn in turns:
145
  speaker = turn.get('speaker')
146
  text = turn.get('text', '').strip()
147
  if speaker == 'assistant' and text:
148
- # Ensure we don't exclude valid shorter responses
149
  if len(text) <= self.max_length:
150
- response_set.add(text)
151
- logger.info(f"Collected {len(response_set)} unique assistant responses from dialogues.")
152
- return list(response_set)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
155
  """Extract query-response pairs from a dialogue."""
@@ -173,113 +206,101 @@ class TFDataPipeline:
173
 
174
  def compute_and_index_response_embeddings(self):
175
  """
176
- Computes embeddings for the response pool and adds them to the FAISS index with progress bars.
 
177
  """
178
  logger.info("Computing embeddings for the response pool...")
179
 
180
- # Ensure all responses are strings
181
- if not all(isinstance(response, str) for response in self.response_pool):
182
- logger.error("All elements in response_pool must be strings.")
183
- raise ValueError("Invalid data type in response_pool.")
184
-
185
- # Tokenization
186
- logger.info("Tokenizing responses...")
187
- encoded_responses = self.tokenizer(
188
- self.response_pool,
189
- padding=True,
190
- truncation=True,
191
- max_length=self.max_length,
192
- return_tensors='tf'
193
- )
194
- response_ids = encoded_responses['input_ids']
195
-
196
- # Compute embeddings in batches with progress bar
197
- batch_size = getattr(self, 'embedding_batch_size', 64) # Default to 64 if not set
198
- total_responses = len(response_ids)
199
- logger.info(f"Computing embeddings in batches of {batch_size}...")
200
  embeddings = []
201
 
202
- with tqdm(total=total_responses, desc="Computing Embeddings", unit="response") as pbar:
203
- for i in range(0, total_responses, batch_size):
204
- batch_ids = response_ids[i:i + batch_size]
205
- # Compute embeddings
206
- batch_embeddings = self.encoder(batch_ids, training=False).numpy()
207
- # Normalize embeddings for cosine similarity
208
- faiss.normalize_L2(batch_embeddings)
209
- embeddings.append(batch_embeddings)
210
- pbar.update(len(batch_ids))
211
-
212
- if embeddings:
213
- embeddings = np.vstack(embeddings).astype(np.float32)
214
- # Add embeddings to FAISS index with progress bar
215
- logger.info(f"Adding {len(embeddings)} response embeddings to FAISS index...")
216
-
217
- # Determine number of batches for indexing
218
- index_batch_size = getattr(self, 'index_batch_size', 1000) # Adjust as needed
219
- total_embeddings = len(embeddings)
220
- num_index_batches = math.ceil(total_embeddings / index_batch_size)
221
-
222
- with tqdm(total=total_embeddings, desc="Indexing Embeddings", unit="embedding") as pbar_index:
223
- for i in range(0, total_embeddings, index_batch_size):
224
- batch = embeddings[i:i + index_batch_size]
225
- self.index.add(batch)
226
- pbar_index.update(len(batch))
227
-
228
- logger.info("Response embeddings added to FAISS index.")
229
- else:
230
- logger.warning("No embeddings to add to FAISS index.")
231
-
232
- # **Sanity Check:** Verify the number of embeddings in FAISS index
233
- logger.info(f"Total embeddings in FAISS index after addition: {self.index.ntotal}")
234
 
235
  def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
236
- """Find hard negatives for a batch of queries with error handling and retries."""
 
 
 
 
 
 
 
237
  retry_count = 0
238
  total_responses = len(self.response_pool)
239
-
240
- # Set k to be neg_samples + additional candidates to improve negative selection
241
- k = self.neg_samples + 0
242
 
243
  while retry_count < self.max_retries:
244
  try:
245
- # Compute embeddings in sub-batches to manage memory
246
- batch_size = 128 # Example sub-batch size; adjust as needed
247
  query_embeddings = []
248
  for i in range(0, len(queries), batch_size):
249
- sub_queries = queries[i:i + batch_size]
250
- sub_embeddings = np.vstack([
251
- self.query_embeddings_cache[q] for q in sub_queries
252
- ]).astype(np.float32)
253
- faiss.normalize_L2(sub_embeddings)
254
- query_embeddings.append(sub_embeddings)
255
- query_embeddings = np.vstack(query_embeddings)
256
 
257
- # Ensure contiguous memory layout
258
  query_embeddings = np.ascontiguousarray(query_embeddings)
259
 
260
- # Perform FAISS search on CPU
261
  distances, indices = self.index.search(query_embeddings, k)
262
 
263
  all_negatives = []
264
- for query_indices, query, positive in zip(indices, queries, positives):
265
- negatives = []
266
- positive_strip = positive.strip()
267
- seen = {positive_strip}
 
 
 
268
 
 
269
  for idx in query_indices:
270
- if idx >= 0 and idx < total_responses:
271
- candidate = self.response_pool[idx].strip()
272
- if candidate and candidate not in seen:
273
- seen.add(candidate)
274
- negatives.append(candidate)
275
- if len(negatives) >= self.neg_samples:
 
276
  break
277
 
278
- # If not enough negatives are found, pad with a special token
279
- while len(negatives) < self.neg_samples:
280
- negatives.append("<EMPTY_NEGATIVE>") # Use a special token
 
 
 
281
 
282
- all_negatives.append(negatives)
283
 
284
  return all_negatives
285
 
@@ -288,123 +309,236 @@ class TFDataPipeline:
288
  logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
289
  if retry_count == self.max_retries:
290
  logger.error("Max retries reached for hard negative search due to missing embeddings.")
291
- return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries]
292
- # Perform memory cleanup
293
  gc.collect()
294
  if tf.config.list_physical_devices('GPU'):
295
  tf.keras.backend.clear_session()
 
296
  except Exception as e:
297
  retry_count += 1
298
  logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
299
  if retry_count == self.max_retries:
300
  logger.error("Max retries reached for hard negative search.")
301
- return [["<EMPTY_NEGATIVE>"] * self.neg_samples for _ in queries]
302
- # Perform memory cleanup
303
  gc.collect()
304
  if tf.config.list_physical_devices('GPU'):
305
  tf.keras.backend.clear_session()
306
 
307
- def encode_query(self, query: str, context: Optional[List[Tuple[str, str]]] = None) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  """
309
- Encode a query with optional conversation context into an embedding vector.
 
 
 
 
 
 
 
 
 
 
 
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  Args:
312
- query (str): The user query.
313
- context (Optional[List[Tuple[str, str]]]): Optional conversation history as a list of (user, assistant) tuples.
314
 
315
  Returns:
316
- np.ndarray: The normalized embedding vector for the query.
317
  """
318
- # Prepare query with context
319
  if context:
320
- context_str = ' '.join([
321
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {q} "
322
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
323
- for q, r in context[-self.config.max_context_turns:]
324
- ])
325
- query = f"{context_str} {self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]}" \
326
- f" {query}"
 
 
 
 
 
 
 
 
327
  else:
328
- query = f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
329
-
330
- # Tokenize and encode
 
 
 
331
  encodings = self.tokenizer(
332
- [query],
333
  padding='max_length',
334
  truncation=True,
335
  max_length=self.max_length,
336
- return_tensors='np' # Use NumPy arrays for compatibility with FAISS
337
  )
338
  input_ids = encodings['input_ids']
339
 
340
- # Verify token IDs
341
  max_id = np.max(input_ids)
342
- new_vocab_size = len(self.tokenizer)
343
-
344
- if max_id >= new_vocab_size:
345
- logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
346
  raise ValueError("Token ID exceeds vocabulary size.")
347
 
348
- # Get embeddings from the shared encoder
349
  embeddings = self.encoder(input_ids, training=False).numpy()
350
-
351
- # Normalize embeddings for cosine similarity
352
- faiss.normalize_L2(embeddings)
353
-
354
- return embeddings[0] # Return as a 1D array
355
 
356
- def encode_responses(self, responses: List[str], context: Optional[List[Tuple[str, str]]] = None) -> np.ndarray:
 
 
 
 
357
  """
358
- Encode a list of responses into embedding vectors.
359
 
360
  Args:
361
- responses (List[str]): List of response texts.
362
- context (Optional[List[Tuple[str, str]]]): Optional conversation history as a list of (user, assistant) tuples.
363
 
364
  Returns:
365
- np.ndarray: Array of normalized embedding vectors.
366
  """
367
- # Prepare responses with context if provided
 
368
  if context:
369
- prepared_responses = []
370
- for response in responses:
371
- context_str = ' '.join([
372
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {q} "
373
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
374
- for q, r in context[-self.config.max_context_turns:]
375
- ])
376
- full_response = f"{context_str} {self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {response}"
377
- prepared_responses.append(full_response)
 
 
 
 
 
 
 
 
378
  else:
379
- prepared_responses = [
380
- f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
381
- for resp in responses
 
382
  ]
383
-
384
- # Tokenize and encode
385
  encodings = self.tokenizer(
386
- prepared_responses,
387
  padding='max_length',
388
  truncation=True,
389
  max_length=self.max_length,
390
- return_tensors='np' # Use NumPy arrays for compatibility with FAISS
391
  )
392
  input_ids = encodings['input_ids']
393
-
394
- # Verify token IDs
395
  max_id = np.max(input_ids)
396
- new_vocab_size = len(self.tokenizer)
397
-
398
- if max_id >= new_vocab_size:
399
- logger.error(f"Token ID {max_id} exceeds the vocabulary size {new_vocab_size}.")
400
  raise ValueError("Token ID exceeds vocabulary size.")
401
-
402
- # Get embeddings from the shared encoder
403
  embeddings = self.encoder(input_ids, training=False).numpy()
404
-
405
- # Normalize embeddings for cosine similarity
406
- faiss.normalize_L2(embeddings)
407
-
408
  return embeddings.astype('float32')
409
 
410
  def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
 
8
  from tqdm import tqdm
9
  import json
10
  from pathlib import Path
11
+ from typing import Union, Optional, Dict, List, Tuple, Generator
12
  from transformers import AutoTokenizer
13
  from typing import List, Tuple, Generator
14
  from transformers import AutoTokenizer
15
  from gpu_monitor import GPUMemoryMonitor
16
+ import random
17
 
18
  from logger_config import config_logger
19
  logger = config_logger(__name__)
 
28
  response_pool: List[str],
29
  max_length: int,
30
  query_embeddings_cache: dict,
31
+ neg_samples: int = 5,
32
  index_type: str = 'IndexFlatIP',
33
  nlist: int = 100,
34
  max_retries: int = 3
 
48
  self.max_batch_size = 16 if len(response_pool) < 100 else 64
49
  self.memory_monitor = GPUMemoryMonitor()
50
  self.max_retries = max_retries
51
+
52
+ # Build a quick text->domain map for O(1) domain lookups
53
+ self._text_domain_map = {}
54
+ self.build_text_to_domain_map()
55
 
56
  if os.path.exists(index_file_path):
57
  logger.info(f"Loading existing FAISS index from {index_file_path}...")
 
140
 
141
  logger.info(f"Loaded {len(dialogues)} dialogues.")
142
  return dialogues
143
+
144
+ def collect_responses_with_domain(self, dialogues: List[dict]) -> List[Dict[str, str]]:
145
+ """
146
+ Extract unique assistant responses from dialogues, along with the domain.
147
+ Returns a list of dicts: [{'domain': str, 'text': str}, ...]
148
+ """
149
+ response_set = set() # We'll store (domain, text) tuples to keep them unique
150
+ results = []
151
+
152
  for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
153
+ # domain is stored at the top level in your new JSON format
154
+ domain = dialogue.get('domain', 'other')
155
  turns = dialogue.get('turns', [])
156
  for turn in turns:
157
  speaker = turn.get('speaker')
158
  text = turn.get('text', '').strip()
159
  if speaker == 'assistant' and text:
 
160
  if len(text) <= self.max_length:
161
+ # Use a tuple as a "set" key to ensure uniqueness
162
+ key = (domain, text)
163
+ if key not in response_set:
164
+ response_set.add(key)
165
+ results.append({
166
+ "domain": domain,
167
+ "text": text
168
+ })
169
+
170
+ logger.info(f"Collected {len(results)} unique assistant responses from dialogues.")
171
+ return results
172
+ # def collect_responses(self, dialogues: List[dict]) -> List[str]:
173
+ # """Extract unique assistant responses from dialogues."""
174
+ # response_set = set()
175
+ # for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"):
176
+ # turns = dialogue.get('turns', [])
177
+ # for turn in turns:
178
+ # speaker = turn.get('speaker')
179
+ # text = turn.get('text', '').strip()
180
+ # if speaker == 'assistant' and text:
181
+ # # Ensure we don't exclude valid shorter responses
182
+ # if len(text) <= self.max_length:
183
+ # response_set.add(text)
184
+ # logger.info(f"Collected {len(response_set)} unique assistant responses from dialogues.")
185
+ # return list(response_set)
186
 
187
  def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]:
188
  """Extract query-response pairs from a dialogue."""
 
206
 
207
  def compute_and_index_response_embeddings(self):
208
  """
209
+ Computes embeddings for the response pool and adds them to the FAISS index.
210
+ self.response_pool is now List[Dict[str, str]] with keys "domain" and "text".
211
  """
212
  logger.info("Computing embeddings for the response pool...")
213
 
214
+ # Extract just the assistant text
215
+ texts = [resp["text"] for resp in self.response_pool]
216
+ logger.debug(f"Total texts to embed: {len(texts)}")
217
+
218
+ batch_size = getattr(self, 'embedding_batch_size', 64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  embeddings = []
220
 
221
+ with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar:
222
+ for i in range(0, len(texts), batch_size):
223
+ batch_texts = texts[i:i+batch_size]
224
+ encodings = self.tokenizer(
225
+ batch_texts,
226
+ padding=True,
227
+ truncation=True,
228
+ max_length=self.max_length,
229
+ return_tensors='tf'
230
+ )
231
+ batch_embeds = self.encoder(encodings['input_ids'], training=False).numpy()
232
+
233
+ embeddings.append(batch_embeds)
234
+ pbar.update(len(batch_texts))
235
+
236
+ # Combine embeddings and add to FAISS
237
+ all_embeddings = np.vstack(embeddings).astype(np.float32)
238
+ logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...")
239
+ self.index.add(all_embeddings)
240
+
241
+ # For debugging or repeated usage, you might store them:
242
+ self.response_embeddings = all_embeddings
243
+ logger.info(f"FAISS index now has {self.index.ntotal} vectors.")
 
 
 
 
 
 
 
 
 
244
 
245
  def _find_hard_negatives_batch(self, queries: List[str], positives: List[str]) -> List[List[str]]:
246
+ """
247
+ Find hard negatives for a batch of queries using FAISS search.
248
+ Falls back to random negatives if we run out of tries or can't find enough.
249
+ Uses domain-based fallback if possible.
250
+ """
251
+ import random
252
+ import gc
253
+
254
  retry_count = 0
255
  total_responses = len(self.response_pool)
256
+ k = self.neg_samples # Number of negatives to retrieve from FAISS
257
+ batch_size = 128
 
258
 
259
  while retry_count < self.max_retries:
260
  try:
261
+ # 1) Build query embeddings from the cache
 
262
  query_embeddings = []
263
  for i in range(0, len(queries), batch_size):
264
+ sub_queries = queries[i : i + batch_size]
265
+ sub_embeds = [self.query_embeddings_cache[q] for q in sub_queries]
266
+ sub_embeds = np.vstack(sub_embeds).astype(np.float32)
267
+ faiss.normalize_L2(sub_embeds) # If not already normalized
268
+ query_embeddings.append(sub_embeds)
 
 
269
 
270
+ query_embeddings = np.vstack(query_embeddings)
271
  query_embeddings = np.ascontiguousarray(query_embeddings)
272
 
273
+ # 2) Perform FAISS search
274
  distances, indices = self.index.search(query_embeddings, k)
275
 
276
  all_negatives = []
277
+ # For each query, find domain from the corresponding positive if possible
278
+ for query_indices, query_text, pos_text in zip(indices, queries, positives):
279
+ negative_list = []
280
+ seen = {pos_text.strip()}
281
+
282
+ # Attempt to detect the domain of the positive text
283
+ domain_of_positive = self._detect_domain_for_text(pos_text)
284
 
285
+ # Collect hard negatives from FAISS
286
  for idx in query_indices:
287
+ if 0 <= idx < total_responses:
288
+ candidate_dict = self.response_pool[idx] # e.g. {domain, text}
289
+ candidate_text = candidate_dict["text"].strip()
290
+ if candidate_text and candidate_text not in seen:
291
+ seen.add(candidate_text)
292
+ negative_list.append(candidate_text)
293
+ if len(negative_list) >= self.neg_samples:
294
  break
295
 
296
+ # If not enough negatives, fallback to random domain-based
297
+ if len(negative_list) < self.neg_samples:
298
+ needed = self.neg_samples - len(negative_list)
299
+ # Pass in domain_of_positive to your updated `_get_random_negatives(...)`
300
+ random_negatives = self._get_random_negatives(needed, seen, domain=domain_of_positive)
301
+ negative_list.extend(random_negatives)
302
 
303
+ all_negatives.append(negative_list)
304
 
305
  return all_negatives
306
 
 
309
  logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}")
310
  if retry_count == self.max_retries:
311
  logger.error("Max retries reached for hard negative search due to missing embeddings.")
312
+ return self._fallback_negatives(queries, positives, reason="key_error")
 
313
  gc.collect()
314
  if tf.config.list_physical_devices('GPU'):
315
  tf.keras.backend.clear_session()
316
+
317
  except Exception as e:
318
  retry_count += 1
319
  logger.warning(f"Hard negative search attempt {retry_count} failed: {e}")
320
  if retry_count == self.max_retries:
321
  logger.error("Max retries reached for hard negative search.")
322
+ return self._fallback_negatives(queries, positives, reason="generic_error")
 
323
  gc.collect()
324
  if tf.config.list_physical_devices('GPU'):
325
  tf.keras.backend.clear_session()
326
 
327
+ def _detect_domain_for_text(self, text: str) -> Optional[str]:
328
+ """
329
+ O(1) domain detection by looking up text in our dictionary.
330
+ Returns the domain if found, else None.
331
+ """
332
+ stripped_text = text.strip()
333
+ return self._text_domain_map.get(stripped_text, None)
334
+
335
+ def _get_random_negatives(self, needed: int, seen: set, domain: Optional[str] = None) -> List[str]:
336
+ """
337
+ Return a list of 'needed' random negative texts from the same domain if possible,
338
+ otherwise fallback to all-domain.
339
+ """
340
+ # 1) Filter response_pool for domain if provided
341
+ if domain:
342
+ domain_texts = [r["text"] for r in self.response_pool if r["domain"] == domain]
343
+ # fallback to entire set if insufficient domain_texts
344
+ if len(domain_texts) < needed * 2: # pick some threshold
345
+ domain_texts = [r["text"] for r in self.response_pool]
346
+ else:
347
+ domain_texts = [r["text"] for r in self.response_pool]
348
+
349
+ negatives = []
350
+ tries = 0
351
+ max_tries = needed * 10
352
+ while len(negatives) < needed and tries < max_tries:
353
+ tries += 1
354
+ candidate = random.choice(domain_texts).strip()
355
+ if candidate and candidate not in seen:
356
+ negatives.append(candidate)
357
+ seen.add(candidate)
358
+
359
+ # If still not enough, we do the best we can
360
+ if len(negatives) < needed:
361
+ logger.warning(f"Could not find enough domain-based random negatives; needed {needed}, got {len(negatives)}.")
362
+
363
+ return negatives
364
+
365
+ def _fallback_negatives(self, queries: List[str], positives: List[str], reason: str) -> List[List[str]]:
366
+ """
367
+ Called if FAISS fails or embeddings are missing.
368
+ We use entirely random negatives for each query, ignoring FAISS,
369
+ but still attempt domain-based selection if possible.
370
  """
371
+ logger.error(f"Falling back to random negatives due to: {reason}")
372
+ all_negatives = []
373
+
374
+ for pos_text in positives:
375
+ # Build a 'seen' set with the positive
376
+ seen = {pos_text.strip()}
377
+
378
+ # Attempt to detect the domain of the positive text
379
+ domain_of_positive = self._detect_domain_for_text(pos_text)
380
+
381
+ # Use domain-based random negatives if available
382
+ negs = self._get_random_negatives(self.neg_samples, seen, domain=domain_of_positive)
383
+ all_negatives.append(negs)
384
 
385
+ return all_negatives
386
+
387
+ def build_text_to_domain_map(self):
388
+ """
389
+ Build an O(1) lookup dict: text -> domain,
390
+ so we don't have to scan the entire self.response_pool each time.
391
+ """
392
+ self._text_domain_map = {}
393
+
394
+ for item in self.response_pool:
395
+ # e.g., item = {"domain": "restaurant", "text": "some text..."}
396
+ stripped_text = item["text"].strip()
397
+ domain = item["domain"]
398
+
399
+ # If the same text appears multiple times with the same domain, no big deal.
400
+ # If it appears with a different domain, you can decide how to handle collisions.
401
+ if stripped_text in self._text_domain_map:
402
+ existing_domain = self._text_domain_map[stripped_text]
403
+ if existing_domain != domain:
404
+ # Log a warning or decide on a policy:
405
+ logger.warning(
406
+ f"Collision detected: text '{stripped_text}' found with domains "
407
+ f"'{existing_domain}' and '{domain}'. Keeping the first."
408
+ )
409
+ # By default, keep the first domain or overwrite. We'll skip overwriting:
410
+ continue
411
+ else:
412
+ # Insert into the dict
413
+ self._text_domain_map[stripped_text] = domain
414
+
415
+ logger.info(f"Built text->domain map with {len(self._text_domain_map)} unique text entries.")
416
+
417
+ def encode_query(
418
+ self,
419
+ query: str,
420
+ context: Optional[List[Tuple[str, str]]] = None
421
+ ) -> np.ndarray:
422
+ """
423
+ Encode a user query (and optional conversation context) into an embedding vector.
424
+
425
  Args:
426
+ query: The user query.
427
+ context: Optional conversation history as a list of (user_text, assistant_text).
428
 
429
  Returns:
430
+ np.ndarray of shape [embedding_dim], typically L2-normalized already.
431
  """
432
+ # 1) Prepare context (if any) by concatenating user/assistant pairs
433
  if context:
434
+ # Take the last N turns
435
+ relevant_history = context[-self.config.max_context_turns:]
436
+ context_str_parts = []
437
+ for (u_text, a_text) in relevant_history:
438
+ context_str_parts.append(
439
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {u_text} "
440
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
441
+ )
442
+ context_str = " ".join(context_str_parts)
443
+
444
+ # Append the user's new query
445
+ full_query = (
446
+ f"{context_str} "
447
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
448
+ )
449
  else:
450
+ # Just a single user turn
451
+ full_query = (
452
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
453
+ )
454
+
455
+ # 2) Tokenize
456
  encodings = self.tokenizer(
457
+ [full_query],
458
  padding='max_length',
459
  truncation=True,
460
  max_length=self.max_length,
461
+ return_tensors='np' # to keep it compatible with FAISS
462
  )
463
  input_ids = encodings['input_ids']
464
 
465
+ # 3) Check for out-of-vocab IDs
466
  max_id = np.max(input_ids)
467
+ vocab_size = len(self.tokenizer)
468
+ if max_id >= vocab_size:
469
+ logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
 
470
  raise ValueError("Token ID exceeds vocabulary size.")
471
 
472
+ # 4) Get embeddings from the model
473
  embeddings = self.encoder(input_ids, training=False).numpy()
474
+ # Typically your custom model already L2-normalizes the final embeddings.
475
+
476
+ # 5) Return the single embedding as 1D array
477
+ return embeddings[0]
 
478
 
479
+ def encode_responses(
480
+ self,
481
+ responses: List[str],
482
+ context: Optional[List[Tuple[str, str]]] = None
483
+ ) -> np.ndarray:
484
  """
485
+ Encode multiple response texts into embedding vectors.
486
 
487
  Args:
488
+ responses: List of raw assistant responses.
489
+ context: Optional conversation context (last N turns).
490
 
491
  Returns:
492
+ np.ndarray of shape [num_responses, embedding_dim].
493
  """
494
+ # 1) If you want to incorporate context into response encoding
495
+ # Usually for retrieval we might skip this. But if you want it:
496
  if context:
497
+ relevant_history = context[-self.config.max_context_turns:]
498
+ prepared = []
499
+ for resp in responses:
500
+ context_str_parts = []
501
+ for (u_text, a_text) in relevant_history:
502
+ context_str_parts.append(
503
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {u_text} "
504
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {a_text}"
505
+ )
506
+ context_str = " ".join(context_str_parts)
507
+
508
+ # Now treat resp as an assistant turn
509
+ full_resp = (
510
+ f"{context_str} "
511
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {resp}"
512
+ )
513
+ prepared.append(full_resp)
514
  else:
515
+ # By default, just mark each response as from the assistant
516
+ prepared = [
517
+ f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {r}"
518
+ for r in responses
519
  ]
520
+
521
+ # 2) Tokenize
522
  encodings = self.tokenizer(
523
+ prepared,
524
  padding='max_length',
525
  truncation=True,
526
  max_length=self.max_length,
527
+ return_tensors='np'
528
  )
529
  input_ids = encodings['input_ids']
530
+
531
+ # 3) Check for out-of-vocab
532
  max_id = np.max(input_ids)
533
+ vocab_size = len(self.tokenizer)
534
+ if max_id >= vocab_size:
535
+ logger.error(f"Token ID {max_id} exceeds tokenizer vocab size {vocab_size}.")
 
536
  raise ValueError("Token ID exceeds vocabulary size.")
537
+
538
+ # 4) Model forward
539
  embeddings = self.encoder(input_ids, training=False).numpy()
540
+ # Typically already L2-normalized if your final layer is normalized.
541
+
 
 
542
  return embeddings.astype('float32')
543
 
544
  def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32):
validate_model.py CHANGED
@@ -1,16 +1,17 @@
1
  import os
2
  import json
 
3
  from chatbot_model import ChatbotConfig, RetrievalChatbot
4
  from response_quality_checker import ResponseQualityChecker
5
  from chatbot_validator import ChatbotValidator
6
  from plotter import Plotter
7
  from environment_setup import EnvironmentSetup
8
-
9
  from logger_config import config_logger
 
10
  logger = config_logger(__name__)
11
 
12
  def run_interactive_chat(chatbot, quality_checker):
13
- """Separate function for interactive chat loop"""
14
  while True:
15
  try:
16
  user_input = input("You: ")
@@ -18,7 +19,7 @@ def run_interactive_chat(chatbot, quality_checker):
18
  print("\nAssistant: Goodbye!")
19
  break
20
 
21
- if user_input.lower() in ['quit', 'exit', 'bye']:
22
  print("Assistant: Goodbye!")
23
  break
24
 
@@ -26,69 +27,97 @@ def run_interactive_chat(chatbot, quality_checker):
26
  query=user_input,
27
  conversation_history=None,
28
  quality_checker=quality_checker,
29
- top_k=5
30
  )
31
 
32
  print(f"Assistant: {response}")
33
 
34
- if metrics.get('is_confident', False):
 
35
  print("\nAlternative responses:")
36
  for resp, score in candidates[1:4]:
37
  print(f"Score: {score:.4f} - {resp}")
38
  else:
39
  print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
40
 
41
- # TODO:
42
  def validate_chatbot():
43
  # Initialize environment
44
  env = EnvironmentSetup()
45
  env.initialize()
46
 
47
- MODEL_DIR = 'models'
48
- FAISS_INDICES_DIR = os.path.join(MODEL_DIR, 'faiss_indices')
49
- FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
50
- FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
51
- RESPONSE_POOL_PRODUCTION_PATH = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
52
- RESPONSE_POOL_TEST_PATH = FAISS_INDEX_TEST_PATH.replace('.index', '_responses.json')
53
- ENVIRONMENT = 'production' # or 'test'
54
- if ENVIRONMENT == 'test':
55
  FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
56
- RESPONSE_POOL_PATH = RESPONSE_POOL_TEST_PATH
57
  else:
58
  FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
59
- RESPONSE_POOL_PATH = RESPONSE_POOL_PRODUCTION_PATH
60
-
61
- # Load config
62
- config = ChatbotConfig()
63
-
64
- # Initialize RetrievalChatbot in 'inference' mode
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  try:
66
- chatbot = RetrievalChatbot(config=config, mode='inference')
67
- logger.info("RetrievalChatbot initialized in 'inference' mode.")
68
  except Exception as e:
69
- logger.error(f"Failed to initialize RetrievalChatbot: {e}")
70
  return
71
 
72
- # Ensure FAISS index and response pool are accessible, then load
73
  if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
74
  logger.error("FAISS index or response pool file is missing.")
75
  return
76
 
 
77
  try:
 
78
  chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
79
  logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
80
-
81
- with open(RESPONSE_POOL_PATH, 'r', encoding='utf-8') as f:
 
 
 
 
 
82
  chatbot.data_pipeline.response_pool = json.load(f)
83
  logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
84
-
 
 
 
 
 
 
 
85
  chatbot.data_pipeline.validate_faiss_index()
86
  logger.info("FAISS index and response pool validated successfully.")
 
87
  except Exception as e:
88
- logger.error(f"Failed to load FAISS index: {e}")
89
  return
90
 
91
- # Initialize ResponseQualityChecker and ChatbotValidator
92
  quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
93
  validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
94
  logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
@@ -101,17 +130,17 @@ def validate_chatbot():
101
  logger.error(f"Validation process failed: {e}")
102
  return
103
 
104
- # Plot validation_metrics
105
- try:
106
- plotter = Plotter(save_dir=env.training_dirs['plots'])
107
- plotter.plot_validation_metrics(validation_metrics)
108
- logger.info("Validation metrics plotted successfully.")
109
- except Exception as e:
110
- logger.error(f"Failed to plot validation metrics: {e}")
111
 
112
- # Run interactive chat
113
- logger.info("\nStarting interactive chat session...")
114
- run_interactive_chat(chatbot, quality_checker)
115
 
116
- if __name__ == '__main__':
117
- validate_chatbot()
 
1
  import os
2
  import json
3
+
4
  from chatbot_model import ChatbotConfig, RetrievalChatbot
5
  from response_quality_checker import ResponseQualityChecker
6
  from chatbot_validator import ChatbotValidator
7
  from plotter import Plotter
8
  from environment_setup import EnvironmentSetup
 
9
  from logger_config import config_logger
10
+
11
  logger = config_logger(__name__)
12
 
13
  def run_interactive_chat(chatbot, quality_checker):
14
+ """Separate function for interactive chat loop."""
15
  while True:
16
  try:
17
  user_input = input("You: ")
 
19
  print("\nAssistant: Goodbye!")
20
  break
21
 
22
+ if user_input.lower() in ["quit", "exit", "bye"]:
23
  print("Assistant: Goodbye!")
24
  break
25
 
 
27
  query=user_input,
28
  conversation_history=None,
29
  quality_checker=quality_checker,
30
+ top_k=10
31
  )
32
 
33
  print(f"Assistant: {response}")
34
 
35
+ # Show alternative responses if confident
36
+ if metrics.get("is_confident", False):
37
  print("\nAlternative responses:")
38
  for resp, score in candidates[1:4]:
39
  print(f"Score: {score:.4f} - {resp}")
40
  else:
41
  print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
42
 
 
43
  def validate_chatbot():
44
  # Initialize environment
45
  env = EnvironmentSetup()
46
  env.initialize()
47
 
48
+ MODEL_DIR = "new_iteration/data_prep_iterative_models"
49
+ FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
50
+ FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
51
+ FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
52
+
53
+ # Toggle 'production' or 'test' env
54
+ ENVIRONMENT = "production"
55
+ if ENVIRONMENT == "test":
56
  FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
57
+ RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json")
58
  else:
59
  FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
60
+ RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
61
+
62
+ # Load the config
63
+ config_path = os.path.join(MODEL_DIR, "config.json")
64
+ if os.path.exists(config_path):
65
+ with open(config_path, "r", encoding="utf-8") as f:
66
+ config_dict = json.load(f)
67
+ config = ChatbotConfig.from_dict(config_dict)
68
+ logger.info(f"Loaded ChatbotConfig from {config_path}")
69
+ else:
70
+ config = ChatbotConfig()
71
+ logger.warning("No config.json found. Using default ChatbotConfig.")
72
+
73
+ # Load RetrievalChatbot in 'inference' mode using the classmethod
74
+ # This:
75
+ # - Loads shared_encoder submodule
76
+ # - Loads encoder_custom_weights.weights.h5
77
+ # - Loads tokenizer
78
+ # - Prepares the model for inference
79
  try:
80
+ chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
81
+ logger.info("RetrievalChatbot loaded in 'inference' mode successfully.")
82
  except Exception as e:
83
+ logger.error(f"Failed to load RetrievalChatbot: {e}")
84
  return
85
 
86
+ # Confirm FAISS index & response pool exist
87
  if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
88
  logger.error("FAISS index or response pool file is missing.")
89
  return
90
 
91
+ # Load specific FAISS index and response pool
92
  try:
93
+ # Even though load_model might auto-load an index, we override here with the specific file
94
  chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
95
  logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
96
+
97
+ print("FAISS dimensions:", chatbot.data_pipeline.index.d)
98
+ print("FAISS index type:", type(chatbot.data_pipeline.index))
99
+ print("FAISS index total vectors:", chatbot.data_pipeline.index.ntotal)
100
+ print("FAISS is_trained:", chatbot.data_pipeline.index.is_trained)
101
+
102
+ with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
103
  chatbot.data_pipeline.response_pool = json.load(f)
104
  logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
105
+
106
+ print("Sample from response pool (first 10):")
107
+ for i, response in enumerate(chatbot.data_pipeline.response_pool[:10]):
108
+ print(f"{i}: {response}")
109
+
110
+ print("\nTotal responses in pool:", len(chatbot.data_pipeline.response_pool))
111
+
112
+ # Validate dimension consistency
113
  chatbot.data_pipeline.validate_faiss_index()
114
  logger.info("FAISS index and response pool validated successfully.")
115
+
116
  except Exception as e:
117
+ logger.error(f"Failed to load or validate FAISS index: {e}")
118
  return
119
 
120
+ # Init QualityChecker and Validator
121
  quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
122
  validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
123
  logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
 
130
  logger.error(f"Validation process failed: {e}")
131
  return
132
 
133
+ # Plot metrics
134
+ # try:
135
+ # plotter = Plotter(save_dir=env.training_dirs["plots"])
136
+ # plotter.plot_validation_metrics(validation_metrics)
137
+ # logger.info("Validation metrics plotted successfully.")
138
+ # except Exception as e:
139
+ # logger.error(f"Failed to plot validation metrics: {e}")
140
 
141
+ # Run interactive chat loop
142
+ # logger.info("\nStarting interactive chat session...")
143
+ # run_interactive_chat(chatbot, quality_checker)
144
 
145
+ if __name__ == "__main__":
146
+ validate_chatbot()