import json import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import precision_recall_curve, average_precision_score def load_go_data(file_path): """加载JSON文件中的GO数据""" data = {} with open(file_path, 'r') as f: for line in f: entry = json.loads(line) data[entry["protein_id"]] = set(entry["GO_id"]) return data def calculate_pr_metrics(true_go_file, pred_go_file, scores_file=None): """计算precision、recall和绘制PR曲线""" # 加载真实GO和预测GO数据 true_go_data = load_go_data(true_go_file) pred_go_data = load_go_data(pred_go_file) # 如果提供了分数文件,加载分数 scores = {} if scores_file: with open(scores_file, 'r') as f: for line in f: entry = json.loads(line) scores[entry["protein_id"]] = {go: score for go, score in entry.get("GO_scores", {}).items()} # 准备计算PR曲线的数据 all_true = [] all_scores = [] # 处理每个蛋白质 common_proteins = set(true_go_data.keys()) & set(pred_go_data.keys()) # 计算每个蛋白质的precision和recall protein_metrics = {} for protein_id in common_proteins: true_gos = true_go_data[protein_id] pred_gos = pred_go_data[protein_id] # 计算当前蛋白质的precision和recall if len(pred_gos) > 0: precision = len(true_gos & pred_gos) / len(pred_gos) else: precision = 0.0 if len(true_gos) > 0: recall = len(true_gos & pred_gos) / len(true_gos) else: recall = 1.0 # 如果没有真实GO,则recall为1 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 protein_metrics[protein_id] = { "precision": precision, "recall": recall, "f1": f1 } # 如果有分数,为PR曲线准备数据 if scores_file: protein_scores = scores.get(protein_id, {}) for go in set(true_gos) | set(pred_go_data.get(protein_id, set())): all_true.append(1 if go in true_gos else 0) all_scores.append(protein_scores.get(go, 0.0)) # 计算整体指标 avg_precision = np.mean([m["precision"] for m in protein_metrics.values()]) avg_recall = np.mean([m["recall"] for m in protein_metrics.values()]) avg_f1 = np.mean([m["f1"] for m in protein_metrics.values()]) results = { "average_precision": avg_precision, "average_recall": avg_recall, "average_f1": avg_f1, "protein_metrics": protein_metrics } # 如果有分数,绘制PR曲线 if scores_file and all_true and all_scores: all_true = np.array(all_true) all_scores = np.array(all_scores) precision, recall, thresholds = precision_recall_curve(all_true, all_scores) avg_precision = average_precision_score(all_true, all_scores) # 计算每个阈值的F1分数 f1_scores = np.zeros_like(thresholds) for i, threshold in enumerate(thresholds): f1_scores[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0 # 找到最佳F1分数对应的阈值 best_f1_idx = np.argmax(f1_scores) best_threshold = thresholds[best_f1_idx] best_precision = precision[best_f1_idx] best_recall = recall[best_f1_idx] best_f1 = f1_scores[best_f1_idx] # 绘制PR曲线 plt.figure(figsize=(10, 8)) plt.plot(recall, precision, label=f'平均精确率 = {avg_precision:.3f}') plt.scatter(best_recall, best_precision, color='red', label=f'最佳F1 = {best_f1:.3f} (阈值 = {best_threshold:.3f})') plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall 曲线') plt.legend() plt.grid(True) # 保存图像 plt.savefig('pr_curve.png', dpi=300) plt.close() results.update({ "pr_curve": { "precision": precision.tolist(), "recall": recall.tolist(), "thresholds": thresholds.tolist(), "best_threshold": float(best_threshold), "best_f1": float(best_f1) } }) return results def main(): import argparse parser = argparse.ArgumentParser(description='计算GO预测的Precision和Recall并绘制PR曲线') parser.add_argument('--true', required=True, help='真实GO的JSON文件路径') parser.add_argument('--pred', required=True, help='预测GO的JSON文件路径') parser.add_argument('--scores', help='GO分数的JSON文件路径(可选)') parser.add_argument('--output', default='test_results/pr_results.json', help='输出结果的JSON文件路径') args = parser.parse_args() results = calculate_pr_metrics(args.true, args.pred, args.scores) # 保存结果 with open(args.output, 'w') as f: json.dump(results, f, indent=2) print(f"平均精确率: {results['average_precision']:.4f}") print(f"平均召回率: {results['average_recall']:.4f}") print(f"平均F1分数: {results['average_f1']:.4f}") if 'pr_curve' in results: print(f"最佳F1分数: {results['pr_curve']['best_f1']:.4f} (阈值: {results['pr_curve']['best_threshold']:.4f})") print(f"PR曲线已保存为 pr_curve.png") if __name__ == "__main__": main()