protein_rag / cal_llm_score.py
ericzhang1122's picture
Upload folder using huggingface_hub
5c20520 verified
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from utils.openai_access import call_chatgpt
from utils.mpr import MultipleProcessRunnerSimplifier
from utils.prompts import LLM_SCORE_PROMPT
import re
qa_data = {}
def load_qa_results_from_dir(results_dir):
"""从结果目录加载所有QA结果"""
global qa_data
qa_data = {}
results_path = Path(results_dir)
json_files = list(results_path.glob("*.json"))
print(f"找到 {len(json_files)} 个结果文件")
for json_file in tqdm(json_files, desc="加载QA结果"):
try:
with open(json_file, 'r') as f:
data = json.load(f)
if ('index' in data and 'protein_id' in data and
'ground_truth' in data and 'llm_answer' in data):
qa_data[data['index']] = data
except Exception as e:
print(f"加载文件 {json_file} 时出错: {e}")
print(f"成功加载 {len(qa_data)} 个QA对")
return qa_data
def extract_score_from_response(response):
"""从LLM响应中提取分数"""
if not response:
return None
# 尝试解析JSON格式的响应
try:
if isinstance(response, str):
# 尝试直接解析JSON
json_match = re.search(r'\{[^}]*"score"[^}]*\}', response)
if json_match:
json_obj = json.loads(json_match.group())
return json_obj.get('score')
# 尝试提取数字
score_match = re.search(r'"score":\s*(\d+(?:\.\d+)?)', response)
if score_match:
return float(score_match.group(1))
# 尝试提取纯数字
number_match = re.search(r'(\d+(?:\.\d+)?)', response)
if number_match:
score = float(number_match.group(1))
if 0 <= score <= 100:
return score
elif isinstance(response, dict):
return response.get('score')
except:
pass
return None
def process_single_scoring(process_id, idx, qa_index, writer, save_dir):
"""处理单个QA对的打分"""
try:
qa_item = qa_data[qa_index]
protein_id = qa_item['protein_id']
question = qa_item.get('question', '')
ground_truth = qa_item['ground_truth']
llm_answer = qa_item['llm_answer']
# 构建打分提示
scoring_prompt = LLM_SCORE_PROMPT.replace('{{ground_truth}}', str(ground_truth))
scoring_prompt = scoring_prompt.replace('{{llm_answer}}', str(llm_answer))
# 调用LLM进行打分
score_response = call_chatgpt(scoring_prompt)
score = extract_score_from_response(score_response)
# 构建结果数据
result = {
'index': qa_index,
'protein_id': protein_id,
'question': question,
'ground_truth': ground_truth,
'llm_answer': llm_answer,
'score': score,
'raw_score_response': score_response
}
# 保存文件
save_path = os.path.join(save_dir, f"score_{protein_id}_{qa_index}.json")
with open(save_path, 'w') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"处理QA索引 {qa_index} 时出错: {str(e)}")
def get_missing_score_indices(save_dir):
"""检查哪些QA索引尚未完成打分"""
all_qa_indices = list(qa_data.keys())
problem_qa_indices = set()
for qa_index in tqdm(all_qa_indices, desc="检查打分文件"):
protein_id = qa_data[qa_index]['protein_id']
json_file = Path(save_dir) / f"score_{protein_id}_{qa_index}.json"
if not json_file.exists():
problem_qa_indices.add(qa_index)
continue
try:
with open(json_file, 'r') as f:
data = json.load(f)
if (data is None or len(data) == 0 or
'score' not in data or
data.get('score') is None):
problem_qa_indices.add(qa_index)
json_file.unlink()
except Exception as e:
problem_qa_indices.add(qa_index)
try:
json_file.unlink()
except:
pass
return problem_qa_indices
def collect_scores_to_json(save_dir, output_json):
"""收集所有打分结果并保存为JSON文件"""
results = []
save_path = Path(save_dir)
score_files = list(save_path.glob("score_*.json"))
for score_file in tqdm(score_files, desc="收集打分结果"):
try:
with open(score_file, 'r') as f:
data = json.load(f)
results.append({
'index': data.get('index'),
'protein_id': data.get('protein_id'),
'question': data.get('question', ''),
'ground_truth': data.get('ground_truth'),
'llm_answer': data.get('llm_answer'),
'score': data.get('score')
})
except Exception as e:
print(f"读取文件 {score_file} 时出错: {e}")
# 按index排序
results.sort(key=lambda x: x.get('index', 0))
# 保存为JSON文件
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"打分结果已保存到: {output_json}")
# 转换为DataFrame用于分析
df = pd.DataFrame(results)
return df
def analyze_scores(df):
"""对打分结果进行统计分析"""
print("\n=== 打分结果统计分析 ===")
# 基本统计
valid_scores = df[df['score'].notna()]['score']
if len(valid_scores) == 0:
print("没有有效的打分结果")
return
print(f"总样本数: {len(df)}")
print(f"有效打分数: {len(valid_scores)}")
print(f"无效打分数: {len(df) - len(valid_scores)}")
print(f"有效率: {len(valid_scores)/len(df)*100:.2f}%")
print(f"\n分数统计:")
print(f"平均分: {valid_scores.mean():.2f}")
print(f"中位数: {valid_scores.median():.2f}")
print(f"标准差: {valid_scores.std():.2f}")
print(f"最高分: {valid_scores.max():.2f}")
print(f"最低分: {valid_scores.min():.2f}")
# 分数分布
print(f"\n分数分布:")
bins = [0, 20, 40, 60, 80, 100]
labels = ['0-20', '21-40', '41-60', '61-80', '81-100']
for i, (low, high) in enumerate(zip(bins[:-1], bins[1:])):
count = len(valid_scores[(valid_scores >= low) & (valid_scores <= high)])
percentage = count / len(valid_scores) * 100
print(f"{labels[i]}: {count} ({percentage:.1f}%)")
# 分位数
print(f"\n分位数:")
quantiles = [0.1, 0.25, 0.5, 0.75, 0.9]
for q in quantiles:
print(f"{int(q*100)}%分位数: {valid_scores.quantile(q):.2f}")
# 按蛋白质ID分析(如果样本足够多)
if len(df['protein_id'].unique()) > 1:
print(f"\n按蛋白质ID分析:")
protein_stats = df[df['score'].notna()].groupby('protein_id')['score'].agg(['count', 'mean', 'std']).round(2)
print(protein_stats.head(10))
# 保存统计分析结果
stats_result = {
"basic_stats": {
"total_samples": len(df),
"valid_scores": len(valid_scores),
"invalid_scores": len(df) - len(valid_scores),
"valid_rate": len(valid_scores)/len(df)*100,
"mean_score": float(valid_scores.mean()),
"median_score": float(valid_scores.median()),
"std_score": float(valid_scores.std()),
"max_score": float(valid_scores.max()),
"min_score": float(valid_scores.min())
},
"distribution": {},
"quantiles": {}
}
# 分数分布统计
for i, (low, high) in enumerate(zip(bins[:-1], bins[1:])):
count = len(valid_scores[(valid_scores >= low) & (valid_scores <= high)])
percentage = count / len(valid_scores) * 100
stats_result["distribution"][labels[i]] = {
"count": count,
"percentage": percentage
}
# 分位数统计
for q in quantiles:
stats_result["quantiles"][f"{int(q*100)}%"] = float(valid_scores.quantile(q))
return stats_result
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--results_dir", type=str,
default="data/evolla_hard_motif_go",
help="包含LLM答案结果的目录")
parser.add_argument("--n_process", type=int, default=32,
help="并行进程数")
parser.add_argument("--save_dir", type=str,
default="data/llm_scores",
help="保存打分结果的目录")
parser.add_argument("--output_json", type=str,
default="data/llm_scores_results.json",
help="输出JSON文件路径")
parser.add_argument("--stats_json", type=str,
default="data/llm_scores_stats.json",
help="统计分析结果JSON文件路径")
parser.add_argument("--max_iterations", type=int, default=3,
help="最大迭代次数")
args = parser.parse_args()
# 创建保存目录
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(os.path.dirname(args.output_json), exist_ok=True)
# 加载QA结果数据
load_qa_results_from_dir(args.results_dir)
if not qa_data:
print("没有找到有效的QA结果数据")
return
# 循环检查和打分
iteration = 0
while iteration < args.max_iterations:
iteration += 1
print(f"\n开始第 {iteration} 轮打分")
# 获取缺失打分的QA索引
missing_indices = get_missing_score_indices(args.save_dir)
if not missing_indices:
print("所有QA对已完成打分!")
break
print(f"发现 {len(missing_indices)} 个待打分的QA对")
missing_indices_list = sorted(list(missing_indices))
# 使用多进程处理打分
mprs = MultipleProcessRunnerSimplifier(
data=missing_indices_list,
do=lambda process_id, idx, qa_index, writer: process_single_scoring(process_id, idx, qa_index, writer, args.save_dir),
n_process=args.n_process,
split_strategy="static"
)
mprs.run()
print(f"第 {iteration} 轮打分完成")
# 收集结果并保存为JSON
df = collect_scores_to_json(args.save_dir, args.output_json)
# 进行统计分析
stats_result = analyze_scores(df)
# 保存统计分析结果为JSON
with open(args.stats_json, 'w', encoding='utf-8') as f:
json.dump(stats_result, f, indent=2, ensure_ascii=False)
print(f"统计分析结果已保存到: {args.stats_json}")
# 检查最终结果
final_missing = get_missing_score_indices(args.save_dir)
if final_missing:
print(f"\n仍有 {len(final_missing)} 个QA对未能成功打分")
else:
print(f"\n所有 {len(qa_data)} 个QA对已成功完成打分!")
if __name__ == "__main__":
main()