JoeArmani commited on
Commit
c7c1b4e
·
1 Parent(s): 64e7c31

chat refinements

Browse files
chatbot_config.py CHANGED
@@ -4,19 +4,23 @@ from typing import Dict
4
 
5
  @dataclass
6
  class ChatbotConfig:
7
- """RetrievalChatbot Config"""
8
- max_context_token_limit: int = 512
9
- embedding_dim: int = 384 # Match Sentence Transformer dimension
 
 
10
  learning_rate: float = 0.0005
11
  min_text_length: int = 3
12
- max_context_turns: int = 20
13
  pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
14
  cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
15
  summarizer_model: str = 't5-small'
16
  embedding_batch_size: int = 64
17
  search_batch_size: int = 64
18
  max_batch_size: int = 64
 
19
  max_retries: int = 3
 
20
 
21
  def to_dict(self) -> Dict:
22
  """Convert config to dictionary."""
 
4
 
5
  @dataclass
6
  class ChatbotConfig:
7
+ """
8
+ All config params for the chatbot
9
+ """
10
+ max_context_length: int = 512
11
+ embedding_dim: int = 384 # Sentence Transformer dim
12
  learning_rate: float = 0.0005
13
  min_text_length: int = 3
14
+ max_context_turns: int = 24
15
  pretrained_model: str = 'sentence-transformers/all-MiniLM-L6-v2'
16
  cross_encoder_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
17
  summarizer_model: str = 't5-small'
18
  embedding_batch_size: int = 64
19
  search_batch_size: int = 64
20
  max_batch_size: int = 64
21
+ neg_samples: int = 10
22
  max_retries: int = 3
23
+ nlist: int = 100
24
 
25
  def to_dict(self) -> Dict:
26
  """Convert config to dictionary."""
chatbot_model.py CHANGED
@@ -22,6 +22,9 @@ from tqdm.auto import tqdm
22
 
23
  absl.logging.set_verbosity(absl.logging.WARNING)
24
  logger = config_logger(__name__)
 
 
 
25
 
26
  class RetrievalChatbot(DeviceAwareModel):
27
  """
@@ -59,7 +62,6 @@ class RetrievalChatbot(DeviceAwareModel):
59
  tokenizer=self.tokenizer,
60
  encoder=self.encoder,
61
  response_pool=[],
62
- max_length=self.config.max_context_token_limit,
63
  query_embeddings_cache={},
64
  )
65
 
@@ -96,7 +98,7 @@ class RetrievalChatbot(DeviceAwareModel):
96
  return Summarizer(
97
  tokenizer=self.tokenizer,
98
  model_name=self.config.summarizer_model,
99
- max_summary_length=self.config.max_context_token_limit // 4,
100
  device=self.device,
101
  max_summary_rounds=2
102
  )
@@ -218,7 +220,6 @@ class RetrievalChatbot(DeviceAwareModel):
218
  ) -> List[Tuple[str, float]]:
219
  """
220
  Retrieve top-k responses using FAISS and cross-encoder re-ranking.
221
-
222
  Args:
223
  query: The user's input text.
224
  top_k: Number of responses to return.
@@ -226,7 +227,6 @@ class RetrievalChatbot(DeviceAwareModel):
226
  summarizer: Optional summarizer for long queries.
227
  summarize_threshold: Threshold to summarize long queries.
228
  boost_factor: Factor to boost scores for keyword matches.
229
-
230
  Returns:
231
  List of (response_text, final_score).
232
  """
@@ -241,18 +241,27 @@ class RetrievalChatbot(DeviceAwareModel):
241
 
242
  # Detect domain for query
243
  detected_domain = self.detect_domain_from_query(query)
 
244
 
245
- # Step 1: Retrieve candidates from FAISS
246
- logger.info("Retrieving initial candidates from FAISS...")
247
  faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10)
248
 
249
  if not faiss_candidates:
250
  logger.warning("No candidates retrieved from FAISS.")
251
  return []
252
 
253
- # Step 2: Re-rank candidates using Cross-Encoder
254
- logger.info("Re-ranking candidates using Cross-Encoder...")
255
- texts = [item[0] for item in faiss_candidates] # Extract response texts
 
 
 
 
 
 
 
 
256
  faiss_scores = [item[1] for item in faiss_candidates]
257
 
258
  if reranker is None:
@@ -277,9 +286,10 @@ class RetrievalChatbot(DeviceAwareModel):
277
 
278
  final_candidates.append((resp_text, length_adjusted_score))
279
 
280
- # Step 3: Sort and return top-k results
281
  final_candidates.sort(key=lambda x: x[1], reverse=True)
282
- logger.info(f"Returning top-{top_k} re-ranked responses.")
 
283
  return final_candidates[:top_k]
284
 
285
  def extract_keywords(self, query: str) -> List[str]:
@@ -323,7 +333,7 @@ class RetrievalChatbot(DeviceAwareModel):
323
 
324
  def detect_domain_from_query(self, query: str) -> str:
325
  """
