File size: 9,236 Bytes
f7b283c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
from typing import Dict, List, Tuple, Any, Optional
import numpy as np
from logger_config import config_logger
logger = config_logger(__name__)
class ChatbotValidator:
"""Handles automated validation and performance analysis for the chatbot."""
def __init__(self, chatbot, quality_checker):
"""
Initialize the validator.
Args:
chatbot: RetrievalChatbot instance
quality_checker: ResponseQualityChecker instance
"""
self.chatbot = chatbot
self.quality_checker = quality_checker
# Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
self.domain_queries = {
'restaurant': [
"I'd like to make a reservation for dinner tonight.",
"Can you book a table for 4 people at an Italian place?",
"Do you have any availability for tomorrow at 7pm?",
"I need to change my dinner reservation time.",
"What's the wait time for a table right now?"
],
'movie_tickets': [
"I want to buy tickets for the new Marvel movie.",
"Are there any showings of Avatar after 6pm?",
"Can I get 3 tickets for the 8pm show?",
"What movies are playing this weekend?",
"Do you have any matinee showtimes available?"
],
'rideshare': [
"I need a ride from the airport to downtown.",
"How much would it cost to get to the mall?",
"Can you book a car for tomorrow morning?",
"Is there a driver available now?",
"What's the estimated arrival time?"
],
'services': [
"I need to schedule an oil change for my car.",
"When can I bring my car in for maintenance?",
"Do you have any openings for auto repair today?",
"How long will the service take?",
"Can I get an estimate for brake repair?"
],
'events': [
"I need tickets to the concert this weekend.",
"What events are happening near me?",
"Can I book seats for the basketball game?",
"Are there any comedy shows tonight?",
"How much are tickets to the theater?"
]
}
def run_validation(
self,
num_examples: int = 10,
top_k: int = 10,
domains: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Run comprehensive validation across specified domains.
Args:
num_examples: Number of test queries per domain
top_k: Number of responses to retrieve for each query
domains: Optional list of specific domains to test
Returns:
Dict containing detailed validation metrics and domain-specific performance
"""
logger.info("\n=== Running Enhanced Automatic Validation ===")
# Select domains to test
test_domains = domains if domains else list(self.domain_queries.keys())
metrics_history = []
domain_metrics = {}
# Run validation for each domain
for domain in test_domains:
domain_metrics[domain] = []
queries = self.domain_queries[domain][:num_examples]
logger.info(f"\n=== Testing {domain.title()} Domain ===")
for i, query in enumerate(queries, 1):
logger.info(f"\nTest Case {i}:")
logger.info(f"Query: {query}")
# Get responses with increased top_k
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
# Enhanced quality checking with context
quality_metrics = self.quality_checker.check_response_quality(query, responses)
# Add domain info
quality_metrics['domain'] = domain
metrics_history.append(quality_metrics)
domain_metrics[domain].append(quality_metrics)
# Detailed logging
self._log_validation_results(query, responses, quality_metrics, i)
# Calculate and log overall metrics
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
domain_analysis = self._analyze_domain_performance(domain_metrics)
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
aggregate_metrics.update({
'domain_performance': domain_analysis,
'confidence_analysis': confidence_analysis
})
self._log_validation_summary(aggregate_metrics)
return aggregate_metrics
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
"""Calculate comprehensive aggregate metrics."""
metrics = {
'num_queries_tested': len(metrics_history),
'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
# Additional statistical metrics
'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
'min_score': np.min([m.get('top_score', 0) for m in metrics_history]),
'max_score': np.max([m.get('top_score', 0) for m in metrics_history])
}
return metrics
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
"""Analyze performance by domain."""
domain_analysis = {}
for domain, metrics in domain_metrics.items():
domain_analysis[domain] = {
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics]),
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics]),
'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
'num_samples': len(metrics)
}
return domain_analysis
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
"""Analyze the distribution of confidence scores."""
scores = [m.get('top_score', 0) for m in metrics_history]
return {
'percentile_25': np.percentile(scores, 25),
'percentile_50': np.percentile(scores, 50),
'percentile_75': np.percentile(scores, 75),
'percentile_90': np.percentile(scores, 90)
}
def _log_validation_results(
self,
query: str,
responses: List[Tuple[str, float]],
metrics: Dict[str, Any],
case_num: int
):
"""Log detailed validation results."""
logger.info(f"\nTest Case {case_num}:")
logger.info(f"Query: {query}")
logger.info(f"Domain: {metrics.get('domain', 'Unknown')}")
logger.info(f"Confidence: {'Yes' if metrics.get('is_confident', False) else 'No'}")
logger.info("\nQuality Metrics:")
for metric, value in metrics.items():
if isinstance(value, (int, float)):
logger.info(f" {metric}: {value:.4f}")
logger.info("\nTop Responses:")
for i, (response, score) in enumerate(responses[:3], 1):
logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
if i == 1 and not metrics.get('is_confident', False):
logger.info(" [Low Confidence]")
def _log_validation_summary(self, metrics: Dict[str, Any]):
"""Log comprehensive validation summary."""
logger.info("\n=== Validation Summary ===")
logger.info("\nOverall Metrics:")
for metric, value in metrics.items():
if isinstance(value, (int, float)):
logger.info(f"{metric}: {value:.4f}")
logger.info("\nDomain Performance:")
for domain, domain_metrics in metrics['domain_performance'].items():
logger.info(f"\n{domain.title()}:")
for metric, value in domain_metrics.items():
logger.info(f" {metric}: {value:.4f}")
logger.info("\nConfidence Distribution:")
for percentile, value in metrics['confidence_analysis'].items():
logger.info(f"{percentile}: {value:.4f}") |