File size: 5,730 Bytes
5c20520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import json
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score

def load_go_data(file_path):
    """加载JSON文件中的GO数据"""
    data = {}
    with open(file_path, 'r') as f:
        for line in f:
            entry = json.loads(line)
            data[entry["protein_id"]] = set(entry["GO_id"])
    return data

def calculate_pr_metrics(true_go_file, pred_go_file, scores_file=None):
    """计算precision、recall和绘制PR曲线"""
    # 加载真实GO和预测GO数据
    true_go_data = load_go_data(true_go_file)
    pred_go_data = load_go_data(pred_go_file)
    
    # 如果提供了分数文件,加载分数
    scores = {}
    if scores_file:
        with open(scores_file, 'r') as f:
            for line in f:
                entry = json.loads(line)
                scores[entry["protein_id"]] = {go: score for go, score in entry.get("GO_scores", {}).items()}
    
    # 准备计算PR曲线的数据
    all_true = []
    all_scores = []
    
    # 处理每个蛋白质
    common_proteins = set(true_go_data.keys()) & set(pred_go_data.keys())
    
    # 计算每个蛋白质的precision和recall
    protein_metrics = {}
    
    for protein_id in common_proteins:
        true_gos = true_go_data[protein_id]
        pred_gos = pred_go_data[protein_id]
        
        # 计算当前蛋白质的precision和recall
        if len(pred_gos) > 0:
            precision = len(true_gos & pred_gos) / len(pred_gos)
        else:
            precision = 0.0
            
        if len(true_gos) > 0:
            recall = len(true_gos & pred_gos) / len(true_gos)
        else:
            recall = 1.0  # 如果没有真实GO,则recall为1
            
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        protein_metrics[protein_id] = {
            "precision": precision,
            "recall": recall,
            "f1": f1
        }
        
        # 如果有分数,为PR曲线准备数据
        if scores_file:
            protein_scores = scores.get(protein_id, {})
            for go in set(true_gos) | set(pred_go_data.get(protein_id, set())):
                all_true.append(1 if go in true_gos else 0)
                all_scores.append(protein_scores.get(go, 0.0))
    
    # 计算整体指标
    avg_precision = np.mean([m["precision"] for m in protein_metrics.values()])
    avg_recall = np.mean([m["recall"] for m in protein_metrics.values()])
    avg_f1 = np.mean([m["f1"] for m in protein_metrics.values()])
    
    results = {
        "average_precision": avg_precision,
        "average_recall": avg_recall,
        "average_f1": avg_f1,
        "protein_metrics": protein_metrics
    }
    
    # 如果有分数,绘制PR曲线
    if scores_file and all_true and all_scores:
        all_true = np.array(all_true)
        all_scores = np.array(all_scores)
        
        precision, recall, thresholds = precision_recall_curve(all_true, all_scores)
        avg_precision = average_precision_score(all_true, all_scores)
        
        # 计算每个阈值的F1分数
        f1_scores = np.zeros_like(thresholds)
        for i, threshold in enumerate(thresholds):
            f1_scores[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0
        
        # 找到最佳F1分数对应的阈值
        best_f1_idx = np.argmax(f1_scores)
        best_threshold = thresholds[best_f1_idx]
        best_precision = precision[best_f1_idx]
        best_recall = recall[best_f1_idx]
        best_f1 = f1_scores[best_f1_idx]
        
        # 绘制PR曲线
        plt.figure(figsize=(10, 8))
        plt.plot(recall, precision, label=f'平均精确率 = {avg_precision:.3f}')
        plt.scatter(best_recall, best_precision, color='red', 
                   label=f'最佳F1 = {best_f1:.3f} (阈值 = {best_threshold:.3f})')
        
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall 曲线')
        plt.legend()
        plt.grid(True)
        
        # 保存图像
        plt.savefig('pr_curve.png', dpi=300)
        plt.close()
        
        results.update({
            "pr_curve": {
                "precision": precision.tolist(),
                "recall": recall.tolist(),
                "thresholds": thresholds.tolist(),
                "best_threshold": float(best_threshold),
                "best_f1": float(best_f1)
            }
        })
    
    return results

def main():
    import argparse
    parser = argparse.ArgumentParser(description='计算GO预测的Precision和Recall并绘制PR曲线')
    parser.add_argument('--true', required=True, help='真实GO的JSON文件路径')
    parser.add_argument('--pred', required=True, help='预测GO的JSON文件路径')
    parser.add_argument('--scores', help='GO分数的JSON文件路径(可选)')
    parser.add_argument('--output', default='test_results/pr_results.json', help='输出结果的JSON文件路径')
    
    args = parser.parse_args()
    
    results = calculate_pr_metrics(args.true, args.pred, args.scores)
    
    # 保存结果
    with open(args.output, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"平均精确率: {results['average_precision']:.4f}")
    print(f"平均召回率: {results['average_recall']:.4f}")
    print(f"平均F1分数: {results['average_f1']:.4f}")
    
    if 'pr_curve' in results:
        print(f"最佳F1分数: {results['pr_curve']['best_f1']:.4f} (阈值: {results['pr_curve']['best_threshold']:.4f})")
        print(f"PR曲线已保存为 pr_curve.png")

if __name__ == "__main__":
    main()