|
from typing import Dict, List, Tuple, Any, Optional |
|
import numpy as np |
|
from logger_config import config_logger |
|
|
|
logger = config_logger(__name__) |
|
|
|
class ChatbotValidator: |
|
"""Handles automated validation and performance analysis for the chatbot.""" |
|
|
|
def __init__(self, chatbot, quality_checker): |
|
""" |
|
Initialize the validator. |
|
|
|
Args: |
|
chatbot: RetrievalChatbot instance |
|
quality_checker: ResponseQualityChecker instance |
|
""" |
|
self.chatbot = chatbot |
|
self.quality_checker = quality_checker |
|
|
|
|
|
self.domain_queries = { |
|
'restaurant': [ |
|
"I'd like to make a reservation for dinner tonight.", |
|
"Can you book a table for 4 people at an Italian place?", |
|
"Do you have any availability for tomorrow at 7pm?", |
|
"I need to change my dinner reservation time.", |
|
"What's the wait time for a table right now?" |
|
], |
|
'movie_tickets': [ |
|
"I want to buy tickets for the new Marvel movie.", |
|
"Are there any showings of Avatar after 6pm?", |
|
"Can I get 3 tickets for the 8pm show?", |
|
"What movies are playing this weekend?", |
|
"Do you have any matinee showtimes available?" |
|
], |
|
'rideshare': [ |
|
"I need a ride from the airport to downtown.", |
|
"How much would it cost to get to the mall?", |
|
"Can you book a car for tomorrow morning?", |
|
"Is there a driver available now?", |
|
"What's the estimated arrival time?" |
|
], |
|
'services': [ |
|
"I need to schedule an oil change for my car.", |
|
"When can I bring my car in for maintenance?", |
|
"Do you have any openings for auto repair today?", |
|
"How long will the service take?", |
|
"Can I get an estimate for brake repair?" |
|
], |
|
'events': [ |
|
"I need tickets to the concert this weekend.", |
|
"What events are happening near me?", |
|
"Can I book seats for the basketball game?", |
|
"Are there any comedy shows tonight?", |
|
"How much are tickets to the theater?" |
|
] |
|
} |
|
|
|
def run_validation( |
|
self, |
|
num_examples: int = 10, |
|
top_k: int = 10, |
|
domains: Optional[List[str]] = None |
|
) -> Dict[str, Any]: |
|
""" |
|
Run comprehensive validation across specified domains. |
|
|
|
Args: |
|
num_examples: Number of test queries per domain |
|
top_k: Number of responses to retrieve for each query |
|
domains: Optional list of specific domains to test |
|
|
|
Returns: |
|
Dict containing detailed validation metrics and domain-specific performance |
|
""" |
|
logger.info("\n=== Running Enhanced Automatic Validation ===") |
|
|
|
|
|
test_domains = domains if domains else list(self.domain_queries.keys()) |
|
metrics_history = [] |
|
domain_metrics = {} |
|
|
|
|
|
for domain in test_domains: |
|
domain_metrics[domain] = [] |
|
queries = self.domain_queries[domain][:num_examples] |
|
|
|
logger.info(f"\n=== Testing {domain.title()} Domain ===") |
|
|
|
for i, query in enumerate(queries, 1): |
|
logger.info(f"\nTest Case {i}:") |
|
logger.info(f"Query: {query}") |
|
|
|
|
|
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k) |
|
|
|
|
|
quality_metrics = self.quality_checker.check_response_quality(query, responses) |
|
|
|
|
|
quality_metrics['domain'] = domain |
|
metrics_history.append(quality_metrics) |
|
domain_metrics[domain].append(quality_metrics) |
|
|
|
|
|
self._log_validation_results(query, responses, quality_metrics, i) |
|
|
|
|
|
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history) |
|
domain_analysis = self._analyze_domain_performance(domain_metrics) |
|
confidence_analysis = self._analyze_confidence_distribution(metrics_history) |
|
|
|
aggregate_metrics.update({ |
|
'domain_performance': domain_analysis, |
|
'confidence_analysis': confidence_analysis |
|
}) |
|
|
|
self._log_validation_summary(aggregate_metrics) |
|
return aggregate_metrics |
|
|
|
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]: |
|
"""Calculate comprehensive aggregate metrics.""" |
|
metrics = { |
|
'num_queries_tested': len(metrics_history), |
|
'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]), |
|
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]), |
|
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]), |
|
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]), |
|
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]), |
|
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]), |
|
|
|
|
|
'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]), |
|
'score_std': np.std([m.get('top_score', 0) for m in metrics_history]), |
|
'min_score': np.min([m.get('top_score', 0) for m in metrics_history]), |
|
'max_score': np.max([m.get('top_score', 0) for m in metrics_history]) |
|
} |
|
return metrics |
|
|
|
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]: |
|
"""Analyze performance by domain.""" |
|
domain_analysis = {} |
|
|
|
for domain, metrics in domain_metrics.items(): |
|
domain_analysis[domain] = { |
|
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]), |
|
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics]), |
|
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics]), |
|
'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]), |
|
'num_samples': len(metrics) |
|
} |
|
|
|
return domain_analysis |
|
|
|
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]: |
|
"""Analyze the distribution of confidence scores.""" |
|
scores = [m.get('top_score', 0) for m in metrics_history] |
|
|
|
return { |
|
'percentile_25': np.percentile(scores, 25), |
|
'percentile_50': np.percentile(scores, 50), |
|
'percentile_75': np.percentile(scores, 75), |
|
'percentile_90': np.percentile(scores, 90) |
|
} |
|
|
|
def _log_validation_results( |
|
self, |
|
query: str, |
|
responses: List[Tuple[str, float]], |
|
metrics: Dict[str, Any], |
|
case_num: int |
|
): |
|
"""Log detailed validation results.""" |
|
logger.info(f"\nTest Case {case_num}:") |
|
logger.info(f"Query: {query}") |
|
logger.info(f"Domain: {metrics.get('domain', 'Unknown')}") |
|
logger.info(f"Confidence: {'Yes' if metrics.get('is_confident', False) else 'No'}") |
|
logger.info("\nQuality Metrics:") |
|
for metric, value in metrics.items(): |
|
if isinstance(value, (int, float)): |
|
logger.info(f" {metric}: {value:.4f}") |
|
|
|
logger.info("\nTop Responses:") |
|
for i, (response, score) in enumerate(responses[:3], 1): |
|
logger.info(f"{i}. Score: {score:.4f}. Response: {response}") |
|
if i == 1 and not metrics.get('is_confident', False): |
|
logger.info(" [Low Confidence]") |
|
|
|
def _log_validation_summary(self, metrics: Dict[str, Any]): |
|
"""Log comprehensive validation summary.""" |
|
logger.info("\n=== Validation Summary ===") |
|
|
|
logger.info("\nOverall Metrics:") |
|
for metric, value in metrics.items(): |
|
if isinstance(value, (int, float)): |
|
logger.info(f"{metric}: {value:.4f}") |
|
|
|
logger.info("\nDomain Performance:") |
|
for domain, domain_metrics in metrics['domain_performance'].items(): |
|
logger.info(f"\n{domain.title()}:") |
|
for metric, value in domain_metrics.items(): |
|
logger.info(f" {metric}: {value:.4f}") |
|
|
|
logger.info("\nConfidence Distribution:") |
|
for percentile, value in metrics['confidence_analysis'].items(): |
|
logger.info(f"{percentile}: {value:.4f}") |