|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from sklearn.metrics import roc_curve, auc, precision_recall_curve |
|
|
|
from .constants import BASE_DIR |
|
from .loading import load_data_split |
|
from ..antigen.antigen import AntigenChain |
|
from .metrics import find_optimal_threshold, calculate_node_metrics |
|
|
|
def evaluate_ReCEP(model_path, device_id=0, radius=18.0, threshold=0.5, k=5, |
|
verbose=True, split="test", save_results=True, output_dir=None, encoder="esmc"): |
|
""" |
|
Evaluate ReCEP model on a dataset split using both probability-based and voting-based predictions. |
|
|
|
Args: |
|
model_path: Path to the trained ReCEP model |
|
device_id: GPU device ID |
|
radius: Radius for spherical regions |
|
threshold: Threshold for probability-based predictions |
|
k: Number of top regions to select |
|
verbose: Whether to print progress |
|
split: Dataset split to evaluate ('test', 'val', 'train') |
|
save_results: Whether to save detailed results to files |
|
output_dir: Directory to save results (if save_results=True) |
|
|
|
Returns: |
|
Dictionary containing evaluation metrics for both prediction methods |
|
""" |
|
print(f"[INFO] Evaluating ReCEP model from {model_path}") |
|
print(f"[INFO] Settings:") |
|
print(f" Radius: {radius}") |
|
print(f" K: {k}") |
|
print(f" Split: {split}\n") |
|
|
|
antigens = load_data_split(split, verbose=verbose) |
|
|
|
|
|
all_true_labels = [] |
|
all_predicted_probs = [] |
|
all_voted_labels = [] |
|
all_predicted_binary = [] |
|
|
|
protein_results = [] |
|
|
|
for pdb_id, chain_id in tqdm(antigens, desc=f"Evaluating ReCEP on {split} set", disable=not verbose): |
|
try: |
|
antigen_chain = AntigenChain.from_pdb(chain_id=chain_id, id=pdb_id) |
|
results = antigen_chain.evaluate( |
|
model_path=model_path, |
|
device_id=device_id, |
|
radius=radius, |
|
threshold=threshold, |
|
k=k, |
|
verbose=False, |
|
encoder=encoder |
|
) |
|
|
|
|
|
true_epitopes = antigen_chain.get_epitope_residue_numbers() |
|
true_binary = [] |
|
predicted_probs = [] |
|
voted_binary = [] |
|
predicted_binary = [] |
|
|
|
|
|
for idx in range(len(antigen_chain.residue_index)): |
|
residue_num = int(antigen_chain.residue_index[idx]) |
|
|
|
|
|
true_binary.append(1 if residue_num in true_epitopes else 0) |
|
|
|
|
|
predicted_probs.append(results['predictions'].get(residue_num, 0)) |
|
|
|
|
|
voted_binary.append(1 if residue_num in results['voted_epitopes'] else 0) |
|
|
|
|
|
predicted_binary.append(1 if residue_num in results['predicted_epitopes'] else 0) |
|
|
|
|
|
all_true_labels.extend(true_binary) |
|
all_predicted_probs.extend(predicted_probs) |
|
all_voted_labels.extend(voted_binary) |
|
all_predicted_binary.extend(predicted_binary) |
|
|
|
length = len(antigen_chain.sequence) |
|
species = antigen_chain.get_species() |
|
precision = results['predicted_precision'] |
|
recall = results['predicted_recall'] |
|
f1 = 2 * precision * recall / (precision + recall + 1e-10) |
|
|
|
|
|
if len(set(true_binary)) > 1: |
|
pr_precision, pr_recall, _ = precision_recall_curve(true_binary, predicted_probs) |
|
pr_auc = auc(pr_recall, pr_precision) |
|
else: |
|
pr_auc = 0.0 |
|
|
|
|
|
protein_results.append({ |
|
'pdb_id': pdb_id, |
|
'chain_id': chain_id, |
|
'length': length, |
|
'species': species, |
|
'predicted_precision': precision, |
|
'predicted_recall': recall, |
|
'predicted_f1': f1, |
|
'pr_auc': pr_auc, |
|
'voted_precision': results['voted_precision'], |
|
'voted_recall': results['voted_recall'], |
|
'num_residues': len(true_binary), |
|
'num_true_epitopes': sum(true_binary), |
|
'num_predicted_epitopes': sum(predicted_binary), |
|
'num_voted_epitopes': sum(voted_binary), |
|
'true_epitopes': true_binary, |
|
'predicted_probabilities': predicted_probs |
|
}) |
|
|
|
except Exception as e: |
|
if verbose: |
|
print(f"[WARNING] Failed to evaluate {pdb_id}_{chain_id}: {str(e)}") |
|
continue |
|
|
|
|
|
all_true_labels = np.array(all_true_labels) |
|
all_predicted_probs = np.array(all_predicted_probs) |
|
all_voted_labels = np.array(all_voted_labels) |
|
all_predicted_binary = np.array(all_predicted_binary) |
|
|
|
|
|
prob_metrics = calculate_node_metrics(all_predicted_probs, all_true_labels, find_threshold=True, include_curves=True) |
|
|
|
|
|
vote_metrics = calculate_node_metrics(all_voted_labels.astype(float), all_true_labels, find_threshold=False) |
|
|
|
|
|
pred_metrics = calculate_node_metrics(all_predicted_binary.astype(float), all_true_labels, find_threshold=False) |
|
|
|
|
|
prediction_stats = { |
|
'prob_based': { |
|
'total_predicted_positive': int(np.sum(all_predicted_binary)), |
|
'prediction_rate': float(np.mean(all_predicted_binary)) |
|
}, |
|
'vote_based': { |
|
'total_predicted_positive': int(np.sum(all_voted_labels)), |
|
'prediction_rate': float(np.mean(all_voted_labels)) |
|
} |
|
} |
|
|
|
|
|
overall_stats = { |
|
'num_proteins': len(protein_results), |
|
'total_residues': len(all_true_labels), |
|
'total_true_epitopes': int(np.sum(all_true_labels)), |
|
'epitope_ratio': float(np.mean(all_true_labels)), |
|
'avg_protein_size': np.mean([p['num_residues'] for p in protein_results]), |
|
'avg_epitopes_per_protein': np.mean([p['num_true_epitopes'] for p in protein_results]), |
|
'prediction_stats': prediction_stats |
|
} |
|
|
|
if verbose: |
|
print_evaluation_results(prob_metrics, vote_metrics, pred_metrics, overall_stats, threshold) |
|
|
|
|
|
results = { |
|
'probability_metrics': prob_metrics, |
|
'voted_metrics': vote_metrics, |
|
'predicted_metrics': pred_metrics, |
|
'overall_stats': overall_stats, |
|
'protein_results': protein_results, |
|
'threshold': threshold |
|
} |
|
|
|
if save_results: |
|
if output_dir is None: |
|
|
|
from pathlib import Path |
|
model_path_obj = Path(model_path) |
|
timestamp = model_path_obj.parent.name |
|
model_name = model_path_obj.name.split("_")[1] |
|
output_dir = BASE_DIR / "results" / "ReCEP" / timestamp |
|
save_evaluation_results(results, output_dir, model_name) |
|
|
|
return results |
|
|
|
|
|
|
|
def print_evaluation_results(prob_metrics, vote_metrics, pred_metrics, overall_stats, threshold): |
|
"""Print formatted evaluation results for both prediction modes.""" |
|
print(f"\n{'='*80}") |
|
print(f"ReCEP MODEL EVALUATION RESULTS") |
|
print(f"{'='*80}") |
|
|
|
print(f"\nOverall Statistics:") |
|
print(f" Number of proteins: {overall_stats['num_proteins']}") |
|
print(f" Total residues: {overall_stats['total_residues']:,}") |
|
print(f" Total true epitopes: {overall_stats['total_true_epitopes']:,}") |
|
print(f" Epitope ratio: {overall_stats['epitope_ratio']:.3f}") |
|
print(f" Average protein size: {overall_stats['avg_protein_size']:.1f}") |
|
print(f" Average epitopes per protein: {overall_stats['avg_epitopes_per_protein']:.1f}") |
|
|
|
print(f"\n{'-'*40}") |
|
print(f"PROBABILITY-BASED PREDICTIONS") |
|
print(f"{'-'*40}") |
|
print(f"Threshold: {prob_metrics['best_threshold']}") |
|
print(f"\nProbability Metrics:") |
|
print(f" AUPRC: {prob_metrics['auprc']:.4f}") |
|
print(f" AUROC: {prob_metrics['auroc']:.4f}") |
|
print(f"\nBinary Classification Metrics:") |
|
print(f" Accuracy: {prob_metrics['accuracy']:.4f}") |
|
print(f" Precision: {prob_metrics['precision']:.4f}") |
|
print(f" Recall: {prob_metrics['recall']:.4f}") |
|
print(f" F1-Score: {prob_metrics['f1']:.4f}") |
|
print(f" MCC: {prob_metrics['mcc']:.4f}") |
|
print(f"\nConfusion Matrix:") |
|
print(f" True Pos: {prob_metrics['true_positives']:>6} | False Pos: {prob_metrics['false_positives']:>6}") |
|
print(f" False Neg: {prob_metrics['false_negatives']:>6} | True Neg: {prob_metrics['true_negatives']:>6}") |
|
|
|
print(f"\n{'-'*40}") |
|
print(f"VOTING-BASED PREDICTIONS") |
|
print(f"{'-'*40}") |
|
print(f"Binary Classification Metrics:") |
|
print(f" Accuracy: {vote_metrics['accuracy']:.4f}") |
|
print(f" Precision: {vote_metrics['precision']:.4f}") |
|
print(f" Recall: {vote_metrics['recall']:.4f}") |
|
print(f" F1-Score: {vote_metrics['f1']:.4f}") |
|
print(f" MCC: {vote_metrics['mcc']:.4f}") |
|
print(f"\nConfusion Matrix:") |
|
print(f" True Pos: {vote_metrics['true_positives']:>6} | False Pos: {vote_metrics['false_positives']:>6}") |
|
print(f" False Neg: {vote_metrics['false_negatives']:>6} | True Neg: {vote_metrics['true_negatives']:>6}") |
|
|
|
print(f"\n{'-'*40}") |
|
print(f"COMPARISON SUMMARY") |
|
print(f"{'-'*40}") |
|
print(f"{'Metric':<12} {'Probability':<12} {'Voting':<12} {'Difference':<12}") |
|
print(f"{'-'*48}") |
|
print(f"{'Accuracy':<12} {prob_metrics['accuracy']:<12.4f} {vote_metrics['accuracy']:<12.4f} {prob_metrics['accuracy']-vote_metrics['accuracy']:<12.4f}") |
|
print(f"{'Precision':<12} {prob_metrics['precision']:<12.4f} {vote_metrics['precision']:<12.4f} {prob_metrics['precision']-vote_metrics['precision']:<12.4f}") |
|
print(f"{'Recall':<12} {prob_metrics['recall']:<12.4f} {vote_metrics['recall']:<12.4f} {prob_metrics['recall']-vote_metrics['recall']:<12.4f}") |
|
print(f"{'F1-Score':<12} {prob_metrics['f1']:<12.4f} {vote_metrics['f1']:<12.4f} {prob_metrics['f1']-vote_metrics['f1']:<12.4f}") |
|
print(f"{'MCC':<12} {prob_metrics['mcc']:<12.4f} {vote_metrics['mcc']:<12.4f} {prob_metrics['mcc']-vote_metrics['mcc']:<12.4f}") |
|
|
|
print(f"\n{'='*80}") |
|
|
|
|
|
def save_evaluation_results(results, output_dir=None, prefix="evaluation"): |
|
""" |
|
Save detailed evaluation results to files for further analysis. |
|
|
|
Args: |
|
results: Dictionary containing all evaluation results |
|
output_dir: Directory to save results |
|
prefix: Prefix for output files |
|
""" |
|
import os |
|
import json |
|
|
|
if output_dir is None: |
|
output_dir = BASE_DIR / "results" / "evaluation" |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
results_to_save = { |
|
'probability_metrics': results['probability_metrics'], |
|
'voted_metrics': results['voted_metrics'], |
|
'predicted_metrics': results['predicted_metrics'], |
|
'overall_stats': results['overall_stats'], |
|
'threshold': results['threshold'] |
|
} |
|
|
|
|
|
if 'pr_curve' in results_to_save['probability_metrics']: |
|
if results_to_save['probability_metrics']['pr_curve'] is not None: |
|
|
|
results_to_save['probability_metrics']['pr_curve'] = { |
|
'precision': results_to_save['probability_metrics']['pr_curve']['precision'].tolist(), |
|
'recall': results_to_save['probability_metrics']['pr_curve']['recall'].tolist() |
|
} |
|
|
|
if 'roc_curve' in results_to_save['probability_metrics']: |
|
if results_to_save['probability_metrics']['roc_curve'] is not None: |
|
results_to_save['probability_metrics']['roc_curve'] = { |
|
'fpr': results_to_save['probability_metrics']['roc_curve']['fpr'].tolist(), |
|
'tpr': results_to_save['probability_metrics']['roc_curve']['tpr'].tolist() |
|
} |
|
|
|
|
|
with open(os.path.join(output_dir, f"{prefix}_results.json"), 'w') as f: |
|
json.dump(results_to_save, f, indent=2) |
|
|
|
|
|
if 'protein_results' in results: |
|
df = pd.DataFrame(results['protein_results']) |
|
df.to_csv(os.path.join(output_dir, f"{prefix}_protein_results.csv"), index=False) |
|
|
|
print(f"\nResults saved to {output_dir}/") |
|
print(f" - {prefix}_results.json: Overall metrics") |
|
print(f" - {prefix}_protein_results.csv: Per-protein results") |
|
|