import pickle import json import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re from collections import defaultdict def load_ground_truth(pkl_file): """加载ground truth数据""" with open(pkl_file, 'rb') as f: data = pickle.load(f) # 提取每个蛋白的EC号 gt_dict = {} for item in data: uniprot_id = item['uniprot_id'] ec_numbers = [] # 提取EC号 if 'ec' in item: for ec_info in item['ec']: if 'reaction' in ec_info and 'ecNumber' in ec_info['reaction']: ec_numbers.append(ec_info['reaction']['ecNumber']) gt_dict[uniprot_id] = set(ec_numbers) # 使用set去重 return gt_dict def extract_ec_prediction(json_content): """从预测结果中提取EC号""" # 查找[EC_PREDICTION]标签后的内容 pattern = r'\[EC_PREDICTION\]\s*([^\n\r]*)' match = re.search(pattern, json_content) if match: line_content = match.group(1).strip() # 修改EC号格式匹配,支持不完整的EC号(带有-的情况) # 匹配格式:数字.数字.数字.数字 或 数字.数字.数字.- 或 数字.数字.-.- 或 数字.-.-.- ec_pattern = r'\b\d+\.(?:\d+|-)\.(?:\d+|-)\.(?:\d+|-)' ec_numbers = re.findall(ec_pattern, line_content) return ec_numbers return [] def load_predictions(predictions_dir): """加载所有预测结果""" predictions = {} for filename in os.listdir(predictions_dir): if filename.endswith('.json'): uniprot_id = filename.replace('.json', '') filepath = os.path.join(predictions_dir, filename) try: with open(filepath, 'r', encoding='utf-8') as f: content = f.read() # 提取EC预测 predicted_ecs = extract_ec_prediction(content) predictions[uniprot_id] = predicted_ecs except Exception as e: print(f"处理文件 {filename} 时出错: {e}") return predictions def calculate_accuracy(ground_truth, predictions, level=4): """ 计算EC号在指定级别上的准确率 level: 1-4,表示比较EC号的前几个数字 """ correct = 0 total = 0 for uniprot_id, gt_ecs in ground_truth.items(): if uniprot_id in predictions and predictions[uniprot_id]: # 取预测的第一个EC号 pred_ec = predictions[uniprot_id][0] # 检查是否有任何ground truth EC号在指定级别上与预测匹配 is_correct = False for gt_ec in gt_ecs: # 将EC号分割成组成部分 gt_parts = gt_ec.split('.')[:level] pred_parts = pred_ec.split('.')[:level] # 比较前level个部分是否相同 if gt_parts == pred_parts: is_correct = True break if is_correct: correct += 1 total += 1 accuracy = correct / total if total > 0 else 0 return accuracy, correct, total def calculate_prf1(ground_truth, predictions, level=4): """ 计算EC号在指定级别上的精确率、召回率和F1分数 (微平均) level: 1-4,表示比较EC号的前几个数字 """ total_tp = 0 total_fp = 0 total_fn = 0 # 添加用于记录错误预测的字典 incorrect_proteins = { 'false_positives': [], # 预测了但GT中没有的 'false_negatives': [], # GT中有但没预测到的 'no_prediction': [], # 有GT但没有预测的 'zero_prediction': [] # 预测了0个EC号的蛋白 } for uniprot_id, gt_ecs_set in ground_truth.items(): if uniprot_id in predictions: pred_ecs_set = set(predictions[uniprot_id]) # 如果GT是空的,跳过这个蛋白的评估 if not gt_ecs_set: continue # 检查是否预测了0个EC号 if not pred_ecs_set: level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) fn = len(level_gt) total_fn += fn incorrect_proteins['zero_prediction'].append({ 'protein_id': uniprot_id, 'gt_ecs': list(level_gt) }) continue # --- 核心计算逻辑 --- # 为了处理level,我们需要小心地计算交集 # level_gt = {'1.2.3.4' -> '1.2.3'} level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) level_pred = set('.'.join(ec.split('.')[:level]) for ec in pred_ecs_set) # 计算 TP, FP, FN tp = len(level_pred.intersection(level_gt)) fp = len(level_pred) - tp fn = len(level_gt) - tp total_tp += tp total_fp += fp total_fn += fn # 记录有错误的蛋白ID if fp > 0 or fn > 0: fp_ecs = level_pred - level_gt # 假阳性的EC号 fn_ecs = level_gt - level_pred # 假阴性的EC号 if fp > 0: incorrect_proteins['false_positives'].append({ 'protein_id': uniprot_id, 'predicted_ecs': list(fp_ecs), 'gt_ecs': list(level_gt) }) if fn > 0: incorrect_proteins['false_negatives'].append({ 'protein_id': uniprot_id, 'missed_ecs': list(fn_ecs), 'predicted_ecs': list(level_pred) }) else: # 有GT但没有预测的情况 if gt_ecs_set: level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) fn = len(level_gt) total_fn += fn incorrect_proteins['no_prediction'].append({ 'protein_id': uniprot_id, 'gt_ecs': list(level_gt) }) # 使用微平均计算总的 Precision, Recall, F1 precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0 recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 # total用于展示处理了多少个蛋白 total_proteins_evaluated = sum(1 for uid in ground_truth if uid in predictions and ground_truth[uid]) return { 'precision': precision, 'recall': recall, 'f1_score': f1, 'tp': total_tp, 'fp': total_fp, 'fn': total_fn, 'evaluated_proteins': total_proteins_evaluated, 'incorrect_proteins': incorrect_proteins } def main(): # 文件路径 import argparse parser = argparse.ArgumentParser(description='Calculate EC accuracy') parser.add_argument('--pkl_file', type=str, default='data/raw_data/difference_20241122_ec_dict_list.pkl') parser.add_argument('--predictions_dir', type=str, default='data/clean_test_results_top2go_deepseek-r1') args = parser.parse_args() pkl_file = args.pkl_file predictions_dir = args.predictions_dir print("正在加载ground truth数据...") ground_truth = load_ground_truth(pkl_file) print(f"加载了 {len(ground_truth)} 个蛋白的ground truth数据") print("正在加载预测结果...") predictions = load_predictions(predictions_dir) print(f"加载了 {len(predictions)} 个蛋白的预测结果") # print(f"predictions: {predictions}") # print(f"ground_truth: {ground_truth}") # 找到共同的蛋白ID common_ids = set(ground_truth.keys()) & set(predictions.keys()) valid_ids = {uid for uid in common_ids if ground_truth[uid]} # 只评估那些有GT EC号的蛋白 print(f"共同且有GT的蛋白数量: {len(valid_ids)}") # 过滤数据 filtered_gt = {uid: ground_truth[uid] for uid in valid_ids} filtered_pred = {uid: predictions[uid] for uid in valid_ids} # 计算不同级别的PRF1 results = {} print("\n=== 评估结果 ===") for level in [1, 2, 3, 4]: metrics = calculate_prf1(filtered_gt, filtered_pred, level=level) results[level] = metrics print(f"--- EC号前{level}级 ---") print(f" Precision: {metrics['precision']:.4f}") print(f" Recall: {metrics['recall']:.4f}") print(f" F1-Score: {metrics['f1_score']:.4f}") print(f" (TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']})") # 打印预测错误的蛋白ID incorrect = metrics['incorrect_proteins'] if incorrect['false_positives']: print(f" 假阳性错误 ({len(incorrect['false_positives'])}个蛋白):") for item in incorrect['false_positives'][:10]: # 只显示前10个 print(f" {item['protein_id']}: 错误预测了 {item['predicted_ecs']}, GT是 {item['gt_ecs']}") if len(incorrect['false_positives']) > 10: print(f" ... 还有 {len(incorrect['false_positives']) - 10} 个") if incorrect['false_negatives']: print(f" 假阴性错误 ({len(incorrect['false_negatives'])}个蛋白):") for item in incorrect['false_negatives'][:10]: # 只显示前10个 print(f" {item['protein_id']}: 漏掉了 {item['missed_ecs']}, 预测了 {item['predicted_ecs']}") if len(incorrect['false_negatives']) > 10: print(f" ... 还有 {len(incorrect['false_negatives']) - 10} 个") if incorrect['zero_prediction']: print(f" 零预测错误 ({len(incorrect['zero_prediction'])}个蛋白):") for item in incorrect['zero_prediction']: print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但预测了0个EC号") if incorrect['no_prediction']: print(f" 无预测错误 ({len(incorrect['no_prediction'])}个蛋白):") for item in incorrect['no_prediction'][:10]: # 只显示前10个 print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但没有预测") if len(incorrect['no_prediction']) > 10: print(f" ... 还有 {len(incorrect['no_prediction']) - 10} 个") print() # 空行分隔 # 统计信息 print("\n=== 详细统计信息 ===") # 统计ground truth中EC号的分布 gt_ec_counts = defaultdict(int) for ecs in filtered_gt.values(): gt_ec_counts[len(ecs)] += 1 print("Ground truth中EC号数量分布:") for count, freq in sorted(gt_ec_counts.items()): print(f" {count}个EC号: {freq}个蛋白") # 统计预测结果中EC号的分布 pred_ec_counts = defaultdict(int) for ecs in filtered_pred.values(): pred_ec_counts[len(ecs)] += 1 print("\n预测结果中EC号数量分布:") for count, freq in sorted(pred_ec_counts.items()): print(f" {count}个EC号: {freq}个蛋白") # 保存结果 output_file = 'test_results/ec_accuracy_results.json' with open(output_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) # #保存ground truth # with open('test_results/ground_truth.json', 'w', encoding='utf-8') as f: # json.dump(filtered_gt, f, indent=2, ensure_ascii=False) # #保存预测结果 # with open('test_results/predictions.json', 'w', encoding='utf-8') as f: # json.dump(filtered_pred, f, indent=2, ensure_ascii=False) print(f"\n结果已保存到 {output_file}") if __name__ == "__main__": main()