326
- Detect the domain of the query based on keywords. Used for boosting FAISS search.
327
  """
328
  domain_patterns = {
329
  'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
@@ -348,85 +358,6 @@ class RetrievalChatbot(DeviceAwareModel):
348
  pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
349
  return bool(re.match(pattern, text.strip()))
350
 
351
- def faiss_search(
352
- self,
353
- query: str,
354
- domain: str = 'other',
355
- top_k: int = 10,
356
- boost_factor: float = 1.15
357
- ) -> List[Tuple[str, float]]:
358
- """
359
- Retrieve top-k responses from the FAISS index (IndexFlatIP) given a user query.
360
- Args:
361
- query (str): The user input text.
362
- domain (str): The detected domain from possible domains: ['restaurant', 'movie', 'ride_share', 'coffee', 'pizza', 'auto', 'other']
363
- top_k (int): Number of top results to return.
364
- boost_factor (float, optional): Factor to boost scores for keyword matches.
365
- Returns:
366
- List[Tuple[str, float]]: List of (response_text, similarity) sorted by descending similarity.
367
- """
368
- # Encode the query
369
- q_emb = self.data_pipeline.encode_query(query)
370
- q_emb_np = q_emb.reshape(1, -1).astype('float32')
371
-
372
- # Search the index
373
- distances, indices = self.data_pipeline.index.search(q_emb_np, top_k * 10)
374
-
375
- # IndexFlatIP: 'distances' are inner products (cosine similarities for normalized vectors).
376
- candidates = []
377
- for rank, idx in enumerate(indices[0]):
378
- if idx < 0:
379
- continue
380
- text_dict = self.data_pipeline.response_pool[idx]
381
- text = text_dict.get('text', '').strip()
382
- cand_domain = text_dict.get('domain', 'other')
383
- score = distances[0][rank]
384
-
385
- # Skip purely numeric or extremely short text (fewer than 3 words):
386
- words = text.split()
387
- if len(words) < 4:
388
- continue
389
- if self.is_numeric_response(text):
390
- continue
391
-
392
- candidates.append((text, cand_domain, score))
393
-
394
- if not candidates:
395
- logger.warning("No valid candidates found after initial numeric/length filtering.")
396
- return []
397
-
398
- # Sort candidates by score descending
399
- candidates.sort(key=lambda x: x[2], reverse=True)
400
-
401
- # Filter in-domain responses
402
- in_domain = [c for c in candidates if c[1] == domain]
403
- if not in_domain:
404
- logger.info(f"No in-domain responses found for '{domain}'. Using all candidates.")
405
- in_domain = candidates
406
-
407
- # Boost responses containing query keywords
408
- query_keywords = self.extract_keywords(query)
409
- boosted = []
410
- for (resp_text, resp_domain, score) in in_domain:
411
- new_score = score
412
- # If the domain is known AND the response text shares any query keywords, boost it
413
- if query_keywords and any(kw in resp_text.lower() for kw in query_keywords):
414
- new_score *= boost_factor
415
-
416
- # Apply length penalty/bonus
417
- new_score = self.length_adjust_score(resp_text, new_score)
418
-
419
- boosted.append((resp_text, new_score))
420
-
421
- # Sort boosted responses
422
- boosted.sort(key=lambda x: x[1], reverse=True)
423
-
424
- # Debug logging (see FAISS responses)
425
- # for resp, score in boosted[:100]:
426
- # logger.debug(f"Candidate: '{resp}' with score {score}")
427
-
428
- return boosted[:top_k]
429
-
430
  def introduction_message(self) -> None:
431
  """Print an introduction message to introduce the chatbot."""
432
  print(
@@ -453,7 +384,7 @@ class RetrievalChatbot(DeviceAwareModel):
453
  print("\nAssistant: Goodbye!")
454
  break
455
 
456
- response, candidates, metrics = self.chat(
457
  query=user_input,
458
  conversation_history=None,
459
  quality_checker=quality_checker,
@@ -466,7 +397,7 @@ class RetrievalChatbot(DeviceAwareModel):
466
  print("\n Alternative responses:")
467
  for resp, score in candidates[1:4]:
468
  print(f" Score: {score:.4f} - {resp}")
469
- else:
470
  print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
471
 
472
  def chat(
@@ -504,10 +435,10 @@ class RetrievalChatbot(DeviceAwareModel):
504
 
505
  # if uncertain, ask for clarification
506
  if not is_confident or top_response_score < 0.5:
507
- return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics)
508
 
509
  # Return the top response
510
- return responses[0][0], responses, metrics
511
 
512
  return get_response(self, query)
513
 
@@ -535,27 +466,6 @@ class RetrievalChatbot(DeviceAwareModel):
535
  conversation_parts.append(f"{USER_TOKEN} {query}")
536
  return "\n".join(conversation_parts)
537
 
538
- # def _build_conversation_context(
539
- # self,
540
- # query: str,
541
- # conversation_history: Optional[List[Tuple[str, str]]]
542
- # ) -> str:
543
- # """
544
- # Build conversation context string from conversation history.
545
- # """
546
- # if not conversation_history:
547
- # return f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}"
548
-
549
- # conversation_parts = []
550
- # for user_txt, assistant_txt in conversation_history:
551
- # conversation_parts.extend([
552
- # f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {user_txt}",
553
- # f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<ASSISTANT>')]} {assistant_txt}"
554
- # ])
555
-
556
- # conversation_parts.append(f"{self.tokenizer.additional_special_tokens[self.tokenizer.additional_special_tokens.index('<USER>')]} {query}")
557
- # return "\n".join(conversation_parts)
558
-
559
  def train_model(
560
  self,
561
  tfrecord_file_path: str,
@@ -633,7 +543,7 @@ class RetrievalChatbot(DeviceAwareModel):
633
  logger.info("Using fixed learning rate.")
634
 
635
  # Dummy step to force initialization
636
- dummy_input = tf.zeros((1, self.config.max_context_token_limit), dtype=tf.int32)
637
  with tf.GradientTape() as tape:
638
  dummy_output = self.encoder(dummy_input)
639
  dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32)
