Spaces:
Runtime error
Runtime error
File size: 12,213 Bytes
5c20520 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
import pickle
import json
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import re
from collections import defaultdict
def load_ground_truth(pkl_file):
"""加载ground truth数据"""
with open(pkl_file, 'rb') as f:
data = pickle.load(f)
# 提取每个蛋白的EC号
gt_dict = {}
for item in data:
uniprot_id = item['uniprot_id']
ec_numbers = []
# 提取EC号
if 'ec' in item:
for ec_info in item['ec']:
if 'reaction' in ec_info and 'ecNumber' in ec_info['reaction']:
ec_numbers.append(ec_info['reaction']['ecNumber'])
gt_dict[uniprot_id] = set(ec_numbers) # 使用set去重
return gt_dict
def extract_ec_prediction(json_content):
"""从预测结果中提取EC号"""
# 查找[EC_PREDICTION]标签后的内容
pattern = r'\[EC_PREDICTION\]\s*([^\n\r]*)'
match = re.search(pattern, json_content)
if match:
line_content = match.group(1).strip()
# 修改EC号格式匹配,支持不完整的EC号(带有-的情况)
# 匹配格式:数字.数字.数字.数字 或 数字.数字.数字.- 或 数字.数字.-.- 或 数字.-.-.-
ec_pattern = r'\b\d+\.(?:\d+|-)\.(?:\d+|-)\.(?:\d+|-)'
ec_numbers = re.findall(ec_pattern, line_content)
return ec_numbers
return []
def load_predictions(predictions_dir):
"""加载所有预测结果"""
predictions = {}
for filename in os.listdir(predictions_dir):
if filename.endswith('.json'):
uniprot_id = filename.replace('.json', '')
filepath = os.path.join(predictions_dir, filename)
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
# 提取EC预测
predicted_ecs = extract_ec_prediction(content)
predictions[uniprot_id] = predicted_ecs
except Exception as e:
print(f"处理文件 {filename} 时出错: {e}")
return predictions
def calculate_accuracy(ground_truth, predictions, level=4):
"""
计算EC号在指定级别上的准确率
level: 1-4,表示比较EC号的前几个数字
"""
correct = 0
total = 0
for uniprot_id, gt_ecs in ground_truth.items():
if uniprot_id in predictions and predictions[uniprot_id]:
# 取预测的第一个EC号
pred_ec = predictions[uniprot_id][0]
# 检查是否有任何ground truth EC号在指定级别上与预测匹配
is_correct = False
for gt_ec in gt_ecs:
# 将EC号分割成组成部分
gt_parts = gt_ec.split('.')[:level]
pred_parts = pred_ec.split('.')[:level]
# 比较前level个部分是否相同
if gt_parts == pred_parts:
is_correct = True
break
if is_correct:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0
return accuracy, correct, total
def calculate_prf1(ground_truth, predictions, level=4):
"""
计算EC号在指定级别上的精确率、召回率和F1分数 (微平均)
level: 1-4,表示比较EC号的前几个数字
"""
total_tp = 0
total_fp = 0
total_fn = 0
# 添加用于记录错误预测的字典
incorrect_proteins = {
'false_positives': [], # 预测了但GT中没有的
'false_negatives': [], # GT中有但没预测到的
'no_prediction': [], # 有GT但没有预测的
'zero_prediction': [] # 预测了0个EC号的蛋白
}
for uniprot_id, gt_ecs_set in ground_truth.items():
if uniprot_id in predictions:
pred_ecs_set = set(predictions[uniprot_id])
# 如果GT是空的,跳过这个蛋白的评估
if not gt_ecs_set:
continue
# 检查是否预测了0个EC号
if not pred_ecs_set:
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
fn = len(level_gt)
total_fn += fn
incorrect_proteins['zero_prediction'].append({
'protein_id': uniprot_id,
'gt_ecs': list(level_gt)
})
continue
# --- 核心计算逻辑 ---
# 为了处理level,我们需要小心地计算交集
# level_gt = {'1.2.3.4' -> '1.2.3'}
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
level_pred = set('.'.join(ec.split('.')[:level]) for ec in pred_ecs_set)
# 计算 TP, FP, FN
tp = len(level_pred.intersection(level_gt))
fp = len(level_pred) - tp
fn = len(level_gt) - tp
total_tp += tp
total_fp += fp
total_fn += fn
# 记录有错误的蛋白ID
if fp > 0 or fn > 0:
fp_ecs = level_pred - level_gt # 假阳性的EC号
fn_ecs = level_gt - level_pred # 假阴性的EC号
if fp > 0:
incorrect_proteins['false_positives'].append({
'protein_id': uniprot_id,
'predicted_ecs': list(fp_ecs),
'gt_ecs': list(level_gt)
})
if fn > 0:
incorrect_proteins['false_negatives'].append({
'protein_id': uniprot_id,
'missed_ecs': list(fn_ecs),
'predicted_ecs': list(level_pred)
})
else:
# 有GT但没有预测的情况
if gt_ecs_set:
level_gt = set('.'.join(ec.split('.')[:level]) for ec in gt_ecs_set)
fn = len(level_gt)
total_fn += fn
incorrect_proteins['no_prediction'].append({
'protein_id': uniprot_id,
'gt_ecs': list(level_gt)
})
# 使用微平均计算总的 Precision, Recall, F1
precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
# total用于展示处理了多少个蛋白
total_proteins_evaluated = sum(1 for uid in ground_truth if uid in predictions and ground_truth[uid])
return {
'precision': precision,
'recall': recall,
'f1_score': f1,
'tp': total_tp,
'fp': total_fp,
'fn': total_fn,
'evaluated_proteins': total_proteins_evaluated,
'incorrect_proteins': incorrect_proteins
}
def main():
# 文件路径
import argparse
parser = argparse.ArgumentParser(description='Calculate EC accuracy')
parser.add_argument('--pkl_file', type=str, default='data/raw_data/difference_20241122_ec_dict_list.pkl')
parser.add_argument('--predictions_dir', type=str, default='data/clean_test_results_top2go_deepseek-r1')
args = parser.parse_args()
pkl_file = args.pkl_file
predictions_dir = args.predictions_dir
print("正在加载ground truth数据...")
ground_truth = load_ground_truth(pkl_file)
print(f"加载了 {len(ground_truth)} 个蛋白的ground truth数据")
print("正在加载预测结果...")
predictions = load_predictions(predictions_dir)
print(f"加载了 {len(predictions)} 个蛋白的预测结果")
# print(f"predictions: {predictions}")
# print(f"ground_truth: {ground_truth}")
# 找到共同的蛋白ID
common_ids = set(ground_truth.keys()) & set(predictions.keys())
valid_ids = {uid for uid in common_ids if ground_truth[uid]} # 只评估那些有GT EC号的蛋白
print(f"共同且有GT的蛋白数量: {len(valid_ids)}")
# 过滤数据
filtered_gt = {uid: ground_truth[uid] for uid in valid_ids}
filtered_pred = {uid: predictions[uid] for uid in valid_ids}
# 计算不同级别的PRF1
results = {}
print("\n=== 评估结果 ===")
for level in [1, 2, 3, 4]:
metrics = calculate_prf1(filtered_gt, filtered_pred, level=level)
results[level] = metrics
print(f"--- EC号前{level}级 ---")
print(f" Precision: {metrics['precision']:.4f}")
print(f" Recall: {metrics['recall']:.4f}")
print(f" F1-Score: {metrics['f1_score']:.4f}")
print(f" (TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']})")
# 打印预测错误的蛋白ID
incorrect = metrics['incorrect_proteins']
if incorrect['false_positives']:
print(f" 假阳性错误 ({len(incorrect['false_positives'])}个蛋白):")
for item in incorrect['false_positives'][:10]: # 只显示前10个
print(f" {item['protein_id']}: 错误预测了 {item['predicted_ecs']}, GT是 {item['gt_ecs']}")
if len(incorrect['false_positives']) > 10:
print(f" ... 还有 {len(incorrect['false_positives']) - 10} 个")
if incorrect['false_negatives']:
print(f" 假阴性错误 ({len(incorrect['false_negatives'])}个蛋白):")
for item in incorrect['false_negatives'][:10]: # 只显示前10个
print(f" {item['protein_id']}: 漏掉了 {item['missed_ecs']}, 预测了 {item['predicted_ecs']}")
if len(incorrect['false_negatives']) > 10:
print(f" ... 还有 {len(incorrect['false_negatives']) - 10} 个")
if incorrect['zero_prediction']:
print(f" 零预测错误 ({len(incorrect['zero_prediction'])}个蛋白):")
for item in incorrect['zero_prediction']:
print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但预测了0个EC号")
if incorrect['no_prediction']:
print(f" 无预测错误 ({len(incorrect['no_prediction'])}个蛋白):")
for item in incorrect['no_prediction'][:10]: # 只显示前10个
print(f" {item['protein_id']}: GT是 {item['gt_ecs']}, 但没有预测")
if len(incorrect['no_prediction']) > 10:
print(f" ... 还有 {len(incorrect['no_prediction']) - 10} 个")
print() # 空行分隔
# 统计信息
print("\n=== 详细统计信息 ===")
# 统计ground truth中EC号的分布
gt_ec_counts = defaultdict(int)
for ecs in filtered_gt.values():
gt_ec_counts[len(ecs)] += 1
print("Ground truth中EC号数量分布:")
for count, freq in sorted(gt_ec_counts.items()):
print(f" {count}个EC号: {freq}个蛋白")
# 统计预测结果中EC号的分布
pred_ec_counts = defaultdict(int)
for ecs in filtered_pred.values():
pred_ec_counts[len(ecs)] += 1
print("\n预测结果中EC号数量分布:")
for count, freq in sorted(pred_ec_counts.items()):
print(f" {count}个EC号: {freq}个蛋白")
# 保存结果
output_file = 'test_results/ec_accuracy_results.json'
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
# #保存ground truth
# with open('test_results/ground_truth.json', 'w', encoding='utf-8') as f:
# json.dump(filtered_gt, f, indent=2, ensure_ascii=False)
# #保存预测结果
# with open('test_results/predictions.json', 'w', encoding='utf-8') as f:
# json.dump(filtered_pred, f, indent=2, ensure_ascii=False)
print(f"\n结果已保存到 {output_file}")
if __name__ == "__main__":
main() |