File size: 13,527 Bytes
e611d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc, precision_recall_curve

from .constants import BASE_DIR
from .loading import load_data_split
from ..antigen.antigen import AntigenChain
from .metrics import find_optimal_threshold, calculate_node_metrics

def evaluate_ReCEP(model_path, device_id=0, radius=18.0, threshold=0.5, k=5, 
                   verbose=True, split="test", save_results=True, output_dir=None, encoder="esmc"):
    """
    Evaluate ReCEP model on a dataset split using both probability-based and voting-based predictions.
    
    Args:
        model_path: Path to the trained ReCEP model
        device_id: GPU device ID
        radius: Radius for spherical regions
        threshold: Threshold for probability-based predictions
        k: Number of top regions to select
        verbose: Whether to print progress
        split: Dataset split to evaluate ('test', 'val', 'train')
        save_results: Whether to save detailed results to files
        output_dir: Directory to save results (if save_results=True)
        
    Returns:
        Dictionary containing evaluation metrics for both prediction methods
    """
    print(f"[INFO] Evaluating ReCEP model from {model_path}")
    print(f"[INFO] Settings:")
    print(f"  Radius: {radius}")
    print(f"  K: {k}")
    print(f"  Split: {split}\n")

    antigens = load_data_split(split, verbose=verbose)
    
    # Collect data for all proteins
    all_true_labels = []
    all_predicted_probs = []
    all_voted_labels = []
    all_predicted_binary = []
    
    protein_results = []
    
    for pdb_id, chain_id in tqdm(antigens, desc=f"Evaluating ReCEP on {split} set", disable=not verbose):
        try:
            antigen_chain = AntigenChain.from_pdb(chain_id=chain_id, id=pdb_id)
            results = antigen_chain.evaluate(
                model_path=model_path,
                device_id=device_id,
                radius=radius,
                threshold=threshold,
                k=k,
                verbose=False,
                encoder=encoder
            )
            
            # Get true epitope labels as binary array
            true_epitopes = antigen_chain.get_epitope_residue_numbers()
            true_binary = []
            predicted_probs = []
            voted_binary = []
            predicted_binary = []
            
            # Convert to aligned arrays based on residue numbers
            for idx in range(len(antigen_chain.residue_index)):
                residue_num = int(antigen_chain.residue_index[idx])
                
                # True label
                true_binary.append(1 if residue_num in true_epitopes else 0)
                
                # Predicted probability
                predicted_probs.append(results['predictions'].get(residue_num, 0))
                
                # Voted prediction
                voted_binary.append(1 if residue_num in results['voted_epitopes'] else 0)
                
                # Probability-based prediction
                predicted_binary.append(1 if residue_num in results['predicted_epitopes'] else 0)
            
            # Store for overall evaluation
            all_true_labels.extend(true_binary)
            all_predicted_probs.extend(predicted_probs)
            all_voted_labels.extend(voted_binary)
            all_predicted_binary.extend(predicted_binary)
            
            length = len(antigen_chain.sequence)
            species = antigen_chain.get_species()
            precision = results['predicted_precision']
            recall = results['predicted_recall']
            f1 = 2 * precision * recall / (precision + recall + 1e-10)
            
            # Calculate PR-AUC using true_binary and predicted_probs
            if len(set(true_binary)) > 1:  # Check if there are both positive and negative samples
                pr_precision, pr_recall, _ = precision_recall_curve(true_binary, predicted_probs)
                pr_auc = auc(pr_recall, pr_precision)
            else:
                pr_auc = 0.0  # Default value when all labels are the same
            
            # Store individual protein results
            protein_results.append({
                'pdb_id': pdb_id,
                'chain_id': chain_id,
                'length': length,
                'species': species,
                'predicted_precision': precision,
                'predicted_recall': recall,
                'predicted_f1': f1,
                'pr_auc': pr_auc,
                'voted_precision': results['voted_precision'],
                'voted_recall': results['voted_recall'],
                'num_residues': len(true_binary),
                'num_true_epitopes': sum(true_binary),
                'num_predicted_epitopes': sum(predicted_binary),
                'num_voted_epitopes': sum(voted_binary),
                'true_epitopes': true_binary,
                'predicted_probabilities': predicted_probs
            })
            
        except Exception as e:
            if verbose:
                print(f"[WARNING] Failed to evaluate {pdb_id}_{chain_id}: {str(e)}")
            continue
    
    # Convert to numpy arrays
    all_true_labels = np.array(all_true_labels)
    all_predicted_probs = np.array(all_predicted_probs)
    all_voted_labels = np.array(all_voted_labels)
    all_predicted_binary = np.array(all_predicted_binary)
    
    # Calculate metrics for probability-based predictions (includes both probability and binary metrics)
    prob_metrics = calculate_node_metrics(all_predicted_probs, all_true_labels, find_threshold=True, include_curves=True)
    
    # Calculate metrics for voting-based predictions (binary only)
    vote_metrics = calculate_node_metrics(all_voted_labels.astype(float), all_true_labels, find_threshold=False)
    
    # Calculate metrics for probability-based binary predictions using original threshold
    pred_metrics = calculate_node_metrics(all_predicted_binary.astype(float), all_true_labels, find_threshold=False)
    
    # Additional statistics for comprehensive evaluation
    prediction_stats = {
        'prob_based': {
            'total_predicted_positive': int(np.sum(all_predicted_binary)),
            'prediction_rate': float(np.mean(all_predicted_binary))
        },
        'vote_based': {
            'total_predicted_positive': int(np.sum(all_voted_labels)),
            'prediction_rate': float(np.mean(all_voted_labels))
        }
    }
    
    # Overall statistics
    overall_stats = {
        'num_proteins': len(protein_results),
        'total_residues': len(all_true_labels),
        'total_true_epitopes': int(np.sum(all_true_labels)),
        'epitope_ratio': float(np.mean(all_true_labels)),
        'avg_protein_size': np.mean([p['num_residues'] for p in protein_results]),
        'avg_epitopes_per_protein': np.mean([p['num_true_epitopes'] for p in protein_results]),
        'prediction_stats': prediction_stats
    }
    
    if verbose:
        print_evaluation_results(prob_metrics, vote_metrics, pred_metrics, overall_stats, threshold)
    
    # Prepare results dictionary
    results = {
        'probability_metrics': prob_metrics,
        'voted_metrics': vote_metrics,
        'predicted_metrics': pred_metrics,
        'overall_stats': overall_stats,
        'protein_results': protein_results,
        'threshold': threshold
    }
    
    if save_results:
        if output_dir is None:
            # Handle both string and Path objects
            from pathlib import Path
            model_path_obj = Path(model_path)
            timestamp = model_path_obj.parent.name
            model_name = model_path_obj.name.split("_")[1]
            output_dir = BASE_DIR / "results" / "ReCEP" / timestamp
        save_evaluation_results(results, output_dir, model_name)
    
    return results