@@ -747,7 +657,7 @@ class RetrievalChatbot(DeviceAwareModel):
747
  logger.info(f"New validation pairs: {val_size}")
748
 
749
  dataset = dataset.map(
750
- lambda x: parse_tfrecord_fn(x, self.config.max_context_token_limit, self.data_pipeline.neg_samples),
751
  num_parallel_calls=tf.data.AUTOTUNE
752
  )
753
 
 
22
 
23
  absl.logging.set_verbosity(absl.logging.WARNING)
24
  logger = config_logger(__name__)
25
+ logger.setLevel("WARNING")
26
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
27
+ tqdm(disable=True)
28
 
29
  class RetrievalChatbot(DeviceAwareModel):
30
  """
 
62
  tokenizer=self.tokenizer,
63
  encoder=self.encoder,
64
  response_pool=[],
 
65
  query_embeddings_cache={},
66
  )
67
 
 
98
  return Summarizer(
99
  tokenizer=self.tokenizer,
100
  model_name=self.config.summarizer_model,
101
+ max_summary_length=self.config.max_context_length // 4,
102
  device=self.device,
103
  max_summary_rounds=2
104
  )
 
220
  ) -> List[Tuple[str, float]]:
221
  """
222
  Retrieve top-k responses using FAISS and cross-encoder re-ranking.
 
223
  Args:
224
  query: The user's input text.
225
  top_k: Number of responses to return.
 
227
  summarizer: Optional summarizer for long queries.
228
  summarize_threshold: Threshold to summarize long queries.
229
  boost_factor: Factor to boost scores for keyword matches.
 
230
  Returns:
231
  List of (response_text, final_score).
