|
from typing import Dict, List, Tuple, Any, Optional |
|
import numpy as np |
|
import random |
|
from logger_config import config_logger |
|
from cross_encoder_reranker import CrossEncoderReranker |
|
|
|
logger = config_logger(__name__) |
|
|
|
|
|
class ChatbotValidator: |
|
""" |
|
Handles automated validation and performance analysis for the chatbot. |
|
|
|
This validator executes domain-specific test queries, obtains candidate |
|
responses via the chatbot, then evaluates them with a quality checker. |
|
It aggregates metrics across queries and domains, logs intermediate |
|
results, and returns a comprehensive summary. |
|
""" |
|
|
|
def __init__(self, chatbot, quality_checker): |
|
""" |
|
Initialize the validator. |
|
|
|
Args: |
|
chatbot: RetrievalChatbot instance for inference |
|
quality_checker: ResponseQualityChecker instance |
|
""" |
|
self.chatbot = chatbot |
|
self.quality_checker = quality_checker |
|
|
|
|
|
|
|
self.domain_queries = { |
|
'restaurant': [ |
|
"I'd like to make a reservation for dinner tonight.", |
|
"Can you book a table for 4 at an Italian restaurant?", |
|
"Is there any availability to dine tomorrow at 7pm?", |
|
"I'd like to cancel my reservation for tonight.", |
|
"What's the wait time for a table right now?" |
|
], |
|
'movie_tickets': [ |
|
"I want to buy tickets for the new Marvel movie.", |
|
"Are there any showings of Avatar after 6pm?", |
|
"Can I get 3 tickets for the 8pm show?", |
|
"What movies are playing this weekend?", |
|
"Do you have any matinee showtimes available?" |
|
], |
|
'rideshare': [ |
|
"I need a ride from the airport to downtown.", |
|
"How much would it cost to get to the mall?", |
|
"Can you book a car for tomorrow morning?", |
|
"Is there a driver available right now?", |
|
"What's the estimated arrival time for the driver?" |
|
], |
|
'services': [ |
|
"I need to schedule an oil change for my car.", |
|
"When can I bring my car in for maintenance?", |
|
"Do you have any openings for auto repair today?", |
|
"How long will the service take?", |
|
"Can I get an estimate for brake repair?" |
|
], |
|
'events': [ |
|
"I need tickets to the concert this weekend.", |
|
"What events are happening near me?", |
|
"Can I book seats for the basketball game?", |
|
"Are there any comedy shows tonight?", |
|
"How much are tickets to the theater?" |
|
] |
|
} |
|
|
|
def run_validation( |
|
self, |
|
num_examples: int = 5, |
|
top_k: int = 10, |
|
domains: Optional[List[str]] = None, |
|
randomize: bool = False, |
|
seed: int = 42 |
|
) -> Dict[str, Any]: |
|
""" |
|
Run comprehensive validation across specified domains. |
|
|
|
Args: |
|
num_examples: Number of test queries per domain |
|
top_k: Number of responses to retrieve for each query |
|
domains: Optional list of domain keys to test. If None, test all. |
|
randomize: If True, randomly select queries from the domain lists |
|
seed: Random seed for consistent sampling if randomize=True |
|
|
|
Returns: |
|
Dict containing detailed validation metrics and domain-specific performance |
|
""" |
|
logger.info("\n=== Running Enhanced Automatic Validation ===") |
|
|
|
|
|
test_domains = domains if domains else list(self.domain_queries.keys()) |
|
|
|
|
|
metrics_history = [] |
|
domain_metrics = {} |
|
|
|
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2") |
|
|
|
|
|
rng = random.Random(seed) |
|
|
|
|
|
for domain in test_domains: |
|
|
|
if domain not in self.domain_queries: |
|
logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.") |
|
continue |
|
|
|
all_queries = self.domain_queries[domain] |
|
if randomize: |
|
queries = rng.sample(all_queries, min(num_examples, len(all_queries))) |
|
else: |
|
queries = all_queries[:num_examples] |
|
|
|
|
|
domain_metrics[domain] = [] |
|
|
|
logger.info(f"\n=== Testing {domain.title()} Domain ===") |
|
|
|
for i, query in enumerate(queries, 1): |
|
logger.info(f"\nTest Case {i}: {query}") |
|
|
|
|
|
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k, reranker=reranker) |
|
|
|
|
|
quality_metrics = self.quality_checker.check_response_quality(query, responses) |
|
|
|
|
|
quality_metrics['domain'] = domain |
|
metrics_history.append(quality_metrics) |
|
domain_metrics[domain].append(quality_metrics) |
|
|
|
|
|
self._log_validation_results(query, responses, quality_metrics, i) |
|
|
|
|
|
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history) |
|
domain_analysis = self._analyze_domain_performance(domain_metrics) |
|
confidence_analysis = self._analyze_confidence_distribution(metrics_history) |
|
|
|
|
|
aggregate_metrics.update({ |
|
'domain_performance': domain_analysis, |
|
'confidence_analysis': confidence_analysis |
|
}) |
|
|
|
self._log_validation_summary(aggregate_metrics) |
|
return aggregate_metrics |
|
|
|
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]: |
|
""" |
|
Calculate comprehensive aggregate metrics over all tested queries. |
|
""" |
|
if not metrics_history: |
|
logger.warning("No metrics to aggregate. Returning empty summary.") |
|
return {} |
|
|
|
top_scores = [m.get('top_score', 0.0) for m in metrics_history] |
|
|
|
|
|
metrics = { |
|
'num_queries_tested': len(metrics_history), |
|
'avg_top_response_score': np.mean(top_scores), |
|
'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_history]), |
|
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]), |
|
'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]), |
|
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]), |
|
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 |
|
for m in metrics_history]), |
|
|
|
|
|
'median_top_score': np.median(top_scores), |
|
'score_std': np.std(top_scores), |
|
'min_score': np.min(top_scores), |
|
'max_score': np.max(top_scores) |
|
} |
|
return metrics |
|
|
|
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]: |
|
""" |
|
Analyze performance by domain, returning a nested dict. |
|
""" |
|
analysis = {} |
|
|
|
for domain, metrics_list in domain_metrics.items(): |
|
if not metrics_list: |
|
analysis[domain] = {} |
|
continue |
|
|
|
top_scores = [m.get('top_score', 0.0) for m in metrics_list] |
|
|
|
analysis[domain] = { |
|
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 |
|
for m in metrics_list]), |
|
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) |
|
for m in metrics_list]), |
|
'avg_diversity': np.mean([m.get('response_diversity', 0.0) |
|
for m in metrics_list]), |
|
'avg_top_score': np.mean(top_scores), |
|
'num_samples': len(metrics_list) |
|
} |
|
|
|
return analysis |
|
|
|
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]: |
|
""" |
|
Analyze the distribution of top scores to gauge system confidence levels. |
|
""" |
|
if not metrics_history: |
|
return {'percentile_25': 0.0, 'percentile_50': 0.0, |
|
'percentile_75': 0.0, 'percentile_90': 0.0} |
|
|
|
scores = [m.get('top_score', 0.0) for m in metrics_history] |
|
return { |
|
'percentile_25': float(np.percentile(scores, 25)), |
|
'percentile_50': float(np.percentile(scores, 50)), |
|
'percentile_75': float(np.percentile(scores, 75)), |
|
'percentile_90': float(np.percentile(scores, 90)) |
|
} |
|
|
|
def _log_validation_results( |
|
self, |
|
query: str, |
|
responses: List[Tuple[str, float]], |
|
metrics: Dict[str, Any], |
|
case_num: int |
|
): |
|
""" |
|
Log detailed validation results for each test case. |
|
""" |
|
domain = metrics.get('domain', 'Unknown') |
|
is_confident = metrics.get('is_confident', False) |
|
|
|
logger.info(f"Domain: {domain} | Confidence: {'Yes' if is_confident else 'No'}") |
|
logger.info("Quality Metrics:") |
|
for k, v in metrics.items(): |
|
if isinstance(v, (int, float)): |
|
logger.info(f" {k}: {v:.4f}") |
|
|
|
logger.info("Top 3 Responses:") |
|
for i, (resp_text, score) in enumerate(responses[:3], 1): |
|
logger.info(f"{i}) Score: {score:.4f} | {resp_text}") |
|
if i == 1 and not is_confident: |
|
logger.info(" [Low Confidence on Top Response]") |
|
|
|
def _log_validation_summary(self, metrics: Dict[str, Any]): |
|
""" |
|
Log a summary of all validation metrics and domain performance. |
|
""" |
|
if not metrics: |
|
logger.info("No metrics to summarize.") |
|
return |
|
|
|
logger.info("\n=== Validation Summary ===") |
|
|
|
|
|
logger.info("\nOverall Metrics:") |
|
for metric, value in metrics.items(): |
|
|
|
if isinstance(value, (int, float)): |
|
logger.info(f"{metric}: {value:.4f}") |
|
|
|
|
|
domain_perf = metrics.get('domain_performance', {}) |
|
logger.info("\nDomain Performance:") |
|
for domain, domain_stats in domain_perf.items(): |
|
logger.info(f"\n{domain.title()}:") |
|
for metric, value in domain_stats.items(): |
|
logger.info(f" {metric}: {value:.4f}") |
|
|
|
|
|
conf_analysis = metrics.get('confidence_analysis', {}) |
|
logger.info("\nConfidence Distribution:") |
|
for pct, val in conf_analysis.items(): |
|
logger.info(f" {pct}: {val:.4f}") |
|
|