File size: 9,236 Bytes
f7b283c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from typing import Dict, List, Tuple, Any, Optional
import numpy as np
from logger_config import config_logger

logger = config_logger(__name__)

class ChatbotValidator:
    """Handles automated validation and performance analysis for the chatbot."""
    
    def __init__(self, chatbot, quality_checker):
        """
        Initialize the validator.
        
        Args:
            chatbot: RetrievalChatbot instance
            quality_checker: ResponseQualityChecker instance
        """
        self.chatbot = chatbot
        self.quality_checker = quality_checker
        
        # Domain-specific test queries aligned with Taskmaster-1 and Schema-Guided
        self.domain_queries = {
            'restaurant': [
                "I'd like to make a reservation for dinner tonight.",
                "Can you book a table for 4 people at an Italian place?",
                "Do you have any availability for tomorrow at 7pm?",
                "I need to change my dinner reservation time.",
                "What's the wait time for a table right now?"
            ],
            'movie_tickets': [
                "I want to buy tickets for the new Marvel movie.",
                "Are there any showings of Avatar after 6pm?",
                "Can I get 3 tickets for the 8pm show?",
                "What movies are playing this weekend?",
                "Do you have any matinee showtimes available?"
            ],
            'rideshare': [
                "I need a ride from the airport to downtown.",
                "How much would it cost to get to the mall?",
                "Can you book a car for tomorrow morning?",
                "Is there a driver available now?",
                "What's the estimated arrival time?"
            ],
            'services': [
                "I need to schedule an oil change for my car.",
                "When can I bring my car in for maintenance?",
                "Do you have any openings for auto repair today?",
                "How long will the service take?",
                "Can I get an estimate for brake repair?"
            ],
            'events': [
                "I need tickets to the concert this weekend.",
                "What events are happening near me?",
                "Can I book seats for the basketball game?",
                "Are there any comedy shows tonight?",
                "How much are tickets to the theater?"
            ]
        }

    def run_validation(
        self,
        num_examples: int = 10,
        top_k: int = 10,
        domains: Optional[List[str]] = None
    ) -> Dict[str, Any]:
        """
        Run comprehensive validation across specified domains.
        
        Args:
            num_examples: Number of test queries per domain
            top_k: Number of responses to retrieve for each query
            domains: Optional list of specific domains to test
            
        Returns:
            Dict containing detailed validation metrics and domain-specific performance
        """
        logger.info("\n=== Running Enhanced Automatic Validation ===")
        
        # Select domains to test
        test_domains = domains if domains else list(self.domain_queries.keys())
        metrics_history = []
        domain_metrics = {}
        
        # Run validation for each domain
        for domain in test_domains:
            domain_metrics[domain] = []
            queries = self.domain_queries[domain][:num_examples]
            
            logger.info(f"\n=== Testing {domain.title()} Domain ===")
            
            for i, query in enumerate(queries, 1):
                logger.info(f"\nTest Case {i}:")
                logger.info(f"Query: {query}")
                
                # Get responses with increased top_k
                responses = self.chatbot.retrieve_responses_cross_encoder(query, top_k=top_k)
                
                # Enhanced quality checking with context
                quality_metrics = self.quality_checker.check_response_quality(query, responses)
                
                # Add domain info
                quality_metrics['domain'] = domain
                metrics_history.append(quality_metrics)
                domain_metrics[domain].append(quality_metrics)
                
                # Detailed logging
                self._log_validation_results(query, responses, quality_metrics, i)

        # Calculate and log overall metrics
        aggregate_metrics = self._calculate_aggregate_metrics(metrics_history)
        domain_analysis = self._analyze_domain_performance(domain_metrics)
        confidence_analysis = self._analyze_confidence_distribution(metrics_history)
        
        aggregate_metrics.update({
            'domain_performance': domain_analysis,
            'confidence_analysis': confidence_analysis
        })
        
        self._log_validation_summary(aggregate_metrics)
        return aggregate_metrics

    def _calculate_aggregate_metrics(self, metrics_history: List[Dict]) -> Dict[str, float]:
        """Calculate comprehensive aggregate metrics."""
        metrics = {
            'num_queries_tested': len(metrics_history),
            'avg_top_response_score': np.mean([m.get('top_score', 0) for m in metrics_history]),
            'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics_history]),
            'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics_history]),
            'avg_length_score': np.mean([m.get('response_length_score', 0) for m in metrics_history]),
            'avg_score_gap': np.mean([m.get('top_3_score_gap', 0) for m in metrics_history]),
            'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics_history]),
            
            # Additional statistical metrics
            'median_top_score': np.median([m.get('top_score', 0) for m in metrics_history]),
            'score_std': np.std([m.get('top_score', 0) for m in metrics_history]),
            'min_score': np.min([m.get('top_score', 0) for m in metrics_history]),
            'max_score': np.max([m.get('top_score', 0) for m in metrics_history])
        }
        return metrics

    def _analyze_domain_performance(self, domain_metrics: Dict[str, List[Dict]]) -> Dict[str, Dict]:
        """Analyze performance by domain."""
        domain_analysis = {}
        
        for domain, metrics in domain_metrics.items():
            domain_analysis[domain] = {
                'confidence_rate': np.mean([m.get('is_confident', False) for m in metrics]),
                'avg_relevance': np.mean([m.get('query_response_relevance', 0) for m in metrics]),
                'avg_diversity': np.mean([m.get('response_diversity', 0) for m in metrics]),
                'avg_top_score': np.mean([m.get('top_score', 0) for m in metrics]),
                'num_samples': len(metrics)
            }
        
        return domain_analysis

    def _analyze_confidence_distribution(self, metrics_history: List[Dict]) -> Dict[str, float]:
        """Analyze the distribution of confidence scores."""
        scores = [m.get('top_score', 0) for m in metrics_history]
        
        return {
            'percentile_25': np.percentile(scores, 25),
            'percentile_50': np.percentile(scores, 50),
            'percentile_75': np.percentile(scores, 75),
            'percentile_90': np.percentile(scores, 90)
        }

    def _log_validation_results(
        self, 
        query: str, 
        responses: List[Tuple[str, float]], 
        metrics: Dict[str, Any],
        case_num: int
    ):
        """Log detailed validation results."""
        logger.info(f"\nTest Case {case_num}:")
        logger.info(f"Query: {query}")
        logger.info(f"Domain: {metrics.get('domain', 'Unknown')}")
        logger.info(f"Confidence: {'Yes' if metrics.get('is_confident', False) else 'No'}")
        logger.info("\nQuality Metrics:")
        for metric, value in metrics.items():
            if isinstance(value, (int, float)):
                logger.info(f"  {metric}: {value:.4f}")
        
        logger.info("\nTop Responses:")
        for i, (response, score) in enumerate(responses[:3], 1):
            logger.info(f"{i}. Score: {score:.4f}. Response: {response}")
            if i == 1 and not metrics.get('is_confident', False):
                logger.info("   [Low Confidence]")

    def _log_validation_summary(self, metrics: Dict[str, Any]):
        """Log comprehensive validation summary."""
        logger.info("\n=== Validation Summary ===")
        
        logger.info("\nOverall Metrics:")
        for metric, value in metrics.items():
            if isinstance(value, (int, float)):
                logger.info(f"{metric}: {value:.4f}")
        
        logger.info("\nDomain Performance:")
        for domain, domain_metrics in metrics['domain_performance'].items():
            logger.info(f"\n{domain.title()}:")
            for metric, value in domain_metrics.items():
                logger.info(f"  {metric}: {value:.4f}")
        
        logger.info("\nConfidence Distribution:")
        for percentile, value in metrics['confidence_analysis'].items():
            logger.info(f"{percentile}: {value:.4f}")