232
  """
 
241
 
242
  # Detect domain for query
243
  detected_domain = self.detect_domain_from_query(query)
244
+ #logger.info(f"Detected domain: {detected_domain}")
245
 
246
+ # Retrieve candidates from FAISS
247
+ #logger.info("Retrieving initial candidates from FAISS...")
248
  faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10)
249
 
250
  if not faiss_candidates:
251
  logger.warning("No candidates retrieved from FAISS.")
252
  return []
253
 
254
+ # Filter out-of-domain responses
255
+ if detected_domain != 'other':
256
+ in_domain_candidates = [c for c in faiss_candidates if c[0]["domain"] == detected_domain]
257
+ if in_domain_candidates:
258
+ faiss_candidates = in_domain_candidates
259
+ else:
260
+ logger.info(f"No in-domain responses found for '{query}'. Using all candidates.")
261
+
262
+ # Re-rank candidates using Cross-Encoder
263
+ #logger.info("Re-ranking candidates using Cross-Encoder...")
264
+ texts = [item[0]["text"] for item in faiss_candidates] # Extract response texts
265
  faiss_scores = [item[1] for item in faiss_candidates]
266
 
267
  if reranker is None:
 
286
 
287
  final_candidates.append((resp_text, length_adjusted_score))
288
 
289
+ # Sort and return top-k results
290
  final_candidates.sort(key=lambda x: x[1], reverse=True)
291
+ #logger.info(f"Returning top-{top_k} re-ranked responses.")
292
+
293
  return final_candidates[:top_k]
294
 
295
  def extract_keywords(self, query: str) -> List[str]:
 
333
 
334
  def detect_domain_from_query(self, query: str) -> str:
335
  """
