protein_rag / calculate_ec_accuracy.py
ericzhang1122's picture
Upload folder using huggingface_hub
5c20520 verified
import pickle
import json
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import re
from collections import defaultdict
def load_ground_truth(pkl_file):
"""加载ground truth数据"""
with open(pkl_file, 'rb') as f:
data = pickle.load(f)
# 提取每个蛋白的EC号
gt_dict = {}
for item in data:
uniprot_id = item['uniprot_id']
ec_numbers = []
# 提取EC号
if 'ec' in item:
for ec_info in item['ec']:
if 'reaction' in ec_info and 'ecNumber' in ec_info['reaction']:
ec_numbers.append(ec_info['reaction']['ecNumber'])
gt_dict[uniprot_id] = set(ec_numbers) # 使用set去重
return gt_dict
def extract_ec_prediction(json_content):
"""从预测结果中提取EC号"""
# 查找[EC_PREDICTION]标签后的内容
pattern = r'\[EC_PREDICTION\]\s*([^\n\r]*)'
match = re.search(pattern, json_content)
if match:
line_content = match.group(1).strip()
# 修改EC号格式匹配,支持不完整的EC号(带有-的情况)
# 匹配格式:数字.数字.数字.数字 或 数字.数字.数字.- 或 数字.数字.-.- 或 数字.-.-.-
ec_pattern = r'\b\d+\.(?:\d+|-)\.(?:\d+|-)\.(?:\d+|-)'
ec_numbers = re.findall(ec_pattern, line_content)
return ec_numbers
return []
def load_predictions(predictions_dir):
"""加载所有预测结果"""
predictions = {}
for filename in os.listdir(predictions_dir):
if filename.endswith('.json'):
uniprot_id = filename.replace('.json', '')
filepath = os.path.join(predictions_dir, filename)
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
# 提取EC预测
predicted_ecs = extract_ec_prediction(content)
predictions[uniprot_id] = predicted_ecs
except Exception as e:
print(f"处理文件 {filename} 时出错: {e}")
return predictions
def calculate_accuracy(ground_truth, predictions, level=4):
"""
计算EC号在指定级别上的准确率
level: 1-4,表示比较EC号的前几个数字
"""
correct = 0
total = 0
for uniprot_id, gt_ecs in ground_truth.items():
if uniprot_id in predictions and predictions[uniprot_id]:
# 取预测的第一个EC号
pred_ec = predictions[uniprot_id][0]
# 检查是否有任何ground truth EC号在指定级别上与预测匹配
is_correct = False
for gt_ec in gt_ecs:
# 将EC号分割成组成部分
gt_parts = gt_ec.split('.')[:level]
pred_parts = pred_ec.split('.')[:level]
# 比较前level个部分是否相同
if gt_parts == pred_parts:
is_correct = True
break
if is_correct:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0
return accuracy, correct, total
def calculate_prf1(ground_truth, predictions, level=4):
"""
计算EC号在指定级别上的精确率、召回率和F1分数 (微平均)
level: 1-4,表示比较EC号的前几个数字
"""
total_tp = 0
total_fp = 0
total_fn = 0
# 添加用于记录错误预测的字典
incorrect_proteins = {
'false_positives': [], # 预测了但GT中没有的
'false_negatives': [], # GT中有但没预测到的
'no_prediction': [], # 有GT但没有预测的
'zero_prediction': [] # 预测了0个EC号的蛋白
}
for uniprot_id, gt_ecs_set in ground_truth.items():
if uniprot_id in predictions:
pred_ecs_set = set(predictions[uniprot_id])
# 如果GT是空的,跳过这个蛋白的评估
if not gt_ecs_set:
continue
# 检查是否预测了0个EC号
if not pred_ecs_set:
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
fn = len(level_gt)
total_fn += fn
incorrect_proteins['zero_prediction'].append({
'protein_id': uniprot_id,
'gt_ecs': list(level_gt)
})
continue
# --- 核心计算逻辑 ---
# 为了处理level,我们需要小心地计算交集
# level_gt = {'1.2.3.4' -> '1.2.3'}
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
level_pred = set('.'.join(ec.split('.')[:level]) for ec in pred_ecs_set)
# 计算 TP, FP, FN
tp = len(level_pred.intersection(level_gt))
fp = len(level_pred) - tp
fn = len(level_gt) - tp
total_tp += tp
total_fp += fp
total_fn += fn
# 记录有错误的蛋白ID
if fp > 0 or fn > 0:
fp_ecs = level_pred - level_gt # 假阳性的EC号
fn_ecs = level_gt - level_pred # 假阴性的EC号
if fp > 0:
incorrect_proteins['false_positives'].append({
'protein_id': uniprot_id,
'predicted_ecs': list(fp_ecs),
'gt_ecs': list(level_gt)
})
if fn > 0:
incorrect_proteins['false_negatives'].append({
'protein_id': uniprot_id,
'missed_ecs': list(fn_ecs),
'predicted_ecs': list(level_pred)
})
else:
# 有GT但没有预测的情况
if gt_ecs_set:
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
fn = len(level_gt)
total_fn += fn
incorrect_proteins['no_prediction'].append({
'protein_id': uniprot_id,
'gt_ecs': list(level_gt)
})
# 使用微平均计算总的 Precision, Recall, F1
precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
# total用于展示处理了多少个蛋白
total_proteins_evaluated = sum(1 for uid in ground_truth if uid in predictions and ground_truth[uid])
return {
'precision': precision,
'recall': recall,
'f1_score': f1,
'tp': total_tp,
'fp': total_fp,
'fn': total_fn,
'evaluated_proteins': total_proteins_evaluated,
'incorrect_proteins': incorrect_proteins
}
def main():
# 文件路径
import argparse
parser = argparse.ArgumentParser(description='Calculate EC accuracy')
parser.add_argument('--pkl_file', type=str, default='data/raw_data/difference_20241122_ec_dict_list.pkl')
parser.add_argument('--predictions_dir', type=str, default='data/clean_test_results_top2go_deepseek-r1')
args = parser.parse_args()
pkl_file = args.pkl_file
predictions_dir = args.predictions_dir
print("正在加载ground truth数据...")
ground_truth = load_ground_truth(pkl_file)
print(f"加载了 {len(ground_truth)} 个蛋白的ground truth数据")
print("正在加载预测结果...")
predictions = load_predictions(predictions_dir)
print(f"加载了 {len(predictions)} 个蛋白的预测结果")
# print(f"predictions: {predictions}")
# print(f"ground_truth: {ground_truth}")
# 找到共同的蛋白ID
common_ids = set(ground_truth.keys()) & set(predictions.keys())
valid_ids = {uid for uid in common_ids if ground_truth[uid]} # 只评估那些有GT EC号的蛋白
print(f"共同且有GT的蛋白数量: {len(valid_ids)}")
# 过滤数据
filtered_gt = {uid: ground_truth[uid] for uid in valid_ids}
filtered_pred = {uid: predictions[uid] for uid in valid_ids}
# 计算不同级别的PRF1
results = {}
print("\n=== 评估结果 ===")
for level in [1, 2, 3, 4]:
metrics = calculate_prf1(filtered_gt, filtered_pred, level=level)
results[level] = metrics
print(f"--- EC号前{level}级 ---")
print(f" Precision: {metrics['precision']:.4f}")
print(f" Recall: {metrics['recall']:.4f}")
print(f" F1-Score: {metrics['f1_score']:.4f}")
print(f" (TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']})")
# 打印预测错误的蛋白ID
incorrect = metrics['incorrect_proteins']
if incorrect['false_positives']:
print(f" 假阳性错误 ({len(incorrect['false_positives'])}个蛋白):")
for item in incorrect['false_positives'][:10]: # 只显示前10个
print(f" {item['protein_id']}: 错误预测了 {item['predicted_ecs']}, GT是 {item['gt_ecs']}")
if len(incorrect['false_positives']) > 10:
print(f" ... 还有 {len(incorrect['false_positives']) - 10} 个")
if incorrect['false_negatives']:
print(f" 假阴性错误 ({len(incorrect['false_negatives'])}个蛋白):")
for item in incorrect['false_negatives'][:10]: # 只显示前10个
print(f" {item['protein_id']}: 漏掉了 {item['missed_ecs']}, 预测了 {item['predicted_ecs']}")
if len(incorrect['false_negatives']) > 10:
print(f" ... 还有 {len(incorrect['false_negatives']) - 10} 个")
if incorrect['zero_prediction']:
print(f" 零预测错误 ({len(incorrect['zero_prediction'])}个蛋白):")
for item in incorrect['zero_prediction']:
print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但预测了0个EC号")
if incorrect['no_prediction']:
print(f" 无预测错误 ({len(incorrect['no_prediction'])}个蛋白):")
for item in incorrect['no_prediction'][:10]: # 只显示前10个
print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但没有预测")
if len(incorrect['no_prediction']) > 10:
print(f" ... 还有 {len(incorrect['no_prediction']) - 10} 个")
print() # 空行分隔
# 统计信息
print("\n=== 详细统计信息 ===")
# 统计ground truth中EC号的分布
gt_ec_counts = defaultdict(int)
for ecs in filtered_gt.values():
gt_ec_counts[len(ecs)] += 1
print("Ground truth中EC号数量分布:")
for count, freq in sorted(gt_ec_counts.items()):
print(f" {count}个EC号: {freq}个蛋白")
# 统计预测结果中EC号的分布
pred_ec_counts = defaultdict(int)
for ecs in filtered_pred.values():
pred_ec_counts[len(ecs)] += 1
print("\n预测结果中EC号数量分布:")
for count, freq in sorted(pred_ec_counts.items()):
print(f" {count}个EC号: {freq}个蛋白")
# 保存结果
output_file = 'test_results/ec_accuracy_results.json'
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
# #保存ground truth
# with open('test_results/ground_truth.json', 'w', encoding='utf-8') as f:
# json.dump(filtered_gt, f, indent=2, ensure_ascii=False)
# #保存预测结果
# with open('test_results/predictions.json', 'w', encoding='utf-8') as f:
# json.dump(filtered_pred, f, indent=2, ensure_ascii=False)
print(f"\n结果已保存到 {output_file}")
if __name__ == "__main__":
main()