csc525_retrieval_based_chatbot / chatbot_validator.py
JoeArmani
summarization, reranker, environment setup, and response quality checker
f7b283c
raw
history blame
9.24 kB
from typing import Dict, List, Tuple, Any, Optional
import numpy as np
from logger_config import config_logger
logger = config_logger(__name__)
class ChatbotValidator:
"""Handles automated validation and performance analysis for the chatbot."""
def __init__(self, chatbot, quality_checker):
"""
Initialize the validator.
Args:
chatbot: RetrievalChatbot instance
quality_checker: ResponseQualityChecker instance
"""
self.chatbot = chatbot
self.quality_checker = quality_checker
# Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
self.domain_queries = {
'restaurant': [
"I'd like to make a reservation for dinner tonight.",
"Can you book a table for 4 people at an Italian place?",
"Do you have any availability for tomorrow at 7pm?",
"I need to change my dinner reservation time.",
"What's the wait time for a table right now?"
],
'movie_tickets': [
"I want to buy tickets for the new Marvel movie.",
"Are there any showings of Avatar after 6pm?",
"Can I get 3 tickets for the 8pm show?",
"What movies are playing this weekend?",
"Do you have any matinee showtimes available?"
],
'rideshare': [
"I need a ride from the airport to downtown.",
"How much would it cost to get to the mall?",
"Can you book a car for tomorrow morning?",
"Is there a driver available now?",
"What's the estimated arrival time?"
],
'services': [
"I need to schedule an oil change for my car.",
"When can I bring my car in for maintenance?",
"Do you have any openings for auto repair today?",
"How long will the service take?",
"Can I get an estimate for brake repair?"
],
'events': [
"I need tickets to the concert this weekend.",
"What events are happening near me?",
"Can I book seats for the basketball game?",
"Are there any comedy shows tonight?",
"How much are tickets to the theater?"
]
}
def run_validation(
self,
num_examples: int = 10,
top_k: int = 10,
domains: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Run comprehensive validation across specified domains.
Args:
num_examples: Number of test queries per domain
top_k: Number of responses to retrieve for each query
domains: Optional list of specific domains to test
Returns:
Dict containing detailed validation metrics and domain-specific performance
"""
logger.info("\n=== Running Enhanced Automatic Validation ===")
# Select domains to test
test_domains = domains if domains else list(self.domain_queries.keys())
metrics_history = []
domain_metrics = {}
# Run validation for each domain
for domain in test_domains:
domain_metrics[domain] = []
queries = self.domain_queries[domain][:num_examples]
logger.info(f"\n=== Testing {domain.title()} Domain ===")
for i, query in enumerate(queries, 1):
logger.info(f"\nTest Case {i}:")
logger.info(f"Query: {query}")
# Get responses with increased top_k
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
# Enhanced quality checking with context
quality_metrics = self.quality_checker.check_response_quality(query, responses)
# Add domain info
quality_metrics['domain'] = domain
metrics_history.append(quality_metrics)
domain_metrics[domain].append(quality_metrics)
# Detailed logging
self._log_validation_results(query, responses, quality_metrics, i)
# Calculate and log overall metrics
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
domain_analysis = self._analyze_domain_performance(domain_metrics)
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
aggregate_metrics.update({
'domain_performance': domain_analysis,
'confidence_analysis': confidence_analysis
})
self._log_validation_summary(aggregate_metrics)
return aggregate_metrics
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
"""Calculate comprehensive aggregate metrics."""
metrics = {
'num_queries_tested': len(metrics_history),
'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
# Additional statistical metrics
'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
'min_score': np.min([m.get('top_score', 0) for m in metrics_history]),
'max_score': np.max([m.get('top_score', 0) for m in metrics_history])
}
return metrics
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
"""Analyze performance by domain."""
domain_analysis = {}
for domain, metrics in domain_metrics.items():
domain_analysis[domain] = {
'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics]),
'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics]),
'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
'num_samples': len(metrics)
}
return domain_analysis
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
"""Analyze the distribution of confidence scores."""
scores = [m.get('top_score', 0) for m in metrics_history]
return {
'percentile_25': np.percentile(scores, 25),
'percentile_50': np.percentile(scores, 50),
'percentile_75': np.percentile(scores, 75),
'percentile_90': np.percentile(scores, 90)
}
def _log_validation_results(
self,
query: str,
responses: List[Tuple[str, float]],
metrics: Dict[str, Any],
case_num: int
):
"""Log detailed validation results."""
logger.info(f"\nTest Case {case_num}:")
logger.info(f"Query: {query}")
logger.info(f"Domain: {metrics.get('domain', 'Unknown')}")
logger.info(f"Confidence: {'Yes' if metrics.get('is_confident', False) else 'No'}")
logger.info("\nQuality Metrics:")
for metric, value in metrics.items():
if isinstance(value, (int, float)):
logger.info(f" {metric}: {value:.4f}")
logger.info("\nTop Responses:")
for i, (response, score) in enumerate(responses[:3], 1):
logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
if i == 1 and not metrics.get('is_confident', False):
logger.info(" [Low Confidence]")
def _log_validation_summary(self, metrics: Dict[str, Any]):
"""Log comprehensive validation summary."""
logger.info("\n=== Validation Summary ===")
logger.info("\nOverall Metrics:")
for metric, value in metrics.items():
if isinstance(value, (int, float)):
logger.info(f"{metric}: {value:.4f}")
logger.info("\nDomain Performance:")
for domain, domain_metrics in metrics['domain_performance'].items():
logger.info(f"\n{domain.title()}:")
for metric, value in domain_metrics.items():
logger.info(f" {metric}: {value:.4f}")
logger.info("\nConfidence Distribution:")
for percentile, value in metrics['confidence_analysis'].items():
logger.info(f"{percentile}: {value:.4f}")