from typing import Dict, List, Tuple, Any, Optional import numpy as np import random from logger_config import config_logger from cross_encoder_reranker import CrossEncoderReranker logger = config_logger(__name__) class ChatbotValidator: """ Handles automated validation and performance analysis for the chatbot. This validator executes domain-specific test queries, obtains candidate responses via the chatbot, then evaluates them with a quality checker. It aggregates metrics across queries and domains, logs intermediate results, and returns a comprehensive summary. """ def __init__(self, chatbot, quality_checker): """ Initialize the validator. Args: chatbot: RetrievalChatbot instance for inference quality_checker: ResponseQualityChecker instance """ self.chatbot = chatbot self.quality_checker = quality_checker # Basic domain-specific test queries (easy examples) # Taskmaster-1 and Schema-Guided style self.domain_queries = { 'restaurant': [ "I'd like to make a reservation for dinner tonight.", "Can you book a table for 4 at an Italian restaurant?", "Is there any availability to dine tomorrow at 7pm?", "I'd like to cancel my reservation for tonight.", "What's the wait time for a table right now?" ], 'movie_tickets': [ "I want to buy tickets for the new Marvel movie.", "Are there any showings of Avatar after 6pm?", "Can I get 3 tickets for the 8pm show?", "What movies are playing this weekend?", "Do you have any matinee showtimes available?" ], 'rideshare': [ "I need a ride from the airport to downtown.", "How much would it cost to get to the mall?", "Can you book a car for tomorrow morning?", "Is there a driver available right now?", "What's the estimated arrival time for the driver?" ], 'services': [ "I need to schedule an oil change for my car.", "When can I bring my car in for maintenance?", "Do you have any openings for auto repair today?", "How long will the service take?", "Can I get an estimate for brake repair?" ], 'events': [ "I need tickets to the concert this weekend.", "What events are happening near me?", "Can I book seats for the basketball game?", "Are there any comedy shows tonight?", "How much are tickets to the theater?" ] } def run_validation( self, num_examples: int = 5, top_k: int = 10, domains: Optional[List[str]] = None, randomize: bool = False, seed: int = 42 ) -> Dict[str, Any]: """ Run comprehensive validation across specified domains. Args: num_examples: Number of test queries per domain top_k: Number of responses to retrieve for each query domains: Optional list of domain keys to test. If None, test all. randomize: If True, randomly select queries from the domain lists seed: Random seed for consistent sampling if randomize=True Returns: Dict containing detailed validation metrics and domain-specific performance """ logger.info("\n=== Running Enhanced Automatic Validation ===") # Select which domains to test test_domains = domains if domains else list(self.domain_queries.keys()) # Initialize results metrics_history = [] domain_metrics = {} reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2") # Prepare random selection if needed rng = random.Random(seed) # Run validation for each domain for domain in test_domains: # Avoid errors if domain key missing if domain not in self.domain_queries: logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.") continue all_queries = self.domain_queries[domain] if randomize: queries = rng.sample(all_queries, min(num_examples, len(all_queries))) else: queries = all_queries[:num_examples] # Store domain-level metrics domain_metrics[domain] = [] logger.info(f"\n=== Testing {domain.title()} Domain ===") for i, query in enumerate(queries, 1): logger.info(f"\nTest Case {i}: {query}") # Retrieve top_k responses (including cross-encoder re-ranking if available) responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k, reranker=reranker) # Evaluate with quality checker quality_metrics = self.quality_checker.check_response_quality(query, responses) # Save domain info quality_metrics['domain'] = domain metrics_history.append(quality_metrics) domain_metrics[domain].append(quality_metrics) # Detailed logging self._log_validation_results(query, responses, quality_metrics, i) # Final aggregation aggregate_metrics = self._calculate_aggregate_metrics(metrics_history) domain_analysis = self._analyze_domain_performance(domain_metrics) confidence_analysis = self._analyze_confidence_distribution(metrics_history) # Combine into one dictionary aggregate_metrics.update({ 'domain_performance': domain_analysis, 'confidence_analysis': confidence_analysis }) self._log_validation_summary(aggregate_metrics) return aggregate_metrics def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]: """ Calculate comprehensive aggregate metrics over all tested queries. """ if not metrics_history: logger.warning("No metrics to aggregate. Returning empty summary.") return {} top_scores = [m.get('top_score', 0.0) for m in metrics_history] # The length-based metrics are robust to missing or zero-length data metrics = { 'num_queries_tested': len(metrics_history), 'avg_top_response_score': np.mean(top_scores), 'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_history]), 'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]), 'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]), 'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]), 'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_history]), # Additional statistical metrics 'median_top_score': np.median(top_scores), 'score_std': np.std(top_scores), 'min_score': np.min(top_scores), 'max_score': np.max(top_scores) } return metrics def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]: """ Analyze performance by domain, returning a nested dict. """ analysis = {} for domain, metrics_list in domain_metrics.items(): if not metrics_list: analysis[domain] = {} continue top_scores = [m.get('top_score', 0.0) for m in metrics_list] analysis[domain] = { 'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0 for m in metrics_list]), 'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_list]), 'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_list]), 'avg_top_score': np.mean(top_scores), 'num_samples': len(metrics_list) } return analysis def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]: """ Analyze the distribution of top scores to gauge system confidence levels. """ if not metrics_history: return {'percentile_25': 0.0, 'percentile_50': 0.0, 'percentile_75': 0.0, 'percentile_90': 0.0} scores = [m.get('top_score', 0.0) for m in metrics_history] return { 'percentile_25': float(np.percentile(scores, 25)), 'percentile_50': float(np.percentile(scores, 50)), 'percentile_75': float(np.percentile(scores, 75)), 'percentile_90': float(np.percentile(scores, 90)) } def _log_validation_results( self, query: str, responses: List[Tuple[str, float]], metrics: Dict[str, Any], case_num: int ): """ Log detailed validation results for each test case. """ domain = metrics.get('domain', 'Unknown') is_confident = metrics.get('is_confident', False) logger.info(f"Domain: {domain} | Confidence: {'Yes' if is_confident else 'No'}") logger.info("Quality Metrics:") for k, v in metrics.items(): if isinstance(v, (int, float)): logger.info(f" {k}: {v:.4f}") logger.info("Top 3 Responses:") for i, (resp_text, score) in enumerate(responses[:3], 1): logger.info(f"{i}) Score: {score:.4f} | {resp_text}") if i == 1 and not is_confident: logger.info(" [Low Confidence on Top Response]") def _log_validation_summary(self, metrics: Dict[str, Any]): """ Log a summary of all validation metrics and domain performance. """ if not metrics: logger.info("No metrics to summarize.") return logger.info("\n=== Validation Summary ===") # Overall logger.info("\nOverall Metrics:") for metric, value in metrics.items(): # Skip sub-dicts here if isinstance(value, (int, float)): logger.info(f"{metric}: {value:.4f}") # Domain performance domain_perf = metrics.get('domain_performance', {}) logger.info("\nDomain Performance:") for domain, domain_stats in domain_perf.items(): logger.info(f"\n{domain.title()}:") for metric, value in domain_stats.items(): logger.info(f" {metric}: {value:.4f}") # Confidence distribution conf_analysis = metrics.get('confidence_analysis', {}) logger.info("\nConfidence Distribution:") for pct, val in conf_analysis.items(): logger.info(f" {pct}: {val:.4f}")