from typing import Dict, List, Tuple, Any, Optional import numpy as np from logger_config import config_logger logger = config_logger(__name__) class ChatbotValidator: """Handles automated validation and performance analysis for the chatbot.""" def __init__(self, chatbot, quality_checker): """ Initialize the validator. Args: chatbot: RetrievalChatbot instance quality_checker: ResponseQualityChecker instance """ self.chatbot = chatbot self.quality_checker = quality_checker # Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided self.domain_queries = { 'restaurant': [ "I'd like to make a reservation for dinner tonight.", "Can you book a table for 4 people at an Italian place?", "Do you have any availability for tomorrow at 7pm?", "I need to change my dinner reservation time.", "What's the wait time for a table right now?" ], 'movie_tickets': [ "I want to buy tickets for the new Marvel movie.", "Are there any showings of Avatar after 6pm?", "Can I get 3 tickets for the 8pm show?", "What movies are playing this weekend?", "Do you have any matinee showtimes available?" ], 'rideshare': [ "I need a ride from the airport to downtown.", "How much would it cost to get to the mall?", "Can you book a car for tomorrow morning?", "Is there a driver available now?", "What's the estimated arrival time?" ], 'services': [ "I need to schedule an oil change for my car.", "When can I bring my car in for maintenance?", "Do you have any openings for auto repair today?", "How long will the service take?", "Can I get an estimate for brake repair?" ], 'events': [ "I need tickets to the concert this weekend.", "What events are happening near me?", "Can I book seats for the basketball game?", "Are there any comedy shows tonight?", "How much are tickets to the theater?" ] } def run_validation( self, num_examples: int = 10, top_k: int = 10, domains: Optional[List[str]] = None ) -> Dict[str, Any]: """ Run comprehensive validation across specified domains. Args: num_examples: Number of test queries per domain top_k: Number of responses to retrieve for each query domains: Optional list of specific domains to test Returns: Dict containing detailed validation metrics and domain-specific performance """ logger.info("\n=== Running Enhanced Automatic Validation ===") # Select domains to test test_domains = domains if domains else list(self.domain_queries.keys()) metrics_history = [] domain_metrics = {} # Run validation for each domain for domain in test_domains: domain_metrics[domain] = [] queries = self.domain_queries[domain][:num_examples] logger.info(f"\n=== Testing {domain.title()} Domain ===") for i, query in enumerate(queries, 1): logger.info(f"\nTest Case {i}:") logger.info(f"Query: {query}") # Get responses with increased top_k responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k) # Enhanced quality checking with context quality_metrics = self.quality_checker.check_response_quality(query, responses) # Add domain info quality_metrics['domain'] = domain metrics_history.append(quality_metrics) domain_metrics[domain].append(quality_metrics) # Detailed logging self._log_validation_results(query, responses, quality_metrics, i) # Calculate and log overall metrics aggregate_metrics = self._calculate_aggregate_metrics(metrics_history) domain_analysis = self._analyze_domain_performance(domain_metrics) confidence_analysis = self._analyze_confidence_distribution(metrics_history) aggregate_metrics.update({ 'domain_performance': domain_analysis, 'confidence_analysis': confidence_analysis }) self._log_validation_summary(aggregate_metrics) return aggregate_metrics def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]: """Calculate comprehensive aggregate metrics.""" metrics = { 'num_queries_tested': len(metrics_history), 'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]), 'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]), 'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]), 'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]), 'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]), 'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]), # Additional statistical metrics 'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]), 'score_std': np.std([m.get('top_score', 0) for m in metrics_history]), 'min_score': np.min([m.get('top_score', 0) for m in metrics_history]), 'max_score': np.max([m.get('top_score', 0) for m in metrics_history]) } return metrics def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]: """Analyze performance by domain.""" domain_analysis = {} for domain, metrics in domain_metrics.items(): domain_analysis[domain] = { 'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]), 'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics]), 'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics]), 'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]), 'num_samples': len(metrics) } return domain_analysis def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]: """Analyze the distribution of confidence scores.""" scores = [m.get('top_score', 0) for m in metrics_history] return { 'percentile_25': np.percentile(scores, 25), 'percentile_50': np.percentile(scores, 50), 'percentile_75': np.percentile(scores, 75), 'percentile_90': np.percentile(scores, 90) } def _log_validation_results( self, query: str, responses: List[Tuple[str, float]], metrics: Dict[str, Any], case_num: int ): """Log detailed validation results.""" logger.info(f"\nTest Case {case_num}:") logger.info(f"Query: {query}") logger.info(f"Domain: {metrics.get('domain', 'Unknown')}") logger.info(f"Confidence: {'Yes' if metrics.get('is_confident', False) else 'No'}") logger.info("\nQuality Metrics:") for metric, value in metrics.items(): if isinstance(value, (int, float)): logger.info(f" {metric}: {value:.4f}") logger.info("\nTop Responses:") for i, (response, score) in enumerate(responses[:3], 1): logger.info(f"{i}. Score: {score:.4f}. Response: {response}") if i == 1 and not metrics.get('is_confident', False): logger.info(" [Low Confidence]") def _log_validation_summary(self, metrics: Dict[str, Any]): """Log comprehensive validation summary.""" logger.info("\n=== Validation Summary ===") logger.info("\nOverall Metrics:") for metric, value in metrics.items(): if isinstance(value, (int, float)): logger.info(f"{metric}: {value:.4f}") logger.info("\nDomain Performance:") for domain, domain_metrics in metrics['domain_performance'].items(): logger.info(f"\n{domain.title()}:") for metric, value in domain_metrics.items(): logger.info(f" {metric}: {value:.4f}") logger.info("\nConfidence Distribution:") for percentile, value in metrics['confidence_analysis'].items(): logger.info(f"{percentile}: {value:.4f}")