|
import numpy as np |
|
|
|
from sklearn.metrics import ( |
|
mean_squared_error, mean_absolute_error, r2_score, |
|
average_precision_score, roc_auc_score, f1_score, |
|
precision_score, recall_score, matthews_corrcoef, |
|
accuracy_score, confusion_matrix, roc_curve, precision_recall_curve |
|
) |
|
|
|
def calculate_graph_metrics(preds, labels, threshold=0.5): |
|
""" |
|
Calculate graph-level metrics for recall prediction. |
|
|
|
Args: |
|
preds: Predicted recall values (numpy array) |
|
labels: True recall values (numpy array) |
|
threshold: Threshold for binary classification (default: 0.5, was 0.7) |
|
|
|
Returns: |
|
Dictionary of metrics |
|
""" |
|
|
|
preds = np.nan_to_num(preds, nan=0.0, posinf=1.0, neginf=0.0) |
|
labels = np.nan_to_num(labels, nan=0.0, posinf=1.0, neginf=0.0) |
|
|
|
|
|
pred_binary = (preds > threshold).astype(int) |
|
label_binary = (labels > threshold).astype(int) |
|
|
|
metrics = {} |
|
|
|
|
|
if len(np.unique(label_binary)) > 1: |
|
metrics['recall'] = recall_score(label_binary, pred_binary, zero_division=0) |
|
metrics['precision'] = precision_score(label_binary, pred_binary, zero_division=0) |
|
metrics['mcc'] = matthews_corrcoef(label_binary, pred_binary) |
|
metrics['f1'] = f1_score(label_binary, pred_binary, zero_division=0) |
|
metrics['accuracy'] = accuracy_score(label_binary, pred_binary) |
|
else: |
|
metrics['recall'] = 0.0 |
|
metrics['precision'] = 0.0 |
|
metrics['mcc'] = 0.0 |
|
metrics['f1'] = 0.0 |
|
metrics['accuracy'] = 0.0 |
|
|
|
|
|
metrics['mse'] = mean_squared_error(labels, preds) |
|
metrics['mae'] = mean_absolute_error(labels, preds) |
|
metrics['r2'] = r2_score(labels, preds) |
|
|
|
return metrics |
|
|
|
def calculate_node_metrics(preds, labels, find_threshold=False, include_curves=False): |
|
""" |
|
Calculate node-level metrics for epitope prediction. |
|
|
|
Args: |
|
preds: Predicted probabilities (numpy array) |
|
labels: True binary labels (numpy array) |
|
find_threshold: If True, find the threshold that maximizes F1 score |
|
include_curves: If True, include PR and ROC curves for visualization |
|
|
|
Returns: |
|
Dictionary of metrics including optimal threshold if find_threshold=True |
|
""" |
|
|
|
preds = np.nan_to_num(preds, nan=0.0, posinf=1.0, neginf=0.0) |
|
labels = np.nan_to_num(labels, nan=0.0, posinf=1.0, neginf=0.0) |
|
|
|
metrics = {} |
|
|
|
|
|
if len(np.unique(labels)) > 1: |
|
|
|
try: |
|
metrics['auroc'] = roc_auc_score(labels, preds) |
|
metrics['auprc'] = average_precision_score(labels, preds) |
|
|
|
|
|
if include_curves: |
|
|
|
precision_curve, recall_curve, _ = precision_recall_curve(labels, preds) |
|
metrics['pr_curve'] = { |
|
'precision': precision_curve, |
|
'recall': recall_curve |
|
} |
|
|
|
|
|
fpr, tpr, _ = roc_curve(labels, preds) |
|
metrics['roc_curve'] = { |
|
'fpr': fpr, |
|
'tpr': tpr |
|
} |
|
else: |
|
metrics['pr_curve'] = None |
|
metrics['roc_curve'] = None |
|
|
|
except: |
|
metrics['auroc'] = 0.0 |
|
metrics['auprc'] = 0.0 |
|
metrics['pr_curve'] = None |
|
metrics['roc_curve'] = None |
|
|
|
|
|
if find_threshold: |
|
best_threshold, best_mcc = find_optimal_threshold(preds, labels) |
|
metrics['best_threshold'] = best_threshold |
|
threshold = best_threshold |
|
else: |
|
threshold = 0.5 |
|
metrics['best_threshold'] = 0.5 |
|
|
|
|
|
pred_binary = (preds > threshold).astype(int) |
|
metrics['f1'] = f1_score(labels, pred_binary, zero_division=0) |
|
metrics['mcc'] = matthews_corrcoef(labels, pred_binary) |
|
metrics['precision'] = precision_score(labels, pred_binary, zero_division=0) |
|
metrics['recall'] = recall_score(labels, pred_binary, zero_division=0) |
|
metrics['accuracy'] = accuracy_score(labels, pred_binary) |
|
|
|
|
|
try: |
|
tn, fp, fn, tp = confusion_matrix(labels, pred_binary).ravel() |
|
metrics['true_positives'] = int(tp) |
|
metrics['false_positives'] = int(fp) |
|
metrics['true_negatives'] = int(tn) |
|
metrics['false_negatives'] = int(fn) |
|
except: |
|
metrics['true_positives'] = 0 |
|
metrics['false_positives'] = 0 |
|
metrics['true_negatives'] = 0 |
|
metrics['false_negatives'] = 0 |
|
|
|
|
|
metrics['threshold_used'] = threshold |
|
|
|
else: |
|
|
|
metrics['auroc'] = 0.0 |
|
metrics['auprc'] = 0.0 |
|
metrics['f1'] = 0.0 |
|
metrics['mcc'] = 0.0 |
|
metrics['precision'] = 0.0 |
|
metrics['recall'] = 0.0 |
|
metrics['accuracy'] = 0.0 |
|
metrics['best_threshold'] = 0.5 |
|
metrics['threshold_used'] = 0.5 |
|
metrics['true_positives'] = 0 |
|
metrics['false_positives'] = 0 |
|
metrics['true_negatives'] = 0 |
|
metrics['false_negatives'] = 0 |
|
metrics['pr_curve'] = None |
|
metrics['roc_curve'] = None |
|
|
|
return metrics |
|
|
|
def find_optimal_threshold(preds, labels, num_thresholds=100): |
|
""" |
|
Find the threshold that maximizes F1 score. |
|
|
|
Args: |
|
preds: Predicted probabilities (numpy array) |
|
labels: True binary labels (numpy array) |
|
num_thresholds: Number of thresholds to test |
|
|
|
Returns: |
|
Tuple of (best_threshold, best_f1_score) |
|
""" |
|
|
|
thresholds = np.linspace(0.01, 0.99, num_thresholds) |
|
|
|
best_mcc = 0.0 |
|
best_threshold = 0.5 |
|
|
|
for threshold in thresholds: |
|
pred_binary = (preds > threshold).astype(int) |
|
mcc = matthews_corrcoef(labels, pred_binary) |
|
|
|
if mcc > best_mcc: |
|
best_mcc = mcc |
|
best_threshold = threshold |
|
|
|
return best_threshold, best_mcc |
|
|