JoeArmani commited on
Commit
a763857
·
1 Parent(s): 4aec49f

implement chat features

Browse files
chatbot_model.py CHANGED
@@ -546,6 +546,48 @@ class RetrievalChatbot(DeviceAwareModel):
546
 
547
  return boosted[:top_k]
548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  def chat(
550
  self,
551
  query: str,
@@ -562,7 +604,7 @@ class RetrievalChatbot(DeviceAwareModel):
562
  conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
563
 
564
  # Retrieve and re-rank
565
- results = self_arg.retrieve_responses(
566
  query=conversation_str,
567
  top_k=top_k,
568
  reranker=self_arg.reranker,
@@ -571,13 +613,20 @@ class RetrievalChatbot(DeviceAwareModel):
571
  )
572
 
573
  # Handle low confidence or empty responses
574
- if not results:
575
  return ("I'm sorry, but I couldn't find a relevant response.", [], {})
576
 
577
- metrics = quality_checker.check_response_quality(query_arg, results)
578
- if not metrics.get('is_confident', False):
579
- return ("I need more information to provide a good answer. Could you please clarify?", results, metrics)
580
- return results[0][0], results, metrics
 
 
 
 
 
 
 
581
 
582
  return get_response(self, query)
583
 
 
546
 
547
  return boosted[:top_k]
548
 
549
+ def introduction_message(self) -> None:
550
+ """Print an introduction message to introduce the chatbot."""
551
+ print(
552
+ "\nAssistant: Hello! I'm a simple chatbot assistant. I've been trained to answer "
553
+ "basic questions about topics including restaurants, movies, ride sharing, coffee, and pizza. "
554
+ "Please ask me a question and I'll do my best to assist you."
555
+ )
556
+
557
+ def run_interactive_chat(self, quality_checker, show_alternatives=False):
558
+ """Separate function for interactive chat loop."""
559
+
560
+ # Chatbot introduction
561
+ self.introduction_message()
562
+
563
+ # Chat loop
564
+ while True:
565
+ try:
566
+ user_input = input("\nYou: ")
567
+ except (KeyboardInterrupt, EOFError):
568
+ print("\nAssistant: Goodbye!")
569
+ break
570
+
571
+ if user_input.lower() in ["quit", "exit", "bye"]:
572
+ print("\nAssistant: Goodbye!")
573
+ break
574
+
575
+ response, candidates, metrics = self.chat(
576
+ query=user_input,
577
+ conversation_history=None,
578
+ quality_checker=quality_checker,
579
+ top_k=10
580
+ )
581
+
582
+ print(f"\nAssistant: {response}")
583
+
584
+ if show_alternatives and candidates and metrics.get("is_confident", False):
585
+ print("\n Alternative responses:")
586
+ for resp, score in candidates[1:4]:
587
+ print(f" Score: {score:.4f} - {resp}")
588
+ else:
589
+ print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
590
+
591
  def chat(
592
  self,
593
  query: str,
 
604
  conversation_str = self_arg._build_conversation_context(query_arg, conversation_history)
605
 
606
  # Retrieve and re-rank
607
+ responses = self_arg.retrieve_responses(
608
  query=conversation_str,
609
  top_k=top_k,
610
  reranker=self_arg.reranker,
 
613
  )
614
 
615
  # Handle low confidence or empty responses
616
+ if not responses:
617
  return ("I'm sorry, but I couldn't find a relevant response.", [], {})
618
 
619
+ # Analyze is_confident and computed score when returning the top response
620
+ metrics = quality_checker.check_response_quality(query_arg, responses)
621
+ is_confident = metrics.get('is_confident', False)
622
+ top_response_score = responses[0][1]
623
+
624
+ # if uncertain, ask for clarification
625
+ if not is_confident or top_response_score < 0.5:
626
+ return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics)
627
+
628
+ # Return the top response
629
+ return responses[0][0], responses, metrics
630
 
631
  return get_response(self, query)
632
 
chatbot_validator.py CHANGED
@@ -12,7 +12,7 @@ class ChatbotValidator:
12
  Handles automated validation and performance analysis for the chatbot.
13
  This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker.
14
  """
15
-
16
  def __init__(self, chatbot, quality_checker):
17
  """
18
  Initialize the validator.
@@ -22,7 +22,7 @@ class ChatbotValidator:
22
  """
23
  self.chatbot = chatbot
24
  self.quality_checker = quality_checker
25
-
26
  # Domain-specific test queries (aligns with Taskmaster-1 dataset)
27
  self.domain_queries = {
28
  'restaurant': [
@@ -56,7 +56,7 @@ class ChatbotValidator:
56
  "My Jeep needs a repair. Can you help me with that?",
57
  ],
58
  }
59
-
60
  def run_validation(
61
  self,
62
  num_examples: int = 3,
@@ -77,64 +77,65 @@ class ChatbotValidator:
77
  Dict with validation metrics
78
  """
79
  logger.info("\n=== Running Automatic Validation ===")
80
-
81
  # Select which domains to test
82
  test_domains = domains if domains else list(self.domain_queries.keys())
83
-
84
  # Initialize results
85
  metrics_history = []
86
  domain_metrics = {}
87
 
88
  # Init the cross-encoder reranker to pass to the chatbot
89
  reranker = CrossEncoderReranker(model_name=self.chatbot.config.cross_encoder_model)
90
-
91
  # Prepare random selection if needed
92
  rng = random.Random(seed)
93
-
94
  # Run validation for each domain
95
  for domain in test_domains:
96
  # Avoid errors if domain key missing
97
  if domain not in self.domain_queries:
98
  logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.")
99
  continue
100
-
101
  all_queries = self.domain_queries[domain]
102
  if randomize:
103
  queries = rng.sample(all_queries, min(num_examples, len(all_queries)))
104
  else:
105
  queries = all_queries[:num_examples]
106
-
107
  # Store domain-level metrics
108
  domain_metrics[domain] = []
109
-
110
- logger.info(f"\n=== Testing {domain.title()} Domain ===")
111
-
112
  for i, query in enumerate(queries, 1):
113
- logger.info(f"\nTest Case {i}: {query}")
114
-
115
  # Retrieve top_k responses, then evaluate with quality checker
116
  responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=reranker)
