File size: 12,213 Bytes
5c20520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import pickle
import json
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import re
from collections import defaultdict

def load_ground_truth(pkl_file):
    """加载ground truth数据"""
    with open(pkl_file, 'rb') as f:
        data = pickle.load(f)
    
    # 提取每个蛋白的EC号
    gt_dict = {}
    for item in data:
        uniprot_id = item['uniprot_id']
        ec_numbers = []
        
        # 提取EC号
        if 'ec' in item:
            for ec_info in item['ec']:
                if 'reaction' in ec_info and 'ecNumber' in ec_info['reaction']:
                    ec_numbers.append(ec_info['reaction']['ecNumber'])
        
        gt_dict[uniprot_id] = set(ec_numbers)  # 使用set去重
    
    return gt_dict

def extract_ec_prediction(json_content):
    """从预测结果中提取EC号"""
    # 查找[EC_PREDICTION]标签后的内容
    pattern = r'\[EC_PREDICTION\]\s*([^\n\r]*)'
    match = re.search(pattern, json_content)
    
    if match:
        line_content = match.group(1).strip()
        # 修改EC号格式匹配,支持不完整的EC号(带有-的情况)
        # 匹配格式:数字.数字.数字.数字 或 数字.数字.数字.- 或 数字.数字.-.- 或 数字.-.-.-
        ec_pattern = r'\b\d+\.(?:\d+|-)\.(?:\d+|-)\.(?:\d+|-)'
        ec_numbers = re.findall(ec_pattern, line_content)
        return ec_numbers
    
    return []

def load_predictions(predictions_dir):
    """加载所有预测结果"""
    predictions = {}
    
    for filename in os.listdir(predictions_dir):
        if filename.endswith('.json'):
            uniprot_id = filename.replace('.json', '')
            filepath = os.path.join(predictions_dir, filename)
            
            try:
                with open(filepath, 'r', encoding='utf-8') as f:
                    content = f.read()
                
                # 提取EC预测
                predicted_ecs = extract_ec_prediction(content)
                predictions[uniprot_id] = predicted_ecs
                
            except Exception as e:
                print(f"处理文件 {filename} 时出错: {e}")
    
    return predictions

def calculate_accuracy(ground_truth, predictions, level=4):
    """
    计算EC号在指定级别上的准确率
    level: 1-4,表示比较EC号的前几个数字
    """
    correct = 0
    total = 0
    
    for uniprot_id, gt_ecs in ground_truth.items():
        if uniprot_id in predictions and predictions[uniprot_id]:
            # 取预测的第一个EC号
            pred_ec = predictions[uniprot_id][0]
            
            # 检查是否有任何ground truth EC号在指定级别上与预测匹配
            is_correct = False
            for gt_ec in gt_ecs:
                # 将EC号分割成组成部分
                gt_parts = gt_ec.split('.')[:level]
                pred_parts = pred_ec.split('.')[:level]
                
                # 比较前level个部分是否相同
                if gt_parts == pred_parts:
                    is_correct = True
                    break
            
            if is_correct:
                correct += 1
            
            total += 1
    
    accuracy = correct / total if total > 0 else 0
    return accuracy, correct, total

def calculate_prf1(ground_truth, predictions, level=4):
    """
    计算EC号在指定级别上的精确率、召回率和F1分数 (微平均)
    level: 1-4,表示比较EC号的前几个数字
    """
    total_tp = 0
    total_fp = 0
    total_fn = 0
    
    # 添加用于记录错误预测的字典
    incorrect_proteins = {
        'false_positives': [],  # 预测了但GT中没有的
        'false_negatives': [],  # GT中有但没预测到的
        'no_prediction': [],    # 有GT但没有预测的
        'zero_prediction': []   # 预测了0个EC号的蛋白
    }
    
    for uniprot_id, gt_ecs_set in ground_truth.items():
        if uniprot_id in predictions:
            pred_ecs_set = set(predictions[uniprot_id])
            
            # 如果GT是空的,跳过这个蛋白的评估
            if not gt_ecs_set:
                continue
            
            # 检查是否预测了0个EC号
            if not pred_ecs_set:
                level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
                fn = len(level_gt)
                total_fn += fn
                
                incorrect_proteins['zero_prediction'].append({
                    'protein_id': uniprot_id,
                    'gt_ecs': list(level_gt)
                })
                continue

            # --- 核心计算逻辑 ---
            # 为了处理level,我们需要小心地计算交集
            # level_gt = {'1.2.3.4' -> '1.2.3'}
            level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
            level_pred = set('.'.join(ec.split('.')[:level]) for ec in pred_ecs_set)
            
            # 计算 TP, FP, FN
            tp = len(level_pred.intersection(level_gt))
            fp = len(level_pred) - tp
            fn = len(level_gt) - tp
            
            total_tp += tp
            total_fp += fp
            total_fn += fn
            
            # 记录有错误的蛋白ID
            if fp > 0 or fn > 0:
                fp_ecs = level_pred - level_gt  # 假阳性的EC号
                fn_ecs = level_gt - level_pred  # 假阴性的EC号
                
                if fp > 0:
                    incorrect_proteins['false_positives'].append({
                        'protein_id': uniprot_id,
                        'predicted_ecs': list(fp_ecs),
                        'gt_ecs': list(level_gt)
                    })
                
                if fn > 0:
                    incorrect_proteins['false_negatives'].append({
                        'protein_id': uniprot_id,
                        'missed_ecs': list(fn_ecs),
                        'predicted_ecs': list(level_pred)
                    })
        else:
            # 有GT但没有预测的情况
            if gt_ecs_set:
                level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
                fn = len(level_gt)
                total_fn += fn
                
                incorrect_proteins['no_prediction'].append({
                    'protein_id': uniprot_id,
                    'gt_ecs': list(level_gt)
                })

    # 使用微平均计算总的 Precision, Recall, F1
    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    # total用于展示处理了多少个蛋白
    total_proteins_evaluated = sum(1 for uid in ground_truth if uid in predictions and ground_truth[uid])
    
    return {
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'tp': total_tp,
        'fp': total_fp,
        'fn': total_fn,
        'evaluated_proteins': total_proteins_evaluated,
        'incorrect_proteins': incorrect_proteins
    }