336
+ Detect the domain of the query based on keywords. Used for filtering FAISS search.
337
  """
338
  domain_patterns = {
339
  'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b',
 
358
  pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$'
359
  return bool(re.match(pattern, text.strip()))
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  def introduction_message(self) -> None:
362
  """Print an introduction message to introduce the chatbot."""
363
  print(
 
384
  print("\nAssistant: Goodbye!")
385
  break
386
 
387
+ response, candidates, metrics, top_response_score = self.chat(
388
  query=user_input,
389
  conversation_history=None,
390
  quality_checker=quality_checker,
 
397
  print("\n Alternative responses:")
398
  for resp, score in candidates[1:4]:
399
  print(f" Score: {score:.4f} - {resp}")
400
+ elif top_response_score < 0.7:
401
  print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
402
 
403
  def chat(
 
435
 
436
  # if uncertain, ask for clarification
437
  if not is_confident or top_response_score < 0.5:
438
+ return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics, top_response_score)
439
 
440
  # Return the top response
441
+ return responses[0][0], responses, metrics, top_response_score
442
 
443
  return get_response(self, query)
444
 
 
466
  conversation_parts.append(f"{USER_TOKEN} {query}")
467
  return "\n".join(conversation_parts)
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  def train_model(
470
  self,
471
  tfrecord_file_path: str,
 
543
  logger.info("Using fixed learning rate.")
544
 
545
  # Dummy step to force initialization
546
+ dummy_input = tf.zeros((1, self.config.max_context_length), dtype=tf.int32)
547
  with tf.GradientTape() as tape:
548
  dummy_output = self.encoder(dummy_input)
549
  dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32)
 
657
  logger.info(f"New validation pairs: {val_size}")
658
 
659
  dataset = dataset.map(
660
+ lambda x: parse_tfrecord_fn(x, self.config.max_context_length, self.data_pipeline.neg_samples),
661
  num_parallel_calls=tf.data.AUTOTUNE
662
  )
663
 
cross_encoder_reranker.py CHANGED
@@ -42,7 +42,8 @@ class CrossEncoderReranker:
42
  padding=True,
43
  truncation=True,
44
  max_length=max_length,
45
- return_tensors="tf"
 
46
  )
47
 
48
  # Forward pass, logits shape [batch_size, 1]
 
42
  padding=True,
43
  truncation=True,
44
  max_length=max_length,
45
+ return_tensors="tf",
46
+ verbose=False
47
  )
48
 
49
  # Forward pass, logits shape [batch_size, 1]
run_chatbot_chat.py CHANGED
@@ -1,12 +1,19 @@
1
  import os
2
  import json
3
- from chatbot_model import RetrievalChatbot
4
  from chatbot_config import ChatbotConfig
 
 
 
5
  from response_quality_checker import ResponseQualityChecker
6
  from environment_setup import EnvironmentSetup
7
  from logger_config import config_logger
8
 
9
  logger = config_logger(__name__)
 
 
 
 
10
 
11
  def run_chatbot_chat():
12
  env = EnvironmentSetup()
@@ -37,38 +44,55 @@ def run_chatbot_chat():
37
  config = ChatbotConfig()
38
  logger.warning("No config.json found. Using default ChatbotConfig.")
39
 
40
- # Load RetrievalChatbot in 'inference' mode
41
  try:
42
- chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
 
43
  except Exception as e:
44
- logger.error(f"Failed to load RetrievalChatbot: {e}")
45
- return
46
-
47
- # Confirm FAISS index & response pool exist
48
- if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
49
- logger.error("FAISS index or response pool file is missing.")
50
  return
51
-
52
  # Load FAISS index and response pool
53
  try:
54
- chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
57
- chatbot.data_pipeline.response_pool = json.load(f)
58
- logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
 
 
59
  # Validate dimension consistency
60
- chatbot.data_pipeline.validate_faiss_index()
61
-
62
  except Exception as e:
63
  logger.error(f"Failed to load or validate FAISS index: {e}")
64
  return
65
-
66
- # Init QualityChecker and Validator
67
- quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
68
-
69
- # Run interactive chat loop
70
- logger.info("\nStarting interactive chat session...")
71
- chatbot.run_interactive_chat(quality_checker)
 
 
 
72
 
73
  if __name__ == "__main__":
74
  run_chatbot_chat()
 
1
  import os
2
  import json
3
+ from tqdm.auto import tqdm
4
  from chatbot_config import ChatbotConfig
5
+ from chatbot_model import RetrievalChatbot
6
+ from sentence_transformers import SentenceTransformer
7
+ from tf_data_pipeline import TFDataPipeline
8
  from response_quality_checker import ResponseQualityChecker
9
  from environment_setup import EnvironmentSetup
10
  from logger_config import config_logger
11
 
12
  logger = config_logger(__name__)
13
+ logger.setLevel("WARNING")
14
+
15
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
16
+ tqdm(disable=True)
17
 
18
  def run_chatbot_chat():
19
  env = EnvironmentSetup()
 
44
  config = ChatbotConfig()
45
  logger.warning("No config.json found. Using default ChatbotConfig.")
46
 
47
+ # Init SentenceTransformer
48
  try:
49
+ encoder = SentenceTransformer(config.pretrained_model)
50
+ logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
51
  except Exception as e:
52
+ logger.error(f"Failed to load SentenceTransformer: {e}")
 
 
 
 
 
53
  return
54
+
55
  # Load FAISS index and response pool
56
  try:
57
+ # Initialize TFDataPipeline
58
+ data_pipeline = TFDataPipeline(
59
+ config=config,
60
+ tokenizer=encoder.tokenizer,
61
+ encoder=encoder,
62
+ response_pool=[],
63
+ query_embeddings_cache={},
64
+ index_type='IndexFlatIP',
65
+ faiss_index_file_path=FAISS_INDEX_PATH
66
+ )
67
+
68
+ if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
69
+ logger.error("FAISS index or response pool file is missing.")
70
+ return
71
+
72
+ data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
73
+ logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
74
+
75
  with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
76
+ data_pipeline.response_pool = json.load(f)
77
+ logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
78
+ logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")
79
+
80
  # Validate dimension consistency
81
+ data_pipeline.validate_faiss_index()
82
+ logger.info("FAISS index and response pool validated successfully.")
83
  except Exception as e:
84
  logger.error(f"Failed to load or validate FAISS index: {e}")
85
  return
86
+
87
+ # Run interactive chat
88
+ try:
89
+ chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
90
+ quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
91
+
92
+ logger.info("\nStarting interactive chat session...")
93
+ chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=False)
94
+ except Exception as e:
95
+ logger.error(f"Interactive chat session failed: {e}")
96
 
97
  if __name__ == "__main__":
98
  run_chatbot_chat()
run_chatbot_validation.py CHANGED
@@ -44,9 +44,8 @@ def run_chatbot_validation():
44
 
45
  # Init SentenceTransformer
46
  try:
47
- model_name = "sentence-transformers/all-MiniLM-L6-v2" # Replace with your chosen model
48
- encoder = SentenceTransformer(model_name)
49
- logger.info(f"Loaded SentenceTransformer model: {model_name}")
50
  except Exception as e:
51
  logger.error(f"Failed to load SentenceTransformer: {e}")
52
  return
@@ -108,18 +107,10 @@ def run_chatbot_validation():
108
  # Run interactive chat loop
109
  try:
110
  logger.info("\nStarting interactive chat session...")
111
- while True:
112
- user_input = input("You: ")
113
- if user_input.lower() in ["exit", "quit"]:
114
- logger.info("Exiting chat session.")
115
- break
116
-
117
- responses = data_pipeline.retrieve_responses(user_input, top_k=3)
118
- print("Top Responses:")
119
- for i, (response, score) in enumerate(responses, start=1):
120
- print(f"{i}. {response} (Score: {score:.4f})")
121
- except KeyboardInterrupt:
122
- logger.info("Interactive chat session interrupted by user.")
123
-
124
  if __name__ == "__main__":
125
  run_chatbot_validation()
 
44
 
45
  # Init SentenceTransformer
46
  try:
47
+ encoder = SentenceTransformer(config.pretrained_model)
48
+ logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
 
49
  except Exception as e:
50
  logger.error(f"Failed to load SentenceTransformer: {e}")
51
  return
 
107
  # Run interactive chat loop
108
  try:
109
  logger.info("\nStarting interactive chat session...")
110
+ chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=True)
111
+ except Exception as e:
112
+ logger.error(f"Interactive chat session failed: {e}")
113
+
114
+
 
 
 
 
 
 
 
 
115
  if __name__ == "__main__":
116
  run_chatbot_validation()
tf_data_pipeline.py CHANGED
@@ -6,7 +6,7 @@ import h5py
6
  import math
7
  import random
8
  import gc
9
- from tqdm import tqdm
10
  import json
11
  from pathlib import Path
12
  from typing import Union, Optional, Dict, List, Tuple, Generator
@@ -28,31 +28,25 @@ class TFDataPipeline:
28
  encoder: SentenceTransformer,
29
  response_pool: List[str],
30
  query_embeddings_cache: dict,
31
- model_name: str = 'sentence-transformers/all-MiniLM-L6-v2',
32
- max_length: int = 512,
33
- neg_samples: int = 10,
34
  index_type: str = 'IndexFlatIP',
35
  faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
36
- dimension: int = 384,
37
- nlist: int = 100,
38
- max_retries: int = 3
39
  ):
40
  self.config = config
41
  self.tokenizer = tokenizer
42
  self.encoder = encoder
43
- self.model = SentenceTransformer(model_name)
44
  self.faiss_index_file_path = faiss_index_file_path
45
  self.response_pool = response_pool
46
- self.max_length = max_length
47
- self.neg_samples = neg_samples
48
  self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
49
- self.dimension = config.embedding_dim
50
  self.index_type = index_type
51
- self.nlist = nlist
52
- self.embedding_batch_size = 16 if len(response_pool) < 100 else 64
53
- self.search_batch_size = 16 if len(response_pool) < 100 else 64
54
- self.max_batch_size = 16 if len(response_pool) < 100 else 64
55
- self.max_retries = max_retries
 
 
 
56
 
57
  # Build text -> domain map for O(1) domain lookups (hard negative sampling)
58
  self._text_domain_map = {}
@@ -159,7 +153,7 @@ class TFDataPipeline:
159
  speaker = turn.get('speaker')
160
  text = turn.get('text', '').strip()
161
  if speaker == 'assistant' and text:
162
- if len(text) <= self.max_length:
163
  # Use tuple as set key to ensure uniqueness
164
  key = (domain, text)
165
  if key not in response_set:
@@ -388,7 +382,7 @@ class TFDataPipeline:
388
  # f"Collision detected: text '{stripped_text}' found with domains "
389
  # f"'{existing_domain}' and '{domain}'. Keeping the first."
390
  # )
391
- # By default, keep the first domain or overwrite. We'll skip overwriting:
392
  continue
393
  else:
394
  # Insert into the dict
@@ -434,7 +428,7 @@ class TFDataPipeline:
434
  prepared,
435
  padding='max_length',
436
  truncation=True,
437
- max_length=self.max_length,
438
  return_tensors='np'
439
  )
440
  input_ids = encodings['input_ids']
@@ -454,23 +448,19 @@ class TFDataPipeline:
454
  def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
455
  """
