protein_rag / utils /cal_pr.py
ericzhang1122's picture
Upload folder using huggingface_hub
5c20520 verified
import json
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score
def load_go_data(file_path):
"""加载JSON文件中的GO数据"""
data = {}
with open(file_path, 'r') as f:
for line in f:
entry = json.loads(line)
data[entry["protein_id"]] = set(entry["GO_id"])
return data
def calculate_pr_metrics(true_go_file, pred_go_file, scores_file=None):
"""计算precision、recall和绘制PR曲线"""
# 加载真实GO和预测GO数据
true_go_data = load_go_data(true_go_file)
pred_go_data = load_go_data(pred_go_file)
# 如果提供了分数文件,加载分数
scores = {}
if scores_file:
with open(scores_file, 'r') as f:
for line in f:
entry = json.loads(line)
scores[entry["protein_id"]] = {go: score for go, score in entry.get("GO_scores", {}).items()}
# 准备计算PR曲线的数据
all_true = []
all_scores = []
# 处理每个蛋白质
common_proteins = set(true_go_data.keys()) & set(pred_go_data.keys())
# 计算每个蛋白质的precision和recall
protein_metrics = {}
for protein_id in common_proteins:
true_gos = true_go_data[protein_id]
pred_gos = pred_go_data[protein_id]
# 计算当前蛋白质的precision和recall
if len(pred_gos) > 0:
precision = len(true_gos & pred_gos) / len(pred_gos)
else:
precision = 0.0
if len(true_gos) > 0:
recall = len(true_gos & pred_gos) / len(true_gos)
else:
recall = 1.0 # 如果没有真实GO,则recall为1
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
protein_metrics[protein_id] = {
"precision": precision,
"recall": recall,
"f1": f1
}
# 如果有分数,为PR曲线准备数据
if scores_file:
protein_scores = scores.get(protein_id, {})
for go in set(true_gos) | set(pred_go_data.get(protein_id, set())):
all_true.append(1 if go in true_gos else 0)
all_scores.append(protein_scores.get(go, 0.0))
# 计算整体指标
avg_precision = np.mean([m["precision"] for m in protein_metrics.values()])
avg_recall = np.mean([m["recall"] for m in protein_metrics.values()])
avg_f1 = np.mean([m["f1"] for m in protein_metrics.values()])
results = {
"average_precision": avg_precision,
"average_recall": avg_recall,
"average_f1": avg_f1,
"protein_metrics": protein_metrics
}
# 如果有分数,绘制PR曲线
if scores_file and all_true and all_scores:
all_true = np.array(all_true)
all_scores = np.array(all_scores)
precision, recall, thresholds = precision_recall_curve(all_true, all_scores)
avg_precision = average_precision_score(all_true, all_scores)
# 计算每个阈值的F1分数
f1_scores = np.zeros_like(thresholds)
for i, threshold in enumerate(thresholds):
f1_scores[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0
# 找到最佳F1分数对应的阈值
best_f1_idx = np.argmax(f1_scores)
best_threshold = thresholds[best_f1_idx]
best_precision = precision[best_f1_idx]
best_recall = recall[best_f1_idx]
best_f1 = f1_scores[best_f1_idx]
# 绘制PR曲线
plt.figure(figsize=(10, 8))
plt.plot(recall, precision, label=f'平均精确率 = {avg_precision:.3f}')
plt.scatter(best_recall, best_precision, color='red',
label=f'最佳F1 = {best_f1:.3f} (阈值 = {best_threshold:.3f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall 曲线')
plt.legend()
plt.grid(True)
# 保存图像
plt.savefig('pr_curve.png', dpi=300)
plt.close()
results.update({
"pr_curve": {
"precision": precision.tolist(),
"recall": recall.tolist(),
"thresholds": thresholds.tolist(),
"best_threshold": float(best_threshold),
"best_f1": float(best_f1)
}
})
return results
def main():
import argparse
parser = argparse.ArgumentParser(description='计算GO预测的Precision和Recall并绘制PR曲线')
parser.add_argument('--true', required=True, help='真实GO的JSON文件路径')
parser.add_argument('--pred', required=True, help='预测GO的JSON文件路径')
parser.add_argument('--scores', help='GO分数的JSON文件路径(可选)')
parser.add_argument('--output', default='test_results/pr_results.json', help='输出结果的JSON文件路径')
args = parser.parse_args()
results = calculate_pr_metrics(args.true, args.pred, args.scores)
# 保存结果
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
print(f"平均精确率: {results['average_precision']:.4f}")
print(f"平均召回率: {results['average_recall']:.4f}")
print(f"平均F1分数: {results['average_f1']:.4f}")
if 'pr_curve' in results:
print(f"最佳F1分数: {results['pr_curve']['best_f1']:.4f} (阈值: {results['pr_curve']['best_threshold']:.4f})")
print(f"PR曲线已保存为 pr_curve.png")
if __name__ == "__main__":
main()