117
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
118
-
119
  # Aggregate metrics and log
120
  quality_metrics['domain'] = domain
121
  metrics_history.append(quality_metrics)
122
  domain_metrics[domain].append(quality_metrics)
123
  self._log_validation_results(query, responses, quality_metrics)
124
-
 
125
  # Final aggregation
126
  aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
127
  domain_analysis = self._analyze_domain_performance(domain_metrics)
128
  confidence_analysis = self._analyze_confidence_distribution(metrics_history)
129
-
130
  aggregate_metrics.update({
131
  'domain_performance': domain_analysis,
132
  'confidence_analysis': confidence_analysis
133
  })
134
-
135
  self._log_validation_summary(aggregate_metrics)
136
  return aggregate_metrics
137
-
138
  def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
139
  """
140
  Calculate aggregate metrics over tested queries.
@@ -142,9 +143,9 @@ class ChatbotValidator:
142
  if not metrics_history:
143
  logger.warning("No metrics to aggregate. Returning empty summary.")
144
  return {}
145
-
146
  top_scores = [m.get('top_score', 0.0) for m in metrics_history]
147
-
148
  metrics = {
149
  'num_queries_tested': len(metrics_history),
150
  'avg_top_response_score': np.mean(top_scores),
@@ -159,20 +160,20 @@ class ChatbotValidator:
159
  'max_score': np.max(top_scores)
160
  }
161
  return metrics
162
-
163
  def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]:
164
  """
165
  Analyze performance by domain, returning a nested dict.
166
  """
167
  analysis = {}
168
-
169
  for domain, metrics_list in domain_metrics.items():
170
  if not metrics_list:
171
  analysis[domain] = {}
172
  continue
173
-
174
  top_scores = [m.get('top_score', 0.0) for m in metrics_list]
175
-
176
  analysis[domain] = {
177
  'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_list]),
178
  'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_list]),
@@ -180,9 +181,9 @@ class ChatbotValidator:
180
  'avg_top_score': np.mean(top_scores),
181
  'num_samples': len(metrics_list)
182
  }
183
-
184
  return analysis
185
-
186
  def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
187
  """
