Spaces:
Runtime error
Runtime error
import json | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.metrics import precision_recall_curve, average_precision_score | |
def load_go_data(file_path): | |
"""加载JSON文件中的GO数据""" | |
data = {} | |
with open(file_path, 'r') as f: | |
for line in f: | |
entry = json.loads(line) | |
data[entry["protein_id"]] = set(entry["GO_id"]) | |
return data | |
def calculate_pr_metrics(true_go_file, pred_go_file, scores_file=None): | |
"""计算precision、recall和绘制PR曲线""" | |
# 加载真实GO和预测GO数据 | |
true_go_data = load_go_data(true_go_file) | |
pred_go_data = load_go_data(pred_go_file) | |
# 如果提供了分数文件,加载分数 | |
scores = {} | |
if scores_file: | |
with open(scores_file, 'r') as f: | |
for line in f: | |
entry = json.loads(line) | |
scores[entry["protein_id"]] = {go: score for go, score in entry.get("GO_scores", {}).items()} | |
# 准备计算PR曲线的数据 | |
all_true = [] | |
all_scores = [] | |
# 处理每个蛋白质 | |
common_proteins = set(true_go_data.keys()) & set(pred_go_data.keys()) | |
# 计算每个蛋白质的precision和recall | |
protein_metrics = {} | |
for protein_id in common_proteins: | |
true_gos = true_go_data[protein_id] | |
pred_gos = pred_go_data[protein_id] | |
# 计算当前蛋白质的precision和recall | |
if len(pred_gos) > 0: | |
precision = len(true_gos & pred_gos) / len(pred_gos) | |
else: | |
precision = 0.0 | |
if len(true_gos) > 0: | |
recall = len(true_gos & pred_gos) / len(true_gos) | |
else: | |
recall = 1.0 # 如果没有真实GO,则recall为1 | |
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | |
protein_metrics[protein_id] = { | |
"precision": precision, | |
"recall": recall, | |
"f1": f1 | |
} | |
# 如果有分数,为PR曲线准备数据 | |
if scores_file: | |
protein_scores = scores.get(protein_id, {}) | |
for go in set(true_gos) | set(pred_go_data.get(protein_id, set())): | |
all_true.append(1 if go in true_gos else 0) | |
all_scores.append(protein_scores.get(go, 0.0)) | |
# 计算整体指标 | |
avg_precision = np.mean([m["precision"] for m in protein_metrics.values()]) | |
avg_recall = np.mean([m["recall"] for m in protein_metrics.values()]) | |
avg_f1 = np.mean([m["f1"] for m in protein_metrics.values()]) | |
results = { | |
"average_precision": avg_precision, | |
"average_recall": avg_recall, | |
"average_f1": avg_f1, | |
"protein_metrics": protein_metrics | |
} | |
# 如果有分数,绘制PR曲线 | |
if scores_file and all_true and all_scores: | |
all_true = np.array(all_true) | |
all_scores = np.array(all_scores) | |
precision, recall, thresholds = precision_recall_curve(all_true, all_scores) | |
avg_precision = average_precision_score(all_true, all_scores) | |
# 计算每个阈值的F1分数 | |
f1_scores = np.zeros_like(thresholds) | |
for i, threshold in enumerate(thresholds): | |
f1_scores[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0 | |
# 找到最佳F1分数对应的阈值 | |
best_f1_idx = np.argmax(f1_scores) | |
best_threshold = thresholds[best_f1_idx] | |
best_precision = precision[best_f1_idx] | |
best_recall = recall[best_f1_idx] | |
best_f1 = f1_scores[best_f1_idx] | |
# 绘制PR曲线 | |
plt.figure(figsize=(10, 8)) | |
plt.plot(recall, precision, label=f'平均精确率 = {avg_precision:.3f}') | |
plt.scatter(best_recall, best_precision, color='red', | |
label=f'最佳F1 = {best_f1:.3f} (阈值 = {best_threshold:.3f})') | |
plt.xlabel('Recall') | |
plt.ylabel('Precision') | |
plt.title('Precision-Recall 曲线') | |
plt.legend() | |
plt.grid(True) | |
# 保存图像 | |
plt.savefig('pr_curve.png', dpi=300) | |
plt.close() | |
results.update({ | |
"pr_curve": { | |
"precision": precision.tolist(), | |
"recall": recall.tolist(), | |
"thresholds": thresholds.tolist(), | |
"best_threshold": float(best_threshold), | |
"best_f1": float(best_f1) | |
} | |
}) | |
return results | |
def main(): | |
import argparse | |
parser = argparse.ArgumentParser(description='计算GO预测的Precision和Recall并绘制PR曲线') | |
parser.add_argument('--true', required=True, help='真实GO的JSON文件路径') | |
parser.add_argument('--pred', required=True, help='预测GO的JSON文件路径') | |
parser.add_argument('--scores', help='GO分数的JSON文件路径(可选)') | |
parser.add_argument('--output', default='test_results/pr_results.json', help='输出结果的JSON文件路径') | |
args = parser.parse_args() | |
results = calculate_pr_metrics(args.true, args.pred, args.scores) | |
# 保存结果 | |
with open(args.output, 'w') as f: | |
json.dump(results, f, indent=2) | |
print(f"平均精确率: {results['average_precision']:.4f}") | |
print(f"平均召回率: {results['average_recall']:.4f}") | |
print(f"平均F1分数: {results['average_f1']:.4f}") | |
if 'pr_curve' in results: | |
print(f"最佳F1分数: {results['pr_curve']['best_f1']:.4f} (阈值: {results['pr_curve']['best_threshold']:.4f})") | |
print(f"PR曲线已保存为 pr_curve.png") | |
if __name__ == "__main__": | |
main() | |