import numpy as np import pandas as pd from tqdm import tqdm from sklearn.metrics import roc_curve, auc, precision_recall_curve from .constants import BASE_DIR from .loading import load_data_split from ..antigen.antigen import AntigenChain from .metrics import find_optimal_threshold, calculate_node_metrics def evaluate_ReCEP(model_path, device_id=0, radius=18.0, threshold=0.5, k=5, verbose=True, split="test", save_results=True, output_dir=None, encoder="esmc"): """ Evaluate ReCEP model on a dataset split using both probability-based and voting-based predictions. Args: model_path: Path to the trained ReCEP model device_id: GPU device ID radius: Radius for spherical regions threshold: Threshold for probability-based predictions k: Number of top regions to select verbose: Whether to print progress split: Dataset split to evaluate ('test', 'val', 'train') save_results: Whether to save detailed results to files output_dir: Directory to save results (if save_results=True) Returns: Dictionary containing evaluation metrics for both prediction methods """ print(f"[INFO] Evaluating ReCEP model from {model_path}") print(f"[INFO] Settings:") print(f" Radius: {radius}") print(f" K: {k}") print(f" Split: {split}\n") antigens = load_data_split(split, verbose=verbose) # Collect data for all proteins all_true_labels = [] all_predicted_probs = [] all_voted_labels = [] all_predicted_binary = [] protein_results = [] for pdb_id, chain_id in tqdm(antigens, desc=f"Evaluating ReCEP on {split} set", disable=not verbose): try: antigen_chain = AntigenChain.from_pdb(chain_id=chain_id, id=pdb_id) results = antigen_chain.evaluate( model_path=model_path, device_id=device_id, radius=radius, threshold=threshold, k=k, verbose=False, encoder=encoder ) # Get true epitope labels as binary array true_epitopes = antigen_chain.get_epitope_residue_numbers() true_binary = [] predicted_probs = [] voted_binary = [] predicted_binary = [] # Convert to aligned arrays based on residue numbers for idx in range(len(antigen_chain.residue_index)): residue_num = int(antigen_chain.residue_index[idx]) # True label true_binary.append(1 if residue_num in true_epitopes else 0) # Predicted probability predicted_probs.append(results['predictions'].get(residue_num, 0)) # Voted prediction voted_binary.append(1 if residue_num in results['voted_epitopes'] else 0) # Probability-based prediction predicted_binary.append(1 if residue_num in results['predicted_epitopes'] else 0) # Store for overall evaluation all_true_labels.extend(true_binary) all_predicted_probs.extend(predicted_probs) all_voted_labels.extend(voted_binary) all_predicted_binary.extend(predicted_binary) length = len(antigen_chain.sequence) species = antigen_chain.get_species() precision = results['predicted_precision'] recall = results['predicted_recall'] f1 = 2 * precision * recall / (precision + recall + 1e-10) # Calculate PR-AUC using true_binary and predicted_probs if len(set(true_binary)) > 1: # Check if there are both positive and negative samples pr_precision, pr_recall, _ = precision_recall_curve(true_binary, predicted_probs) pr_auc = auc(pr_recall, pr_precision) else: pr_auc = 0.0 # Default value when all labels are the same # Store individual protein results protein_results.append({ 'pdb_id': pdb_id, 'chain_id': chain_id, 'length': length, 'species': species, 'predicted_precision': precision, 'predicted_recall': recall, 'predicted_f1': f1, 'pr_auc': pr_auc, 'voted_precision': results['voted_precision'], 'voted_recall': results['voted_recall'], 'num_residues': len(true_binary), 'num_true_epitopes': sum(true_binary), 'num_predicted_epitopes': sum(predicted_binary), 'num_voted_epitopes': sum(voted_binary), 'true_epitopes': true_binary, 'predicted_probabilities': predicted_probs }) except Exception as e: if verbose: print(f"[WARNING] Failed to evaluate {pdb_id}_{chain_id}: {str(e)}") continue # Convert to numpy arrays all_true_labels = np.array(all_true_labels) all_predicted_probs = np.array(all_predicted_probs) all_voted_labels = np.array(all_voted_labels) all_predicted_binary = np.array(all_predicted_binary) # Calculate metrics for probability-based predictions (includes both probability and binary metrics) prob_metrics = calculate_node_metrics(all_predicted_probs, all_true_labels, find_threshold=True, include_curves=True) # Calculate metrics for voting-based predictions (binary only) vote_metrics = calculate_node_metrics(all_voted_labels.astype(float), all_true_labels, find_threshold=False) # Calculate metrics for probability-based binary predictions using original threshold pred_metrics = calculate_node_metrics(all_predicted_binary.astype(float), all_true_labels, find_threshold=False) # Additional statistics for comprehensive evaluation prediction_stats = { 'prob_based': { 'total_predicted_positive': int(np.sum(all_predicted_binary)), 'prediction_rate': float(np.mean(all_predicted_binary)) }, 'vote_based': { 'total_predicted_positive': int(np.sum(all_voted_labels)), 'prediction_rate': float(np.mean(all_voted_labels)) } } # Overall statistics overall_stats = { 'num_proteins': len(protein_results), 'total_residues': len(all_true_labels), 'total_true_epitopes': int(np.sum(all_true_labels)), 'epitope_ratio': float(np.mean(all_true_labels)), 'avg_protein_size': np.mean([p['num_residues'] for p in protein_results]), 'avg_epitopes_per_protein': np.mean([p['num_true_epitopes'] for p in protein_results]), 'prediction_stats': prediction_stats } if verbose: print_evaluation_results(prob_metrics, vote_metrics, pred_metrics, overall_stats, threshold) # Prepare results dictionary results = { 'probability_metrics': prob_metrics, 'voted_metrics': vote_metrics, 'predicted_metrics': pred_metrics, 'overall_stats': overall_stats, 'protein_results': protein_results, 'threshold': threshold } if save_results: if output_dir is None: # Handle both string and Path objects from pathlib import Path model_path_obj = Path(model_path) timestamp = model_path_obj.parent.name model_name = model_path_obj.name.split("_")[1] output_dir = BASE_DIR / "results" / "ReCEP" / timestamp save_evaluation_results(results, output_dir, model_name) return results def print_evaluation_results(prob_metrics, vote_metrics, pred_metrics, overall_stats, threshold): """Print formatted evaluation results for both prediction modes.""" print(f"\n{'='*80}") print(f"ReCEP MODEL EVALUATION RESULTS") print(f"{'='*80}") print(f"\nOverall Statistics:") print(f" Number of proteins: {overall_stats['num_proteins']}") print(f" Total residues: {overall_stats['total_residues']:,}") print(f" Total true epitopes: {overall_stats['total_true_epitopes']:,}") print(f" Epitope ratio: {overall_stats['epitope_ratio']:.3f}") print(f" Average protein size: {overall_stats['avg_protein_size']:.1f}") print(f" Average epitopes per protein: {overall_stats['avg_epitopes_per_protein']:.1f}") print(f"\n{'-'*40}") print(f"PROBABILITY-BASED PREDICTIONS") print(f"{'-'*40}") print(f"Threshold: {prob_metrics['best_threshold']}") print(f"\nProbability Metrics:") print(f" AUPRC: {prob_metrics['auprc']:.4f}") print(f" AUROC: {prob_metrics['auroc']:.4f}") print(f"\nBinary Classification Metrics:") print(f" Accuracy: {prob_metrics['accuracy']:.4f}") print(f" Precision: {prob_metrics['precision']:.4f}") print(f" Recall: {prob_metrics['recall']:.4f}") print(f" F1-Score: {prob_metrics['f1']:.4f}") print(f" MCC: {prob_metrics['mcc']:.4f}") print(f"\nConfusion Matrix:") print(f" True Pos: {prob_metrics['true_positives']:>6} | False Pos: {prob_metrics['false_positives']:>6}") print(f" False Neg: {prob_metrics['false_negatives']:>6} | True Neg: {prob_metrics['true_negatives']:>6}") print(f"\n{'-'*40}") print(f"VOTING-BASED PREDICTIONS") print(f"{'-'*40}") print(f"Binary Classification Metrics:") print(f" Accuracy: {vote_metrics['accuracy']:.4f}") print(f" Precision: {vote_metrics['precision']:.4f}") print(f" Recall: {vote_metrics['recall']:.4f}") print(f" F1-Score: {vote_metrics['f1']:.4f}") print(f" MCC: {vote_metrics['mcc']:.4f}") print(f"\nConfusion Matrix:") print(f" True Pos: {vote_metrics['true_positives']:>6} | False Pos: {vote_metrics['false_positives']:>6}") print(f" False Neg: {vote_metrics['false_negatives']:>6} | True Neg: {vote_metrics['true_negatives']:>6}") print(f"\n{'-'*40}") print(f"COMPARISON SUMMARY") print(f"{'-'*40}") print(f"{'Metric':<12} {'Probability':<12} {'Voting':<12} {'Difference':<12}") print(f"{'-'*48}") print(f"{'Accuracy':<12} {prob_metrics['accuracy']:<12.4f} {vote_metrics['accuracy']:<12.4f} {prob_metrics['accuracy']-vote_metrics['accuracy']:<12.4f}") print(f"{'Precision':<12} {prob_metrics['precision']:<12.4f} {vote_metrics['precision']:<12.4f} {prob_metrics['precision']-vote_metrics['precision']:<12.4f}") print(f"{'Recall':<12} {prob_metrics['recall']:<12.4f} {vote_metrics['recall']:<12.4f} {prob_metrics['recall']-vote_metrics['recall']:<12.4f}") print(f"{'F1-Score':<12} {prob_metrics['f1']:<12.4f} {vote_metrics['f1']:<12.4f} {prob_metrics['f1']-vote_metrics['f1']:<12.4f}") print(f"{'MCC':<12} {prob_metrics['mcc']:<12.4f} {vote_metrics['mcc']:<12.4f} {prob_metrics['mcc']-vote_metrics['mcc']:<12.4f}") print(f"\n{'='*80}") def save_evaluation_results(results, output_dir=None, prefix="evaluation"): """ Save detailed evaluation results to files for further analysis. Args: results: Dictionary containing all evaluation results output_dir: Directory to save results prefix: Prefix for output files """ import os import json if output_dir is None: output_dir = BASE_DIR / "results" / "evaluation" # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) # Save overall results as JSON results_to_save = { 'probability_metrics': results['probability_metrics'], 'voted_metrics': results['voted_metrics'], 'predicted_metrics': results['predicted_metrics'], 'overall_stats': results['overall_stats'], 'threshold': results['threshold'] } # Remove non-serializable items (curves) if 'pr_curve' in results_to_save['probability_metrics']: if results_to_save['probability_metrics']['pr_curve'] is not None: # Convert numpy arrays to lists for JSON serialization results_to_save['probability_metrics']['pr_curve'] = { 'precision': results_to_save['probability_metrics']['pr_curve']['precision'].tolist(), 'recall': results_to_save['probability_metrics']['pr_curve']['recall'].tolist() } if 'roc_curve' in results_to_save['probability_metrics']: if results_to_save['probability_metrics']['roc_curve'] is not None: results_to_save['probability_metrics']['roc_curve'] = { 'fpr': results_to_save['probability_metrics']['roc_curve']['fpr'].tolist(), 'tpr': results_to_save['probability_metrics']['roc_curve']['tpr'].tolist() } # Save main results with open(os.path.join(output_dir, f"{prefix}_results.json"), 'w') as f: json.dump(results_to_save, f, indent=2) # Save protein-level results as CSV if 'protein_results' in results: df = pd.DataFrame(results['protein_results']) df.to_csv(os.path.join(output_dir, f"{prefix}_protein_results.csv"), index=False) print(f"\nResults saved to {output_dir}/") print(f" - {prefix}_results.json: Overall metrics") print(f" - {prefix}_protein_results.csv: Per-protein results")