188
  Analyze the distribution of top scores to gauge system confidence levels.
@@ -190,7 +191,7 @@ class ChatbotValidator:
190
  if not metrics_history:
191
  return {'percentile_25': 0.0, 'percentile_50': 0.0,
192
  'percentile_75': 0.0, 'percentile_90': 0.0}
193
-
194
  scores = [m.get('top_score', 0.0) for m in metrics_history]
195
  return {
196
  'percentile_25': float(np.percentile(scores, 25)),
@@ -198,7 +199,7 @@ class ChatbotValidator:
198
  'percentile_75': float(np.percentile(scores, 75)),
199
  'percentile_90': float(np.percentile(scores, 90))
200
  }
201
-
202
  def _log_validation_results(
203
  self,
204
  query: str,
@@ -210,19 +211,18 @@ class ChatbotValidator:
210
  """
211
  domain = metrics.get('domain', 'Unknown')
212
  is_confident = metrics.get('is_confident', False)
213
-
214
- logger.info(f"Domain: {domain} | Confidence: {'Yes' if is_confident else 'No'}")
215
- # logger.info("Quality Metrics:")
216
- # for k, v in metrics.items():
217
- # if isinstance(v, (int, float)):
218
- # logger.info(f" {k}: {v:.4f}")
219
-
220
- logger.info("Top 3 Responses:")
 
221
  for i, (resp_text, score) in enumerate(responses[:3], 1):
222
- logger.info(f"{i}) Score: {score:.4f} | {resp_text}")
223
- if i == 1 and not is_confident:
224
- logger.info(" [Low Confidence on Top Response]")
225
-
226
  def _log_validation_summary(self, metrics: Dict[str, Any]):
227
  """
228
  Log a summary of all validation metrics and domain performance.
@@ -230,16 +230,16 @@ class ChatbotValidator:
230
  if not metrics:
231
  logger.info("No metrics to summarize.")
232
  return
233
-
234
  logger.info("\n=== Validation Summary ===")
235
-
236
  # Overall
237
  logger.info("\nOverall Metrics:")
238
  for metric, value in metrics.items():
239
  # Skip sub-dicts here
240
  if isinstance(value, (int, float)):
241
  logger.info(f"{metric}: {value:.4f}")
242
-
243
  # Domain performance
244
  domain_perf = metrics.get('domain_performance', {})
245
  logger.info("\nDomain Performance:")
@@ -247,9 +247,10 @@ class ChatbotValidator:
247
  logger.info(f"\n{domain.title()}:")
248
  for metric, value in domain_stats.items():
249
  logger.info(f" {metric}: {value:.4f}")
250
-
251
  # Confidence distribution
252
  conf_analysis = metrics.get('confidence_analysis', {})
253
  logger.info("\nConfidence Distribution:")
254
  for pct, val in conf_analysis.items():
255
  logger.info(f" {pct}: {val:.4f}")
 
 
12
  Handles automated validation and performance analysis for the chatbot.
13
  This testing module executes domain-specific queries, obtains chatbot responses, and evaluates them with a quality checker.
14
  """
15
+
16
  def __init__(self, chatbot, quality_checker):
17
  """
18
  Initialize the validator.
 
22
  """
23
  self.chatbot = chatbot
24
  self.quality_checker = quality_checker
25
+
26
  # Domain-specific test queries (aligns with Taskmaster-1 dataset)
27
  self.domain_queries = {
28
  'restaurant': [
 
56
  "My Jeep needs a repair. Can you help me with that?",
57
  ],
58
  }
59
+
60
  def run_validation(
61
  self,
62
  num_examples: int = 3,
 
77
  Dict with validation metrics
78
  """
79
  logger.info("\n=== Running Automatic Validation ===")
80
+
81
  # Select which domains to test
82
  test_domains = domains if domains else list(self.domain_queries.keys())
83
+
84
  # Initialize results
85
  metrics_history = []
86
  domain_metrics = {}
87
 
88
  # Init the cross-encoder reranker to pass to the chatbot
89
  reranker = CrossEncoderReranker(model_name=self.chatbot.config.cross_encoder_model)
90
+
91
  # Prepare random selection if needed
92
  rng = random.Random(seed)
93
+
94
  # Run validation for each domain
95
  for domain in test_domains:
96
  # Avoid errors if domain key missing
97
  if domain not in self.domain_queries:
98
  logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.")
99
  continue
100
+
101
  all_queries = self.domain_queries[domain]
102
  if randomize:
103
  queries = rng.sample(all_queries, min(num_examples, len(all_queries)))
104
  else:
105
  queries = all_queries[:num_examples]
106
+
107
  # Store domain-level metrics
108
  domain_metrics[domain] = []
109
+
110
+ logger.info(f"\n=== Testing {domain.title()} Domain ===\n")
111
+
112
  for i, query in enumerate(queries, 1):
113
+ logger.info(f"TEST CASE {i}: QUERY: {query}")
114
+
115
  # Retrieve top_k responses, then evaluate with quality checker
116
  responses = self.chatbot.retrieve_responses(query, top_k=top_k, reranker=reranker)
117
  quality_metrics = self.quality_checker.check_response_quality(query, responses)
118
+
119
  # Aggregate metrics and log
120
  quality_metrics['domain'] = domain
121
  metrics_history.append(quality_metrics)
122
  domain_metrics[domain].append(quality_metrics)
123
  self._log_validation_results(query, responses, quality_metrics)
124
+ logger.info(f"Quality metrics: {quality_metrics}\n")
125
+
126
  # Final aggregation
127
  aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
128
  domain_analysis = self._analyze_domain_performance(domain_metrics)
129
  confidence_analysis = self._analyze_confidence_distribution(metrics_history)
130
+
131
  aggregate_metrics.update({
132
  'domain_performance': domain_analysis,
133
  'confidence_analysis': confidence_analysis
134
  })
135
+
136
  self._log_validation_summary(aggregate_metrics)
137
  return aggregate_metrics
138
+
139
  def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
140
  """
141
  Calculate aggregate metrics over tested queries.
 
143
  if not metrics_history:
144
  logger.warning("No metrics to aggregate. Returning empty summary.")
145
  return {}
146
+
147
  top_scores = [m.get('top_score', 0.0) for m in metrics_history]
148
+
149
  metrics = {
150
  'num_queries_tested': len(metrics_history),
151
  'avg_top_response_score': np.mean(top_scores),
 
160
  'max_score': np.max(top_scores)
161
  }
162
  return metrics
163
+
164
  def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]:
165
  """
166
  Analyze performance by domain, returning a nested dict.
167
  """
168
  analysis = {}
169
+
170
  for domain, metrics_list in domain_metrics.items():
171
  if not metrics_list:
172
  analysis[domain] = {}
173
  continue
174
+
175
  top_scores = [m.get('top_score', 0.0) for m in metrics_list]
176
+
177
  analysis[domain] = {
178
  'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_list]),
179
  'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_list]),
 
181
  'avg_top_score': np.mean(top_scores),
182
  'num_samples': len(metrics_list)
183
  }
184
+
185
  return analysis
186
+
187
  def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
188
  """
189
  Analyze the distribution of top scores to gauge system confidence levels.
 
191
  if not metrics_history:
192
  return {'percentile_25': 0.0, 'percentile_50': 0.0,
193
  'percentile_75': 0.0, 'percentile_90': 0.0}
194
+
195
  scores = [m.get('top_score', 0.0) for m in metrics_history]
196
  return {
197
  'percentile_25': float(np.percentile(scores, 25)),
 
199
  'percentile_75': float(np.percentile(scores, 75)),
200
  'percentile_90': float(np.percentile(scores, 90))
201
  }
202
+
203
  def _log_validation_results(
204
  self,
205
  query: str,
 
211
  """
212
  domain = metrics.get('domain', 'Unknown')
213
  is_confident = metrics.get('is_confident', False)
214
+
215
+ logger.info(f"DOMAIN: {domain} | CONFIDENCE: {'Yes' if is_confident else 'No'}")
216
+
217
+ if is_confident or responses[0][1] >= 0.5:
218
+ logger.info(f"SELECTED RESPONSE: '{responses[0][0]}'")
219
+ else:
220
+ logger.info("SELECTED RESPONSE: NONE (Low Confidence)")
221
+
222
+ logger.info(" Top 3 Responses:")
223
  for i, (resp_text, score) in enumerate(responses[:3], 1):
224
+ logger.info(f" {i}) Score: {score:.4f} | {resp_text}")
225
+
 
 
226
  def _log_validation_summary(self, metrics: Dict[str, Any]):
227
  """
228
  Log a summary of all validation metrics and domain performance.
 
230
  if not metrics:
231
  logger.info("No metrics to summarize.")
232
  return
233
+
234
  logger.info("\n=== Validation Summary ===")
235
+
236
  # Overall
237
  logger.info("\nOverall Metrics:")
238
  for metric, value in metrics.items():
239
  # Skip sub-dicts here
240
  if isinstance(value, (int, float)):
241
  logger.info(f"{metric}: {value:.4f}")
242
+
243
  # Domain performance
244
  domain_perf = metrics.get('domain_performance', {})
245
  logger.info("\nDomain Performance:")
 
247
  logger.info(f"\n{domain.title()}:")
248
  for metric, value in domain_stats.items():
249
  logger.info(f" {metric}: {value:.4f}")
250
+
251
  # Confidence distribution
252
  conf_analysis = metrics.get('confidence_analysis', {})
253
  logger.info("\nConfidence Distribution:")
254
  for pct, val in conf_analysis.items():
255
  logger.info(f" {pct}: {val:.4f}")
256
+
response_quality_checker.py CHANGED
@@ -18,10 +18,10 @@ class ResponseQualityChecker:
18
  def __init__(
19
  self,
20
  data_pipeline: "TFDataPipeline",
21
- confidence_threshold: float = 0.40,
22
- diversity_threshold: float = 0.15,
23
  min_response_length: int = 5,
24
- similarity_cap: float = 0.85
25
  ):
26
  """
27
  Args:
@@ -74,10 +74,8 @@ class ResponseQualityChecker:
74
  metrics['response_length_score'] = self._calc_length_score(responses)
75
  metrics['top_score'] = responses[0][1]
76
  metrics['top_3_score_gap'] = self._calc_score_gap([score for _, score in responses])
77
-
78
  metrics['is_confident'] = self._determine_confidence(metrics)
79
 
80
- logger.info(f"Quality metrics: {metrics}")
81
  return metrics
82
 
83
  def _calc_diversity(self, responses: List[Tuple[str, float]]) -> float:
 
18
  def __init__(
19
  self,
20
  data_pipeline: "TFDataPipeline",
21
+ confidence_threshold: float = 0.45,
22
+ diversity_threshold: float = 0.10,
23
  min_response_length: int = 5,
24
+ similarity_cap: float = 0.90
25
  ):
26
  """
27
  Args:
 
74
  metrics['response_length_score'] = self._calc_length_score(responses)
75
  metrics['top_score'] = responses[0][1]
76
  metrics['top_3_score_gap'] = self._calc_score_gap([score for _, score in responses])
 
77
  metrics['is_confident'] = self._determine_confidence(metrics)
78
 
 
79
  return metrics
80
 
81
  def _calc_diversity(self, responses: List[Tuple[str, float]]) -> float:
run_chatbot_chat.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from chatbot_model import ChatbotConfig, RetrievalChatbot
4
+ from response_quality_checker import ResponseQualityChecker
5
+ from environment_setup import EnvironmentSetup
6
+ from logger_config import config_logger
7
+
8
+ logger = config_logger(__name__)
9
+
10
+ def run_chatbot_chat():
11
+ env = EnvironmentSetup()
12
+ env.initialize()
13
+
14
+ MODEL_DIR = "models"
15
+ FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
16
+ FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
17
+ FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
18
+
19
+ # Toggle 'production' or 'test' env
20
+ ENVIRONMENT = "production"
21
+ if ENVIRONMENT == "test":
22
+ FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
23
+ RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json")
24
+ else:
25
+ FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
26
+ RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
27
+
28
+ # Load the config
29
+ config_path = os.path.join(MODEL_DIR, "config.json")
30
+ if os.path.exists(config_path):
31
+ with open(config_path, "r", encoding="utf-8") as f:
32
+ config_dict = json.load(f)
33
+ config = ChatbotConfig.from_dict(config_dict)
34
+ logger.info(f"Loaded ChatbotConfig from {config_path}")
35
+ else:
36
+ config = ChatbotConfig()
37
+ logger.warning("No config.json found. Using default ChatbotConfig.")
38
+
39
+ # Load RetrievalChatbot in 'inference' mode
40
+ try:
41
+ chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
42
+ except Exception as e:
43
+ logger.error(f"Failed to load RetrievalChatbot: {e}")
44
+ return
45
+
46
+ # Confirm FAISS index & response pool exist
47
+ if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
48
+ logger.error("FAISS index or response pool file is missing.")
49
+ return
50
+
51
+ # Load FAISS index and response pool
52
+ try:
53
+ chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
54
+
55
+ with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
56
+ chatbot.data_pipeline.response_pool = json.load(f)
57
+ logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
58
+ # Validate dimension consistency
59
+ chatbot.data_pipeline.validate_faiss_index()
60
+
61
+ except Exception as e:
62
+ logger.error(f"Failed to load or validate FAISS index: {e}")
63
+ return
64
+
65
+ # Init QualityChecker and Validator
66
+ quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
67
+
68
+ # Run interactive chat loop
69
+ logger.info("\nStarting interactive chat session...")
70
+ chatbot.run_interactive_chat(quality_checker)
71
+
72
+ if __name__ == "__main__":
73
+ run_chatbot_chat()
run_chatbot_validation.py CHANGED
@@ -8,36 +8,6 @@ from environment_setup import EnvironmentSetup
8
  from logger_config import config_logger
9
 
10
  logger = config_logger(__name__)
11
-
12
- def run_interactive_chat(chatbot, quality_checker):
13
- """Separate function for interactive chat loop."""
14
- while True:
15
- try:
16
- user_input = input("You: ")
17
- except (KeyboardInterrupt, EOFError):
18
- print("\nAssistant: Goodbye!")
19
- break
20
-
21
- if user_input.lower() in ["quit", "exit", "bye"]:
22
- print("Assistant: Goodbye!")
23
- break
24
-
25
- response, candidates, metrics = chatbot.chat(
26
- query=user_input,
27
- conversation_history=None,
28
- quality_checker=quality_checker,
29
- top_k=10
30
- )
31
-
32
- print(f"Assistant: {response}")
33
-
34
- # Show alternative responses if confident
35
- if metrics.get("is_confident", False):
36
- print("\nAlternative responses:")
37
- for resp, score in candidates[1:4]:
38
- print(f"Score: {score:.4f} - {resp}")
39
- else:
40
- print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
41
 
42
  def run_chatbot_validation():
43
  # Initialize environment
@@ -118,16 +88,16 @@ def run_chatbot_validation():
118
  return
119
 
120
  # Plot metrics
121
- # try:
122
- # plotter = Plotter(save_dir=env.training_dirs["plots"])
123
- # plotter.plot_validation_metrics(validation_metrics)
124
- # logger.info("Validation metrics plotted successfully.")
125
- # except Exception as e:
126
- # logger.error(f"Failed to plot validation metrics: {e}")
127
 
128
  # Run interactive chat loop
129
  logger.info("\nStarting interactive chat session...")
130
- run_interactive_chat(chatbot, quality_checker)
131
 
132
  if __name__ == "__main__":
133
  run_chatbot_validation()
 
8
  from logger_config import config_logger
9
 
10
  logger = config_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def run_chatbot_validation():
13
  # Initialize environment
 
88
  return
89
 
90
  # Plot metrics
91
+ try:
92
+ plotter = Plotter(save_dir=env.training_dirs["plots"])
93
+ plotter.plot_validation_metrics(validation_metrics)
94
+ logger.info("Validation metrics plotted successfully.")
95
+ except Exception as e:
96
+ logger.error(f"Failed to plot validation metrics: {e}")
97
 
98
  # Run interactive chat loop
99
  logger.info("\nStarting interactive chat session...")
100
+ chatbot.run_interactive_chat(quality_checker, show_alternatives=True)
101
 
102
  if __name__ == "__main__":
103
  run_chatbot_validation()