def print_evaluation_results(prob_metrics, vote_metrics, pred_metrics, overall_stats, threshold):
    """Print formatted evaluation results for both prediction modes."""
    print(f"\n{'='*80}")
    print(f"ReCEP MODEL EVALUATION RESULTS")
    print(f"{'='*80}")
    
    print(f"\nOverall Statistics:")
    print(f"  Number of proteins: {overall_stats['num_proteins']}")
    print(f"  Total residues: {overall_stats['total_residues']:,}")
    print(f"  Total true epitopes: {overall_stats['total_true_epitopes']:,}")
    print(f"  Epitope ratio: {overall_stats['epitope_ratio']:.3f}")
    print(f"  Average protein size: {overall_stats['avg_protein_size']:.1f}")
    print(f"  Average epitopes per protein: {overall_stats['avg_epitopes_per_protein']:.1f}")
    
    print(f"\n{'-'*40}")
    print(f"PROBABILITY-BASED PREDICTIONS")
    print(f"{'-'*40}")
    print(f"Threshold: {prob_metrics['best_threshold']}")
    print(f"\nProbability Metrics:")
    print(f"  AUPRC: {prob_metrics['auprc']:.4f}")
    print(f"  AUROC: {prob_metrics['auroc']:.4f}")
    print(f"\nBinary Classification Metrics:")
    print(f"  Accuracy:  {prob_metrics['accuracy']:.4f}")
    print(f"  Precision: {prob_metrics['precision']:.4f}")
    print(f"  Recall:    {prob_metrics['recall']:.4f}")
    print(f"  F1-Score:  {prob_metrics['f1']:.4f}")
    print(f"  MCC:       {prob_metrics['mcc']:.4f}")
    print(f"\nConfusion Matrix:")
    print(f"  True Pos:  {prob_metrics['true_positives']:>6} | False Pos: {prob_metrics['false_positives']:>6}")
    print(f"  False Neg: {prob_metrics['false_negatives']:>6} | True Neg:  {prob_metrics['true_negatives']:>6}")
    
    print(f"\n{'-'*40}")
    print(f"VOTING-BASED PREDICTIONS")
    print(f"{'-'*40}")
    print(f"Binary Classification Metrics:")
    print(f"  Accuracy:  {vote_metrics['accuracy']:.4f}")
    print(f"  Precision: {vote_metrics['precision']:.4f}")
    print(f"  Recall:    {vote_metrics['recall']:.4f}")
    print(f"  F1-Score:  {vote_metrics['f1']:.4f}")
    print(f"  MCC:       {vote_metrics['mcc']:.4f}")
    print(f"\nConfusion Matrix:")
    print(f"  True Pos:  {vote_metrics['true_positives']:>6} | False Pos: {vote_metrics['false_positives']:>6}")
    print(f"  False Neg: {vote_metrics['false_negatives']:>6} | True Neg:  {vote_metrics['true_negatives']:>6}")
    
    print(f"\n{'-'*40}")
    print(f"COMPARISON SUMMARY")
    print(f"{'-'*40}")
    print(f"{'Metric':<12} {'Probability':<12} {'Voting':<12} {'Difference':<12}")
    print(f"{'-'*48}")
    print(f"{'Accuracy':<12} {prob_metrics['accuracy']:<12.4f} {vote_metrics['accuracy']:<12.4f} {prob_metrics['accuracy']-vote_metrics['accuracy']:<12.4f}")
    print(f"{'Precision':<12} {prob_metrics['precision']:<12.4f} {vote_metrics['precision']:<12.4f} {prob_metrics['precision']-vote_metrics['precision']:<12.4f}")
    print(f"{'Recall':<12} {prob_metrics['recall']:<12.4f} {vote_metrics['recall']:<12.4f} {prob_metrics['recall']-vote_metrics['recall']:<12.4f}")
    print(f"{'F1-Score':<12} {prob_metrics['f1']:<12.4f} {vote_metrics['f1']:<12.4f} {prob_metrics['f1']-vote_metrics['f1']:<12.4f}")
    print(f"{'MCC':<12} {prob_metrics['mcc']:<12.4f} {vote_metrics['mcc']:<12.4f} {prob_metrics['mcc']-vote_metrics['mcc']:<12.4f}")
    
    print(f"\n{'='*80}")
    