def main():
    # 文件路径
    import argparse
    parser = argparse.ArgumentParser(description='Calculate EC accuracy')
    parser.add_argument('--pkl_file', type=str, default='data/raw_data/difference_20241122_ec_dict_list.pkl')
    parser.add_argument('--predictions_dir', type=str, default='data/clean_test_results_top2go_deepseek-r1')
    args = parser.parse_args()
    pkl_file = args.pkl_file
    predictions_dir = args.predictions_dir
    
    print("正在加载ground truth数据...")
    ground_truth = load_ground_truth(pkl_file)
    print(f"加载了 {len(ground_truth)} 个蛋白的ground truth数据")
    
    print("正在加载预测结果...")
    predictions = load_predictions(predictions_dir)
    print(f"加载了 {len(predictions)} 个蛋白的预测结果")

    # print(f"predictions: {predictions}")
    # print(f"ground_truth: {ground_truth}")
    
    # 找到共同的蛋白ID
    common_ids = set(ground_truth.keys()) & set(predictions.keys())
    valid_ids = {uid for uid in common_ids if ground_truth[uid]} # 只评估那些有GT EC号的蛋白
    print(f"共同且有GT的蛋白数量: {len(valid_ids)}")
    
    # 过滤数据
    filtered_gt = {uid: ground_truth[uid] for uid in valid_ids}
    filtered_pred = {uid: predictions[uid] for uid in valid_ids}

    # 计算不同级别的PRF1
    results = {}
    print("\n=== 评估结果 ===")
    for level in [1, 2, 3, 4]:
        metrics = calculate_prf1(filtered_gt, filtered_pred, level=level)
        results[level] = metrics
        print(f"--- EC号前{level}级 ---")
        print(f"  Precision: {metrics['precision']:.4f}")
        print(f"  Recall:    {metrics['recall']:.4f}")
        print(f"  F1-Score:  {metrics['f1_score']:.4f}")
        print(f"  (TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']})")
        
        # 打印预测错误的蛋白ID
        incorrect = metrics['incorrect_proteins']
        
        if incorrect['false_positives']:
            print(f"  假阳性错误 ({len(incorrect['false_positives'])}个蛋白):")
            for item in incorrect['false_positives'][:10]:  # 只显示前10个
                print(f"    {item['protein_id']}: 错误预测了 {item['predicted_ecs']}, GT是 {item['gt_ecs']}")
            if len(incorrect['false_positives']) > 10:
                print(f"    ... 还有 {len(incorrect['false_positives']) - 10} 个")
        
        if incorrect['false_negatives']:
            print(f"  假阴性错误 ({len(incorrect['false_negatives'])}个蛋白):")
            for item in incorrect['false_negatives'][:10]:  # 只显示前10个
                print(f"    {item['protein_id']}: 漏掉了 {item['missed_ecs']}, 预测了 {item['predicted_ecs']}")
            if len(incorrect['false_negatives']) > 10:
                print(f"    ... 还有 {len(incorrect['false_negatives']) - 10} 个")
        
        if incorrect['zero_prediction']:
            print(f"  零预测错误 ({len(incorrect['zero_prediction'])}个蛋白):")
            for item in incorrect['zero_prediction']:  
                print(f"    {item['protein_id']}: GT是 {item['gt_ecs']}, 但预测了0个EC号")
        
        if incorrect['no_prediction']:
            print(f"  无预测错误 ({len(incorrect['no_prediction'])}个蛋白):")
            for item in incorrect['no_prediction'][:10]:  # 只显示前10个
                print(f"    {item['protein_id']}: GT是 {item['gt_ecs']}, 但没有预测")
            if len(incorrect['no_prediction']) > 10:
                print(f"    ... 还有 {len(incorrect['no_prediction']) - 10} 个")
        
        print()  # 空行分隔
    
    # 统计信息
    print("\n=== 详细统计信息 ===")
    
    # 统计ground truth中EC号的分布
    gt_ec_counts = defaultdict(int)
    for ecs in filtered_gt.values():
        gt_ec_counts[len(ecs)] += 1
    
    print("Ground truth中EC号数量分布:")
    for count, freq in sorted(gt_ec_counts.items()):
        print(f"  {count}个EC号: {freq}个蛋白")
    
    # 统计预测结果中EC号的分布
    pred_ec_counts = defaultdict(int)
    for ecs in filtered_pred.values():
        pred_ec_counts[len(ecs)] += 1
    
    print("\n预测结果中EC号数量分布:")
    for count, freq in sorted(pred_ec_counts.items()):
        print(f"  {count}个EC号: {freq}个蛋白")
    
    # 保存结果
    output_file = 'test_results/ec_accuracy_results.json'
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    # #保存ground truth
    # with open('test_results/ground_truth.json', 'w', encoding='utf-8') as f:
    #     json.dump(filtered_gt, f, indent=2, ensure_ascii=False)
    
    # #保存预测结果
    # with open('test_results/predictions.json', 'w', encoding='utf-8') as f:
    #     json.dump(filtered_pred, f, indent=2, ensure_ascii=False)
    
    print(f"\n结果已保存到 {output_file}")

if __name__ == "__main__":
    main()