456
  Retrieve top-k responses for a query using FAISS.
457
-
458
- Args:
459
- query: User's query text.
460
- top_k: Number of responses to return.
461
-
462
- Returns:
463
- List of tuples (response text, similarity score).
464
  """
465
  query_embedding = self.encode_query(query).reshape(1, -1).astype("float32")
466
  distances, indices = self.index.search(query_embedding, top_k)
467
 
468
  results = []
469
- for idx, dist in zip(indices[0], distances[0]):
 
 
 
470
  if idx < 0:
471
  continue
472
  response = self.response_pool[idx]
473
- results.append((response["text"], dist))
474
 
475
  return results
476
 
@@ -496,7 +486,7 @@ class TFDataPipeline:
496
  for dialogue in batch_dialogues:
497
  pairs = self._extract_pairs_from_dialogue(dialogue)
498
  for query, positive in pairs:
499
- if len(query) <= self.max_length and len(positive) <= self.max_length:
500
  queries.append(query)
501
  positives.append(positive)
502
 
@@ -524,14 +514,14 @@ class TFDataPipeline:
524
  try:
525
  encoded_queries = self.tokenizer.batch_encode_plus(
526
  queries,
527
- max_length=self.config.max_context_token_limit,
528
  truncation=True,
529
  padding='max_length',
530
  return_tensors='tf'
531
  )
532
  encoded_positives = self.tokenizer.batch_encode_plus(
533
  positives,
534
- max_length=self.config.max_context_token_limit,
535
  truncation=True,
536
  padding='max_length',
537
  return_tensors='tf'
@@ -547,7 +537,7 @@ class TFDataPipeline:
547
  flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
548
  encoded_negatives = self.tokenizer.batch_encode_plus(
549
  flattened_negatives,
550
- max_length=self.config.max_context_token_limit,
551
  truncation=True,
552
  padding='max_length',
553
  return_tensors='tf'
@@ -555,7 +545,7 @@ class TFDataPipeline:
555
 
556
  # Reshape to [num_queries, num_negatives, max_length]
557
  num_negatives = self.config.neg_samples
558
- reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_token_limit)
559
  except Exception as e:
560
  logger.error(f"Error during negatives tokenization: {e}")
561
  pbar.update(1)
@@ -600,7 +590,7 @@ class TFDataPipeline:
600
  batch_queries,
601
  padding=True,
602
  truncation=True,
603
- max_length=self.max_length,
604
  return_tensors='tf'
605
  )
606
  batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
@@ -667,14 +657,14 @@ class TFDataPipeline:
667
  # Use tf.py_function, limit parallelism
668
  q_ids, p_ids, n_ids = tf.py_function(
669
  func=self._tokenize_triple_py,
670
- inp=[q, p, n, tf.constant(self.max_length), tf.constant(self.neg_samples)],
671
  Tout=[tf.int32, tf.int32, tf.int32]
672
  )
673
 
674
  # Set shape info for the output tensors
675
- q_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
676
- p_ids.set_shape([None, self.max_length]) # [batch_size, max_length]
677
- n_ids.set_shape([None, self.neg_samples, self.max_length]) # [batch_size, neg_samples, max_length]
678
 
679
  return q_ids, p_ids, n_ids
680
 
 
6
  import math
7
  import random
8
  import gc
9
+ from tqdm.auto import tqdm
10
  import json
11
  from pathlib import Path
12
  from typing import Union, Optional, Dict, List, Tuple, Generator
 
28
  encoder: SentenceTransformer,
29
  response_pool: List[str],
30
  query_embeddings_cache: dict,
 
 
 
31
  index_type: str = 'IndexFlatIP',
32
  faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index',
 
 
 
33
  ):
34
  self.config = config
35
  self.tokenizer = tokenizer
36
  self.encoder = encoder
37
+ self.model = SentenceTransformer(config.pretrained_model)
38
  self.faiss_index_file_path = faiss_index_file_path
39
  self.response_pool = response_pool
 
 
40
  self.query_embeddings_cache = query_embeddings_cache # In-memory cache for embeddings
 
41
  self.index_type = index_type
42
+ self.neg_samples = config.neg_samples
43
+ self.nlist = config.nlist
44
+ self.dimension = config.embedding_dim
45
+ self.max_context_length = config.max_context_length
46
+ self.embedding_batch_size = config.embedding_batch_size
47
+ self.search_batch_size = config.search_batch_size
48
+ self.max_batch_size = config.max_batch_size
49
+ self.max_retries = config.max_retries
50
 
51
  # Build text -> domain map for O(1) domain lookups (hard negative sampling)
52
  self._text_domain_map = {}
 
153
  speaker = turn.get('speaker')
154
  text = turn.get('text', '').strip()
155
  if speaker == 'assistant' and text:
156
+ if len(text) <= self.max_context_length:
157
  # Use tuple as set key to ensure uniqueness
158
  key = (domain, text)
159
  if key not in response_set:
 
382
  # f"Collision detected: text '{stripped_text}' found with domains "
383
  # f"'{existing_domain}' and '{domain}'. Keeping the first."
384
  # )
385
+ # By default, keep the first domain or overwrite. Skip overwriting:
386
  continue
387
  else:
388
  # Insert into the dict
 
428
  prepared,
429
  padding='max_length',
430
  truncation=True,
431
+ max_length=self.max_context_length,
432
  return_tensors='np'
433
  )
434
  input_ids = encodings['input_ids']
 
448
  def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
449
  """
