Spaces:
Runtime error
Runtime error
File size: 6,506 Bytes
5c20520 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from pathlib import Path
from tqdm import tqdm
from utils.openai_access import call_chatgpt
from utils.mpr import MultipleProcessRunnerSimplifier
from utils.generate_protein_prompt import generate_prompt
qa_data = None
def _load_qa_data(prompt_path):
global qa_data
if qa_data is None:
qa_data = {}
with open(prompt_path, 'r') as f:
for line in f:
if line.strip():
item = json.loads(line.strip())
qa_data[item['index']] = item
return qa_data
def process_single_qa(process_id, idx, qa_index, writer, save_dir):
"""处理单个QA对并生成答案"""
try:
qa_item = qa_data[qa_index]
protein_id = qa_item['protein_id']
prompt = qa_item['prompt']
question = qa_item['question']
ground_truth = qa_item['ground_truth']
# 调用LLM生成答案
llm_response = call_chatgpt(prompt)
# 构建结果数据
result = {
'protein_id': protein_id,
'index': qa_index,
'question': question,
'ground_truth': ground_truth,
'llm_answer': llm_response
}
# 保存文件,文件名使用protein_id和index
save_path = os.path.join(save_dir, f"{protein_id}_{qa_index}.json")
with open(save_path, 'w') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"Error processing QA index {qa_index}: {str(e)}")
def get_missing_qa_indices(save_dir):
"""检查哪些QA索引尚未成功生成数据"""
# 获取所有应该生成的qa索引
all_qa_indices = list(qa_data.keys())
# 存储问题qa索引(包括空文件和未生成的文件)
problem_qa_indices = set()
# 检查每个应该存在的qa索引
for qa_index in tqdm(all_qa_indices, desc="检查QA数据文件"):
protein_id = qa_data[qa_index]['protein_id']
json_file = Path(save_dir) / f"{protein_id}_{qa_index}.json"
# 如果文件不存在,加入问题列表
if not json_file.exists():
problem_qa_indices.add(qa_index)
continue
# 检查文件内容
try:
with open(json_file, 'r') as f:
data = json.load(f)
# 检查文件内容是否为空或缺少必要字段
if (data is None or len(data) == 0 or
'llm_answer' not in data or
data.get('llm_answer') is None or
data.get('llm_answer') == ''):
problem_qa_indices.add(qa_index)
json_file.unlink() # 删除空文件或不完整文件
except (json.JSONDecodeError, Exception) as e:
# 如果JSON解析失败,也认为是问题文件
problem_qa_indices.add(qa_index)
try:
json_file.unlink() # 删除损坏的文件
except:
pass
return problem_qa_indices
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_path", type=str,
default="data/processed_data/prompts@clean_test.jsonl",
help="Path to the JSONL file containing QA prompts")
parser.add_argument("--n_process", type=int, default=64,
help="Number of parallel processes")
parser.add_argument("--save_dir", type=str,
default="data/clean_test_results_top2",
help="Directory to save results")
parser.add_argument("--max_iterations", type=int, default=3,
help="Maximum number of iterations to try generating all QA pairs")
args = parser.parse_args()
# 创建保存目录
os.makedirs(args.save_dir, exist_ok=True)
# 加载QA数据
_load_qa_data(args.prompt_path)
print(f"已加载 {len(qa_data)} 个QA对")
# 循环检查和生成,直到所有QA对都已生成或达到最大迭代次数
iteration = 0
while iteration < args.max_iterations:
iteration += 1
print(f"\n开始第 {iteration} 轮检查和生成")
# 获取缺失的QA索引
missing_qa_indices = get_missing_qa_indices(args.save_dir)
# 如果没有缺失的QA索引,则完成
if not missing_qa_indices:
print("所有QA数据已成功生成!")
break
print(f"发现 {len(missing_qa_indices)} 个缺失的QA数据,准备生成")
# 将缺失的QA索引列表转换为列表
missing_qa_indices_list = sorted(list(missing_qa_indices))
# 保存当前缺失的QA索引列表,用于记录
missing_ids_file = Path(args.save_dir) / f"missing_qa_indices_iteration_{iteration}.txt"
with open(missing_ids_file, 'w') as f:
for qa_index in missing_qa_indices_list:
protein_id = qa_data[qa_index]['protein_id']
f.write(f"{protein_id}_{qa_index}\n")
# 使用多进程处理生成缺失的QA数据
mprs = MultipleProcessRunnerSimplifier(
data=missing_qa_indices_list,
do=lambda process_id, idx, qa_index, writer: process_single_qa(process_id, idx, qa_index, writer, args.save_dir),
n_process=args.n_process,
split_strategy="static"
)
mprs.run()
print(f"第 {iteration} 轮生成完成")
# 最后检查一次
final_missing_indices = get_missing_qa_indices(args.save_dir)
if final_missing_indices:
print(f"经过 {iteration} 轮生成后,仍有 {len(final_missing_indices)} 个QA数据未成功生成")
# 保存最终缺失的QA索引列表
final_missing_ids_file = Path(args.save_dir) / "final_missing_qa_indices.txt"
with open(final_missing_ids_file, 'w') as f:
for qa_index in sorted(final_missing_indices):
protein_id = qa_data[qa_index]['protein_id']
f.write(f"{protein_id}_{qa_index}\n")
print(f"最终缺失的QA索引已保存到: {final_missing_ids_file}")
else:
print(f"经过 {iteration} 轮生成,所有QA数据已成功生成!")
if __name__ == "__main__":
main()
|