def save_evaluation_results(results, output_dir=None, prefix="evaluation"):
    """
    Save detailed evaluation results to files for further analysis.
    
    Args:
        results: Dictionary containing all evaluation results
        output_dir: Directory to save results
        prefix: Prefix for output files
    """
    import os
    import json
    
    if output_dir is None:
        output_dir = BASE_DIR / "results" / "evaluation"
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Save overall results as JSON
    results_to_save = {
        'probability_metrics': results['probability_metrics'],
        'voted_metrics': results['voted_metrics'], 
        'predicted_metrics': results['predicted_metrics'],
        'overall_stats': results['overall_stats'],
        'threshold': results['threshold']
    }
    
    # Remove non-serializable items (curves)
    if 'pr_curve' in results_to_save['probability_metrics']:
        if results_to_save['probability_metrics']['pr_curve'] is not None:
            # Convert numpy arrays to lists for JSON serialization
            results_to_save['probability_metrics']['pr_curve'] = {
                'precision': results_to_save['probability_metrics']['pr_curve']['precision'].tolist(),
                'recall': results_to_save['probability_metrics']['pr_curve']['recall'].tolist()
            }
    
    if 'roc_curve' in results_to_save['probability_metrics']:
        if results_to_save['probability_metrics']['roc_curve'] is not None:
            results_to_save['probability_metrics']['roc_curve'] = {
                'fpr': results_to_save['probability_metrics']['roc_curve']['fpr'].tolist(),
                'tpr': results_to_save['probability_metrics']['roc_curve']['tpr'].tolist()
            }
    
    # Save main results
    with open(os.path.join(output_dir, f"{prefix}_results.json"), 'w') as f:
        json.dump(results_to_save, f, indent=2)
    
    # Save protein-level results as CSV
    if 'protein_results' in results:
        df = pd.DataFrame(results['protein_results'])
        df.to_csv(os.path.join(output_dir, f"{prefix}_protein_results.csv"), index=False)
    
    print(f"\nResults saved to {output_dir}/")
    print(f"  - {prefix}_results.json: Overall metrics")
    print(f"  - {prefix}_protein_results.csv: Per-protein results")