csc525_retrieval_based_chatbot / chatbot_validator.py
JoeArmani
updates - new iteration with type token
7a0020b
raw
history blame
11.4 kB
from typing import Dict, List, Tuple, Any, Optional
import numpy as np
import random
from logger_config import config_logger
from cross_encoder_reranker import CrossEncoderReranker
logger = config_logger(__name__)
class ChatbotValidator:
"""
Handles automated validation and performance analysis for the chatbot.
This validator executes domain-specific test queries, obtains candidate
responses via the chatbot, then evaluates them with a quality checker.
It aggregates metrics across queries and domains, logs intermediate
results, and returns a comprehensive summary.
"""
def __init__(self, chatbot, quality_checker):
"""
Initialize the validator.
Args:
chatbot: RetrievalChatbot instance for inference
quality_checker: ResponseQualityChecker instance
"""
self.chatbot = chatbot
self.quality_checker = quality_checker
# Basic domain-specific test queries (easy examples)
# Taskmaster-1 and Schema-Guided style
self.domain_queries = {
'restaurant': [
"I'd like to make a reservation for dinner tonight.",
"Can you book a table for 4 at an Italian restaurant?",
"Is there any availability to dine tomorrow at 7pm?",
"I'd like to cancel my reservation for tonight.",
"What's the wait time for a table right now?"
],
'movie_tickets': [
"I want to buy tickets for the new Marvel movie.",
"Are there any showings of Avatar after 6pm?",
"Can I get 3 tickets for the 8pm show?",
"What movies are playing this weekend?",
"Do you have any matinee showtimes available?"
],
'rideshare': [
"I need a ride from the airport to downtown.",
"How much would it cost to get to the mall?",
"Can you book a car for tomorrow morning?",
"Is there a driver available right now?",
"What's the estimated arrival time for the driver?"
],
'services': [
"I need to schedule an oil change for my car.",
"When can I bring my car in for maintenance?",
"Do you have any openings for auto repair today?",
"How long will the service take?",
"Can I get an estimate for brake repair?"
],
'events': [
"I need tickets to the concert this weekend.",
"What events are happening near me?",
"Can I book seats for the basketball game?",
"Are there any comedy shows tonight?",
"How much are tickets to the theater?"
]
}
def run_validation(
self,
num_examples: int = 5,
top_k: int = 10,
domains: Optional[List[str]] = None,
randomize: bool = False,
seed: int = 42
) -> Dict[str, Any]:
"""
Run comprehensive validation across specified domains.
Args:
num_examples: Number of test queries per domain
top_k: Number of responses to retrieve for each query
domains: Optional list of domain keys to test. If None, test all.
randomize: If True, randomly select queries from the domain lists
seed: Random seed for consistent sampling if randomize=True
Returns:
Dict containing detailed validation metrics and domain-specific performance
"""
logger.info("\n=== Running Enhanced Automatic Validation ===")
# Select which domains to test
test_domains = domains if domains else list(self.domain_queries.keys())
# Initialize results
metrics_history = []
domain_metrics = {}
reranker = CrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-12-v2")
# Prepare random selection if needed
rng = random.Random(seed)
# Run validation for each domain
for domain in test_domains:
# Avoid errors if domain key missing
if domain not in self.domain_queries:
logger.warning(f"Domain '{domain}' not found in domain_queries. Skipping.")
continue
all_queries = self.domain_queries[domain]
if randomize:
queries = rng.sample(all_queries, min(num_examples, len(all_queries)))
else:
queries = all_queries[:num_examples]
# Store domain-level metrics
domain_metrics[domain] = []
logger.info(f"\n=== Testing {domain.title()} Domain ===")
for i, query in enumerate(queries, 1):
logger.info(f"\nTest Case {i}: {query}")
# Retrieve top_k responses (including cross-encoder re-ranking if available)
responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k, reranker=reranker)
# Evaluate with quality checker
quality_metrics = self.quality_checker.check_response_quality(query, responses)
# Save domain info
quality_metrics['domain'] = domain
metrics_history.append(quality_metrics)
domain_metrics[domain].append(quality_metrics)
# Detailed logging
self._log_validation_results(query, responses, quality_metrics, i)
# Final aggregation
aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
domain_analysis = self._analyze_domain_performance(domain_metrics)
confidence_analysis = self._analyze_confidence_distribution(metrics_history)
# Combine into one dictionary
aggregate_metrics.update({
'domain_performance': domain_analysis,
'confidence_analysis': confidence_analysis
})
self._log_validation_summary(aggregate_metrics)
return aggregate_metrics
def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
"""
Calculate comprehensive aggregate metrics over all tested queries.
"""
if not metrics_history:
logger.warning("No metrics to aggregate. Returning empty summary.")
return {}
top_scores = [m.get('top_score', 0.0) for m in metrics_history]
# The length-based metrics are robust to missing or zero-length data
metrics = {
'num_queries_tested': len(metrics_history),
'avg_top_response_score': np.mean(top_scores),
'avg_diversity': np.mean([m.get('response_diversity', 0.0) for m in metrics_history]),
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0) for m in metrics_history]),
'avg_length_score': np.mean([m.get('response_length_score', 0.0) for m in metrics_history]),
'avg_score_gap': np.mean([m.get('top_3_score_gap', 0.0) for m in metrics_history]),
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0
for m in metrics_history]),
# Additional statistical metrics
'median_top_score': np.median(top_scores),
'score_std': np.std(top_scores),
'min_score': np.min(top_scores),
'max_score': np.max(top_scores)
}
return metrics
def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict[str, float]]:
"""
Analyze performance by domain, returning a nested dict.
"""
analysis = {}
for domain, metrics_list in domain_metrics.items():
if not metrics_list:
analysis[domain] = {}
continue
top_scores = [m.get('top_score', 0.0) for m in metrics_list]
analysis[domain] = {
'confidence_rate': np.mean([1.0 if m.get('is_confident', False) else 0.0
for m in metrics_list]),
'avg_relevance': np.mean([m.get('query_response_relevance', 0.0)
for m in metrics_list]),
'avg_diversity': np.mean([m.get('response_diversity', 0.0)
for m in metrics_list]),
'avg_top_score': np.mean(top_scores),
'num_samples': len(metrics_list)
}
return analysis
def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
"""
Analyze the distribution of top scores to gauge system confidence levels.
"""
if not metrics_history:
return {'percentile_25': 0.0, 'percentile_50': 0.0,
'percentile_75': 0.0, 'percentile_90': 0.0}
scores = [m.get('top_score', 0.0) for m in metrics_history]
return {
'percentile_25': float(np.percentile(scores, 25)),
'percentile_50': float(np.percentile(scores, 50)),
'percentile_75': float(np.percentile(scores, 75)),
'percentile_90': float(np.percentile(scores, 90))
}
def _log_validation_results(
self,
query: str,
responses: List[Tuple[str, float]],
metrics: Dict[str, Any],
case_num: int
):
"""
Log detailed validation results for each test case.
"""
domain = metrics.get('domain', 'Unknown')
is_confident = metrics.get('is_confident', False)
logger.info(f"Domain: {domain} | Confidence: {'Yes' if is_confident else 'No'}")
logger.info("Quality Metrics:")
for k, v in metrics.items():
if isinstance(v, (int, float)):
logger.info(f" {k}: {v:.4f}")
logger.info("Top 3 Responses:")
for i, (resp_text, score) in enumerate(responses[:3], 1):
logger.info(f"{i}) Score: {score:.4f} | {resp_text}")
if i == 1 and not is_confident:
logger.info(" [Low Confidence on Top Response]")
def _log_validation_summary(self, metrics: Dict[str, Any]):
"""
Log a summary of all validation metrics and domain performance.
"""
if not metrics:
logger.info("No metrics to summarize.")
return
logger.info("\n=== Validation Summary ===")
# Overall
logger.info("\nOverall Metrics:")
for metric, value in metrics.items():
# Skip sub-dicts here
if isinstance(value, (int, float)):
logger.info(f"{metric}: {value:.4f}")
# Domain performance
domain_perf = metrics.get('domain_performance', {})
logger.info("\nDomain Performance:")
for domain, domain_stats in domain_perf.items():
logger.info(f"\n{domain.title()}:")
for metric, value in domain_stats.items():
logger.info(f" {metric}: {value:.4f}")
# Confidence distribution
conf_analysis = metrics.get('confidence_analysis', {})
logger.info("\nConfidence Distribution:")
for pct, val in conf_analysis.items():
logger.info(f" {pct}: {val:.4f}")