Spaces:
Runtime error
Runtime error
import sys | |
import os | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import json | |
from pathlib import Path | |
from tqdm import tqdm | |
import pandas as pd | |
import numpy as np | |
from utils.openai_access import call_chatgpt | |
from utils.mpr import MultipleProcessRunnerSimplifier | |
from utils.prompts import LLM_SCORE_PROMPT | |
import re | |
qa_data = {} | |
def load_qa_results_from_dir(results_dir): | |
"""从结果目录加载所有QA结果""" | |
global qa_data | |
qa_data = {} | |
results_path = Path(results_dir) | |
json_files = list(results_path.glob("*.json")) | |
print(f"找到 {len(json_files)} 个结果文件") | |
for json_file in tqdm(json_files, desc="加载QA结果"): | |
try: | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
if ('index' in data and 'protein_id' in data and | |
'ground_truth' in data and 'llm_answer' in data): | |
qa_data[data['index']] = data | |
except Exception as e: | |
print(f"加载文件 {json_file} 时出错: {e}") | |
print(f"成功加载 {len(qa_data)} 个QA对") | |
return qa_data | |
def extract_score_from_response(response): | |
"""从LLM响应中提取分数""" | |
if not response: | |
return None | |
# 尝试解析JSON格式的响应 | |
try: | |
if isinstance(response, str): | |
# 尝试直接解析JSON | |
json_match = re.search(r'\{[^}]*"score"[^}]*\}', response) | |
if json_match: | |
json_obj = json.loads(json_match.group()) | |
return json_obj.get('score') | |
# 尝试提取数字 | |
score_match = re.search(r'"score":\s*(\d+(?:\.\d+)?)', response) | |
if score_match: | |
return float(score_match.group(1)) | |
# 尝试提取纯数字 | |
number_match = re.search(r'(\d+(?:\.\d+)?)', response) | |
if number_match: | |
score = float(number_match.group(1)) | |
if 0 <= score <= 100: | |
return score | |
elif isinstance(response, dict): | |
return response.get('score') | |
except: | |
pass | |
return None | |
def process_single_scoring(process_id, idx, qa_index, writer, save_dir): | |
"""处理单个QA对的打分""" | |
try: | |
qa_item = qa_data[qa_index] | |
protein_id = qa_item['protein_id'] | |
question = qa_item.get('question', '') | |
ground_truth = qa_item['ground_truth'] | |
llm_answer = qa_item['llm_answer'] | |
# 构建打分提示 | |
scoring_prompt = LLM_SCORE_PROMPT.replace('{{ground_truth}}', str(ground_truth)) | |
scoring_prompt = scoring_prompt.replace('{{llm_answer}}', str(llm_answer)) | |
# 调用LLM进行打分 | |
score_response = call_chatgpt(scoring_prompt) | |
score = extract_score_from_response(score_response) | |
# 构建结果数据 | |
result = { | |
'index': qa_index, | |
'protein_id': protein_id, | |
'question': question, | |
'ground_truth': ground_truth, | |
'llm_answer': llm_answer, | |
'score': score, | |
'raw_score_response': score_response | |
} | |
# 保存文件 | |
save_path = os.path.join(save_dir, f"score_{protein_id}_{qa_index}.json") | |
with open(save_path, 'w') as f: | |
json.dump(result, f, indent=2, ensure_ascii=False) | |
except Exception as e: | |
print(f"处理QA索引 {qa_index} 时出错: {str(e)}") | |
def get_missing_score_indices(save_dir): | |
"""检查哪些QA索引尚未完成打分""" | |
all_qa_indices = list(qa_data.keys()) | |
problem_qa_indices = set() | |
for qa_index in tqdm(all_qa_indices, desc="检查打分文件"): | |
protein_id = qa_data[qa_index]['protein_id'] | |
json_file = Path(save_dir) / f"score_{protein_id}_{qa_index}.json" | |
if not json_file.exists(): | |
problem_qa_indices.add(qa_index) | |
continue | |
try: | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
if (data is None or len(data) == 0 or | |
'score' not in data or | |
data.get('score') is None): | |
problem_qa_indices.add(qa_index) | |
json_file.unlink() | |
except Exception as e: | |
problem_qa_indices.add(qa_index) | |
try: | |
json_file.unlink() | |
except: | |
pass | |
return problem_qa_indices | |
def collect_scores_to_json(save_dir, output_json): | |
"""收集所有打分结果并保存为JSON文件""" | |
results = [] | |
save_path = Path(save_dir) | |
score_files = list(save_path.glob("score_*.json")) | |
for score_file in tqdm(score_files, desc="收集打分结果"): | |
try: | |
with open(score_file, 'r') as f: | |
data = json.load(f) | |
results.append({ | |
'index': data.get('index'), | |
'protein_id': data.get('protein_id'), | |
'question': data.get('question', ''), | |
'ground_truth': data.get('ground_truth'), | |
'llm_answer': data.get('llm_answer'), | |
'score': data.get('score') | |
}) | |
except Exception as e: | |
print(f"读取文件 {score_file} 时出错: {e}") | |
# 按index排序 | |
results.sort(key=lambda x: x.get('index', 0)) | |
# 保存为JSON文件 | |
with open(output_json, 'w', encoding='utf-8') as f: | |
json.dump(results, f, indent=2, ensure_ascii=False) | |
print(f"打分结果已保存到: {output_json}") | |
# 转换为DataFrame用于分析 | |
df = pd.DataFrame(results) | |
return df | |
def analyze_scores(df): | |
"""对打分结果进行统计分析""" | |
print("\n=== 打分结果统计分析 ===") | |
# 基本统计 | |
valid_scores = df[df['score'].notna()]['score'] | |
if len(valid_scores) == 0: | |
print("没有有效的打分结果") | |
return | |
print(f"总样本数: {len(df)}") | |
print(f"有效打分数: {len(valid_scores)}") | |
print(f"无效打分数: {len(df) - len(valid_scores)}") | |
print(f"有效率: {len(valid_scores)/len(df)*100:.2f}%") | |
print(f"\n分数统计:") | |
print(f"平均分: {valid_scores.mean():.2f}") | |
print(f"中位数: {valid_scores.median():.2f}") | |
print(f"标准差: {valid_scores.std():.2f}") | |
print(f"最高分: {valid_scores.max():.2f}") | |
print(f"最低分: {valid_scores.min():.2f}") | |
# 分数分布 | |
print(f"\n分数分布:") | |
bins = [0, 20, 40, 60, 80, 100] | |
labels = ['0-20', '21-40', '41-60', '61-80', '81-100'] | |
for i, (low, high) in enumerate(zip(bins[:-1], bins[1:])): | |
count = len(valid_scores[(valid_scores >= low) & (valid_scores <= high)]) | |
percentage = count / len(valid_scores) * 100 | |
print(f"{labels[i]}: {count} ({percentage:.1f}%)") | |
# 分位数 | |
print(f"\n分位数:") | |
quantiles = [0.1, 0.25, 0.5, 0.75, 0.9] | |
for q in quantiles: | |
print(f"{int(q*100)}%分位数: {valid_scores.quantile(q):.2f}") | |
# 按蛋白质ID分析(如果样本足够多) | |
if len(df['protein_id'].unique()) > 1: | |
print(f"\n按蛋白质ID分析:") | |
protein_stats = df[df['score'].notna()].groupby('protein_id')['score'].agg(['count', 'mean', 'std']).round(2) | |
print(protein_stats.head(10)) | |
# 保存统计分析结果 | |
stats_result = { | |
"basic_stats": { | |
"total_samples": len(df), | |
"valid_scores": len(valid_scores), | |
"invalid_scores": len(df) - len(valid_scores), | |
"valid_rate": len(valid_scores)/len(df)*100, | |
"mean_score": float(valid_scores.mean()), | |
"median_score": float(valid_scores.median()), | |
"std_score": float(valid_scores.std()), | |
"max_score": float(valid_scores.max()), | |
"min_score": float(valid_scores.min()) | |
}, | |
"distribution": {}, | |
"quantiles": {} | |
} | |
# 分数分布统计 | |
for i, (low, high) in enumerate(zip(bins[:-1], bins[1:])): | |
count = len(valid_scores[(valid_scores >= low) & (valid_scores <= high)]) | |
percentage = count / len(valid_scores) * 100 | |
stats_result["distribution"][labels[i]] = { | |
"count": count, | |
"percentage": percentage | |
} | |
# 分位数统计 | |
for q in quantiles: | |
stats_result["quantiles"][f"{int(q*100)}%"] = float(valid_scores.quantile(q)) | |
return stats_result | |
def main(): | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--results_dir", type=str, | |
default="data/evolla_hard_motif_go", | |
help="包含LLM答案结果的目录") | |
parser.add_argument("--n_process", type=int, default=32, | |
help="并行进程数") | |
parser.add_argument("--save_dir", type=str, | |
default="data/llm_scores", | |
help="保存打分结果的目录") | |
parser.add_argument("--output_json", type=str, | |
default="data/llm_scores_results.json", | |
help="输出JSON文件路径") | |
parser.add_argument("--stats_json", type=str, | |
default="data/llm_scores_stats.json", | |
help="统计分析结果JSON文件路径") | |
parser.add_argument("--max_iterations", type=int, default=3, | |
help="最大迭代次数") | |
args = parser.parse_args() | |
# 创建保存目录 | |
os.makedirs(args.save_dir, exist_ok=True) | |
os.makedirs(os.path.dirname(args.output_json), exist_ok=True) | |
# 加载QA结果数据 | |
load_qa_results_from_dir(args.results_dir) | |
if not qa_data: | |
print("没有找到有效的QA结果数据") | |
return | |
# 循环检查和打分 | |
iteration = 0 | |
while iteration < args.max_iterations: | |
iteration += 1 | |
print(f"\n开始第 {iteration} 轮打分") | |
# 获取缺失打分的QA索引 | |
missing_indices = get_missing_score_indices(args.save_dir) | |
if not missing_indices: | |
print("所有QA对已完成打分!") | |
break | |
print(f"发现 {len(missing_indices)} 个待打分的QA对") | |
missing_indices_list = sorted(list(missing_indices)) | |
# 使用多进程处理打分 | |
mprs = MultipleProcessRunnerSimplifier( | |
data=missing_indices_list, | |
do=lambda process_id, idx, qa_index, writer: process_single_scoring(process_id, idx, qa_index, writer, args.save_dir), | |
n_process=args.n_process, | |
split_strategy="static" | |
) | |
mprs.run() | |
print(f"第 {iteration} 轮打分完成") | |
# 收集结果并保存为JSON | |
df = collect_scores_to_json(args.save_dir, args.output_json) | |
# 进行统计分析 | |
stats_result = analyze_scores(df) | |
# 保存统计分析结果为JSON | |
with open(args.stats_json, 'w', encoding='utf-8') as f: | |
json.dump(stats_result, f, indent=2, ensure_ascii=False) | |
print(f"统计分析结果已保存到: {args.stats_json}") | |
# 检查最终结果 | |
final_missing = get_missing_score_indices(args.save_dir) | |
if final_missing: | |
print(f"\n仍有 {len(final_missing)} 个QA对未能成功打分") | |
else: | |
print(f"\n所有 {len(qa_data)} 个QA对已成功完成打分!") | |
if __name__ == "__main__": | |
main() | |