450
  Retrieve top-k responses for a query using FAISS.
 
 
 
 
 
 
 
451
  """
452
  query_embedding = self.encode_query(query).reshape(1, -1).astype("float32")
453
  distances, indices = self.index.search(query_embedding, top_k)
454
 
455
  results = []
456
+ for idx, dist in tqdm(
457
+ zip(indices[0], distances[0]),
458
+ disable=True # Silence tqdm
459
+ ):
460
  if idx < 0:
461
  continue
462
  response = self.response_pool[idx]
463
+ results.append((response, dist))
464
 
465
  return results
466
 
 
486
  for dialogue in batch_dialogues:
487
  pairs = self._extract_pairs_from_dialogue(dialogue)
488
  for query, positive in pairs:
489
+ if len(query) <= self.max_context_length and len(positive) <= self.max_context_length:
490
  queries.append(query)
491
  positives.append(positive)
492
 
 
514
  try:
515
  encoded_queries = self.tokenizer.batch_encode_plus(
516
  queries,
517
+ max_length=self.config.max_context_length,
518
  truncation=True,
519
  padding='max_length',
520
  return_tensors='tf'
521
  )
522
  encoded_positives = self.tokenizer.batch_encode_plus(
523
  positives,
524
+ max_length=self.config.max_context_length,
525
  truncation=True,
526
  padding='max_length',
527
  return_tensors='tf'
 
537
  flattened_negatives = [neg for sublist in hard_negatives for neg in sublist]
538
  encoded_negatives = self.tokenizer.batch_encode_plus(
539
  flattened_negatives,
540
+ max_length=self.config.max_context_length,
541
  truncation=True,
542
  padding='max_length',
543
  return_tensors='tf'
 
545
 
546
  # Reshape to [num_queries, num_negatives, max_length]
547
  num_negatives = self.config.neg_samples
548
+ reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_length)
549
  except Exception as e:
550
  logger.error(f"Error during negatives tokenization: {e}")
551
  pbar.update(1)
 
590
  batch_queries,
591
  padding=True,
592
  truncation=True,
593
+ max_length=self.max_context_length,
594
  return_tensors='tf'
595
  )
596
  batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy()
 
657
  # Use tf.py_function, limit parallelism
658
  q_ids, p_ids, n_ids = tf.py_function(
659
  func=self._tokenize_triple_py,
660
+ inp=[q, p, n, tf.constant(self.max_context_length), tf.constant(self.neg_samples)],
661
  Tout=[tf.int32, tf.int32, tf.int32]
662
  )
663
 
664
  # Set shape info for the output tensors
665
+ q_ids.set_shape([None, self.max_context_length]) # [batch_size, max_length]
666
+ p_ids.set_shape([None, self.max_context_length]) # [batch_size, max_length]
667
+ n_ids.set_shape([None, self.neg_samples, self.max_context_length]) # [batch_size, neg_samples, max_length]
668
 
669
  return q_ids, p_ids, n_ids
670