Spaces:
Runtime error
Runtime error
File size: 5,730 Bytes
5c20520 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import json
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score
def load_go_data(file_path):
"""加载JSON文件中的GO数据"""
data = {}
with open(file_path, 'r') as f:
for line in f:
entry = json.loads(line)
data[entry["protein_id"]] = set(entry["GO_id"])
return data
def calculate_pr_metrics(true_go_file, pred_go_file, scores_file=None):
"""计算precision、recall和绘制PR曲线"""
# 加载真实GO和预测GO数据
true_go_data = load_go_data(true_go_file)
pred_go_data = load_go_data(pred_go_file)
# 如果提供了分数文件,加载分数
scores = {}
if scores_file:
with open(scores_file, 'r') as f:
for line in f:
entry = json.loads(line)
scores[entry["protein_id"]] = {go: score for go, score in entry.get("GO_scores", {}).items()}
# 准备计算PR曲线的数据
all_true = []
all_scores = []
# 处理每个蛋白质
common_proteins = set(true_go_data.keys()) & set(pred_go_data.keys())
# 计算每个蛋白质的precision和recall
protein_metrics = {}
for protein_id in common_proteins:
true_gos = true_go_data[protein_id]
pred_gos = pred_go_data[protein_id]
# 计算当前蛋白质的precision和recall
if len(pred_gos) > 0:
precision = len(true_gos & pred_gos) / len(pred_gos)
else:
precision = 0.0
if len(true_gos) > 0:
recall = len(true_gos & pred_gos) / len(true_gos)
else:
recall = 1.0 # 如果没有真实GO,则recall为1
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
protein_metrics[protein_id] = {
"precision": precision,
"recall": recall,
"f1": f1
}
# 如果有分数,为PR曲线准备数据
if scores_file:
protein_scores = scores.get(protein_id, {})
for go in set(true_gos) | set(pred_go_data.get(protein_id, set())):
all_true.append(1 if go in true_gos else 0)
all_scores.append(protein_scores.get(go, 0.0))
# 计算整体指标
avg_precision = np.mean([m["precision"] for m in protein_metrics.values()])
avg_recall = np.mean([m["recall"] for m in protein_metrics.values()])
avg_f1 = np.mean([m["f1"] for m in protein_metrics.values()])
results = {
"average_precision": avg_precision,
"average_recall": avg_recall,
"average_f1": avg_f1,
"protein_metrics": protein_metrics
}
# 如果有分数,绘制PR曲线
if scores_file and all_true and all_scores:
all_true = np.array(all_true)
all_scores = np.array(all_scores)
precision, recall, thresholds = precision_recall_curve(all_true, all_scores)
avg_precision = average_precision_score(all_true, all_scores)
# 计算每个阈值的F1分数
f1_scores = np.zeros_like(thresholds)
for i, threshold in enumerate(thresholds):
f1_scores[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0
# 找到最佳F1分数对应的阈值
best_f1_idx = np.argmax(f1_scores)
best_threshold = thresholds[best_f1_idx]
best_precision = precision[best_f1_idx]
best_recall = recall[best_f1_idx]
best_f1 = f1_scores[best_f1_idx]
# 绘制PR曲线
plt.figure(figsize=(10, 8))
plt.plot(recall, precision, label=f'平均精确率 = {avg_precision:.3f}')
plt.scatter(best_recall, best_precision, color='red',
label=f'最佳F1 = {best_f1:.3f} (阈值 = {best_threshold:.3f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall 曲线')
plt.legend()
plt.grid(True)
# 保存图像
plt.savefig('pr_curve.png', dpi=300)
plt.close()
results.update({
"pr_curve": {
"precision": precision.tolist(),
"recall": recall.tolist(),
"thresholds": thresholds.tolist(),
"best_threshold": float(best_threshold),
"best_f1": float(best_f1)
}
})
return results
def main():
import argparse
parser = argparse.ArgumentParser(description='计算GO预测的Precision和Recall并绘制PR曲线')
parser.add_argument('--true', required=True, help='真实GO的JSON文件路径')
parser.add_argument('--pred', required=True, help='预测GO的JSON文件路径')
parser.add_argument('--scores', help='GO分数的JSON文件路径(可选)')
parser.add_argument('--output', default='test_results/pr_results.json', help='输出结果的JSON文件路径')
args = parser.parse_args()
results = calculate_pr_metrics(args.true, args.pred, args.scores)
# 保存结果
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
print(f"平均精确率: {results['average_precision']:.4f}")
print(f"平均召回率: {results['average_recall']:.4f}")
print(f"平均F1分数: {results['average_f1']:.4f}")
if 'pr_curve' in results:
print(f"最佳F1分数: {results['pr_curve']['best_f1']:.4f} (阈值: {results['pr_curve']['best_threshold']:.4f})")
print(f"PR曲线已保存为 pr_curve.png")
if __name__ == "__main__":
main()
|