Spaces:
Runtime error
Runtime error
import pickle | |
import json | |
import os | |
import sys | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import re | |
from collections import defaultdict | |
def load_ground_truth(pkl_file): | |
"""加载ground truth数据""" | |
with open(pkl_file, 'rb') as f: | |
data = pickle.load(f) | |
# 提取每个蛋白的EC号 | |
gt_dict = {} | |
for item in data: | |
uniprot_id = item['uniprot_id'] | |
ec_numbers = [] | |
# 提取EC号 | |
if 'ec' in item: | |
for ec_info in item['ec']: | |
if 'reaction' in ec_info and 'ecNumber' in ec_info['reaction']: | |
ec_numbers.append(ec_info['reaction']['ecNumber']) | |
gt_dict[uniprot_id] = set(ec_numbers) # 使用set去重 | |
return gt_dict | |
def extract_ec_prediction(json_content): | |
"""从预测结果中提取EC号""" | |
# 查找[EC_PREDICTION]标签后的内容 | |
pattern = r'\[EC_PREDICTION\]\s*([^\n\r]*)' | |
match = re.search(pattern, json_content) | |
if match: | |
line_content = match.group(1).strip() | |
# 修改EC号格式匹配,支持不完整的EC号(带有-的情况) | |
# 匹配格式:数字.数字.数字.数字 或 数字.数字.数字.- 或 数字.数字.-.- 或 数字.-.-.- | |
ec_pattern = r'\b\d+\.(?:\d+|-)\.(?:\d+|-)\.(?:\d+|-)' | |
ec_numbers = re.findall(ec_pattern, line_content) | |
return ec_numbers | |
return [] | |
def load_predictions(predictions_dir): | |
"""加载所有预测结果""" | |
predictions = {} | |
for filename in os.listdir(predictions_dir): | |
if filename.endswith('.json'): | |
uniprot_id = filename.replace('.json', '') | |
filepath = os.path.join(predictions_dir, filename) | |
try: | |
with open(filepath, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# 提取EC预测 | |
predicted_ecs = extract_ec_prediction(content) | |
predictions[uniprot_id] = predicted_ecs | |
except Exception as e: | |
print(f"处理文件 {filename} 时出错: {e}") | |
return predictions | |
def calculate_accuracy(ground_truth, predictions, level=4): | |
""" | |
计算EC号在指定级别上的准确率 | |
level: 1-4,表示比较EC号的前几个数字 | |
""" | |
correct = 0 | |
total = 0 | |
for uniprot_id, gt_ecs in ground_truth.items(): | |
if uniprot_id in predictions and predictions[uniprot_id]: | |
# 取预测的第一个EC号 | |
pred_ec = predictions[uniprot_id][0] | |
# 检查是否有任何ground truth EC号在指定级别上与预测匹配 | |
is_correct = False | |
for gt_ec in gt_ecs: | |
# 将EC号分割成组成部分 | |
gt_parts = gt_ec.split('.')[:level] | |
pred_parts = pred_ec.split('.')[:level] | |
# 比较前level个部分是否相同 | |
if gt_parts == pred_parts: | |
is_correct = True | |
break | |
if is_correct: | |
correct += 1 | |
total += 1 | |
accuracy = correct / total if total > 0 else 0 | |
return accuracy, correct, total | |
def calculate_prf1(ground_truth, predictions, level=4): | |
""" | |
计算EC号在指定级别上的精确率、召回率和F1分数 (微平均) | |
level: 1-4,表示比较EC号的前几个数字 | |
""" | |
total_tp = 0 | |
total_fp = 0 | |
total_fn = 0 | |
# 添加用于记录错误预测的字典 | |
incorrect_proteins = { | |
'false_positives': [], # 预测了但GT中没有的 | |
'false_negatives': [], # GT中有但没预测到的 | |
'no_prediction': [], # 有GT但没有预测的 | |
'zero_prediction': [] # 预测了0个EC号的蛋白 | |
} | |
for uniprot_id, gt_ecs_set in ground_truth.items(): | |
if uniprot_id in predictions: | |
pred_ecs_set = set(predictions[uniprot_id]) | |
# 如果GT是空的,跳过这个蛋白的评估 | |
if not gt_ecs_set: | |
continue | |
# 检查是否预测了0个EC号 | |
if not pred_ecs_set: | |
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) | |
fn = len(level_gt) | |
total_fn += fn | |
incorrect_proteins['zero_prediction'].append({ | |
'protein_id': uniprot_id, | |
'gt_ecs': list(level_gt) | |
}) | |
continue | |
# --- 核心计算逻辑 --- | |
# 为了处理level,我们需要小心地计算交集 | |
# level_gt = {'1.2.3.4' -> '1.2.3'} | |
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) | |
level_pred = set('.'.join(ec.split('.')[:level]) for ec in pred_ecs_set) | |
# 计算 TP, FP, FN | |
tp = len(level_pred.intersection(level_gt)) | |
fp = len(level_pred) - tp | |
fn = len(level_gt) - tp | |
total_tp += tp | |
total_fp += fp | |
total_fn += fn | |
# 记录有错误的蛋白ID | |
if fp > 0 or fn > 0: | |
fp_ecs = level_pred - level_gt # 假阳性的EC号 | |
fn_ecs = level_gt - level_pred # 假阴性的EC号 | |
if fp > 0: | |
incorrect_proteins['false_positives'].append({ | |
'protein_id': uniprot_id, | |
'predicted_ecs': list(fp_ecs), | |
'gt_ecs': list(level_gt) | |
}) | |
if fn > 0: | |
incorrect_proteins['false_negatives'].append({ | |
'protein_id': uniprot_id, | |
'missed_ecs': list(fn_ecs), | |
'predicted_ecs': list(level_pred) | |
}) | |
else: | |
# 有GT但没有预测的情况 | |
if gt_ecs_set: | |
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set) | |
fn = len(level_gt) | |
total_fn += fn | |
incorrect_proteins['no_prediction'].append({ | |
'protein_id': uniprot_id, | |
'gt_ecs': list(level_gt) | |
}) | |
# 使用微平均计算总的 Precision, Recall, F1 | |
precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0 | |
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0 | |
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
# total用于展示处理了多少个蛋白 | |
total_proteins_evaluated = sum(1 for uid in ground_truth if uid in predictions and ground_truth[uid]) | |
return { | |
'precision': precision, | |
'recall': recall, | |
'f1_score': f1, | |
'tp': total_tp, | |
'fp': total_fp, | |
'fn': total_fn, | |
'evaluated_proteins': total_proteins_evaluated, | |
'incorrect_proteins': incorrect_proteins | |
} | |
def main(): | |
# 文件路径 | |
import argparse | |
parser = argparse.ArgumentParser(description='Calculate EC accuracy') | |
parser.add_argument('--pkl_file', type=str, default='data/raw_data/difference_20241122_ec_dict_list.pkl') | |
parser.add_argument('--predictions_dir', type=str, default='data/clean_test_results_top2go_deepseek-r1') | |
args = parser.parse_args() | |
pkl_file = args.pkl_file | |
predictions_dir = args.predictions_dir | |
print("正在加载ground truth数据...") | |
ground_truth = load_ground_truth(pkl_file) | |
print(f"加载了 {len(ground_truth)} 个蛋白的ground truth数据") | |
print("正在加载预测结果...") | |
predictions = load_predictions(predictions_dir) | |
print(f"加载了 {len(predictions)} 个蛋白的预测结果") | |
# print(f"predictions: {predictions}") | |
# print(f"ground_truth: {ground_truth}") | |
# 找到共同的蛋白ID | |
common_ids = set(ground_truth.keys()) & set(predictions.keys()) | |
valid_ids = {uid for uid in common_ids if ground_truth[uid]} # 只评估那些有GT EC号的蛋白 | |
print(f"共同且有GT的蛋白数量: {len(valid_ids)}") | |
# 过滤数据 | |
filtered_gt = {uid: ground_truth[uid] for uid in valid_ids} | |
filtered_pred = {uid: predictions[uid] for uid in valid_ids} | |
# 计算不同级别的PRF1 | |
results = {} | |
print("\n=== 评估结果 ===") | |
for level in [1, 2, 3, 4]: | |
metrics = calculate_prf1(filtered_gt, filtered_pred, level=level) | |
results[level] = metrics | |
print(f"--- EC号前{level}级 ---") | |
print(f" Precision: {metrics['precision']:.4f}") | |
print(f" Recall: {metrics['recall']:.4f}") | |
print(f" F1-Score: {metrics['f1_score']:.4f}") | |
print(f" (TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']})") | |
# 打印预测错误的蛋白ID | |
incorrect = metrics['incorrect_proteins'] | |
if incorrect['false_positives']: | |
print(f" 假阳性错误 ({len(incorrect['false_positives'])}个蛋白):") | |
for item in incorrect['false_positives'][:10]: # 只显示前10个 | |
print(f" {item['protein_id']}: 错误预测了 {item['predicted_ecs']}, GT是 {item['gt_ecs']}") | |
if len(incorrect['false_positives']) > 10: | |
print(f" ... 还有 {len(incorrect['false_positives']) - 10} 个") | |
if incorrect['false_negatives']: | |
print(f" 假阴性错误 ({len(incorrect['false_negatives'])}个蛋白):") | |
for item in incorrect['false_negatives'][:10]: # 只显示前10个 | |
print(f" {item['protein_id']}: 漏掉了 {item['missed_ecs']}, 预测了 {item['predicted_ecs']}") | |
if len(incorrect['false_negatives']) > 10: | |
print(f" ... 还有 {len(incorrect['false_negatives']) - 10} 个") | |
if incorrect['zero_prediction']: | |
print(f" 零预测错误 ({len(incorrect['zero_prediction'])}个蛋白):") | |
for item in incorrect['zero_prediction']: | |
print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但预测了0个EC号") | |
if incorrect['no_prediction']: | |
print(f" 无预测错误 ({len(incorrect['no_prediction'])}个蛋白):") | |
for item in incorrect['no_prediction'][:10]: # 只显示前10个 | |
print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但没有预测") | |
if len(incorrect['no_prediction']) > 10: | |
print(f" ... 还有 {len(incorrect['no_prediction']) - 10} 个") | |
print() # 空行分隔 | |
# 统计信息 | |
print("\n=== 详细统计信息 ===") | |
# 统计ground truth中EC号的分布 | |
gt_ec_counts = defaultdict(int) | |
for ecs in filtered_gt.values(): | |
gt_ec_counts[len(ecs)] += 1 | |
print("Ground truth中EC号数量分布:") | |
for count, freq in sorted(gt_ec_counts.items()): | |
print(f" {count}个EC号: {freq}个蛋白") | |
# 统计预测结果中EC号的分布 | |
pred_ec_counts = defaultdict(int) | |
for ecs in filtered_pred.values(): | |
pred_ec_counts[len(ecs)] += 1 | |
print("\n预测结果中EC号数量分布:") | |
for count, freq in sorted(pred_ec_counts.items()): | |
print(f" {count}个EC号: {freq}个蛋白") | |
# 保存结果 | |
output_file = 'test_results/ec_accuracy_results.json' | |
with open(output_file, 'w', encoding='utf-8') as f: | |
json.dump(results, f, indent=2, ensure_ascii=False) | |
# #保存ground truth | |
# with open('test_results/ground_truth.json', 'w', encoding='utf-8') as f: | |
# json.dump(filtered_gt, f, indent=2, ensure_ascii=False) | |
# #保存预测结果 | |
# with open('test_results/predictions.json', 'w', encoding='utf-8') as f: | |
# json.dump(filtered_pred, f, indent=2, ensure_ascii=False) | |
print(f"\n结果已保存到 {output_file}") | |
if __name__ == "__main__": | |
main() |