protein_rag / utils /generate_llm_answers.py
ericzhang1122's picture
Upload folder using huggingface_hub
5c20520 verified
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from pathlib import Path
from tqdm import tqdm
from utils.openai_access import call_chatgpt
from utils.mpr import MultipleProcessRunnerSimplifier
from utils.generate_protein_prompt import generate_prompt
qa_data = None
def _load_qa_data(prompt_path):
global qa_data
if qa_data is None:
qa_data = {}
with open(prompt_path, 'r') as f:
for line in f:
if line.strip():
item = json.loads(line.strip())
qa_data[item['index']] = item
return qa_data
def process_single_qa(process_id, idx, qa_index, writer, save_dir):
"""处理单个QA对并生成答案"""
try:
qa_item = qa_data[qa_index]
protein_id = qa_item['protein_id']
prompt = qa_item['prompt']
question = qa_item['question']
ground_truth = qa_item['ground_truth']
# 调用LLM生成答案
llm_response = call_chatgpt(prompt)
# 构建结果数据
result = {
'protein_id': protein_id,
'index': qa_index,
'question': question,
'ground_truth': ground_truth,
'llm_answer': llm_response
}
# 保存文件,文件名使用protein_id和index
save_path = os.path.join(save_dir, f"{protein_id}_{qa_index}.json")
with open(save_path, 'w') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"Error processing QA index {qa_index}: {str(e)}")
def get_missing_qa_indices(save_dir):
"""检查哪些QA索引尚未成功生成数据"""
# 获取所有应该生成的qa索引
all_qa_indices = list(qa_data.keys())
# 存储问题qa索引(包括空文件和未生成的文件)
problem_qa_indices = set()
# 检查每个应该存在的qa索引
for qa_index in tqdm(all_qa_indices, desc="检查QA数据文件"):
protein_id = qa_data[qa_index]['protein_id']
json_file = Path(save_dir) / f"{protein_id}_{qa_index}.json"
# 如果文件不存在,加入问题列表
if not json_file.exists():
problem_qa_indices.add(qa_index)
continue
# 检查文件内容
try:
with open(json_file, 'r') as f:
data = json.load(f)
# 检查文件内容是否为空或缺少必要字段
if (data is None or len(data) == 0 or
'llm_answer' not in data or
data.get('llm_answer') is None or
data.get('llm_answer') == ''):
problem_qa_indices.add(qa_index)
json_file.unlink() # 删除空文件或不完整文件
except (json.JSONDecodeError, Exception) as e:
# 如果JSON解析失败,也认为是问题文件
problem_qa_indices.add(qa_index)
try:
json_file.unlink() # 删除损坏的文件
except:
pass
return problem_qa_indices
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_path", type=str,
default="data/processed_data/prompts@clean_test.jsonl",
help="Path to the JSONL file containing QA prompts")
parser.add_argument("--n_process", type=int, default=64,
help="Number of parallel processes")
parser.add_argument("--save_dir", type=str,
default="data/clean_test_results_top2",
help="Directory to save results")
parser.add_argument("--max_iterations", type=int, default=3,
help="Maximum number of iterations to try generating all QA pairs")
args = parser.parse_args()
# 创建保存目录
os.makedirs(args.save_dir, exist_ok=True)
# 加载QA数据
_load_qa_data(args.prompt_path)
print(f"已加载 {len(qa_data)} 个QA对")
# 循环检查和生成,直到所有QA对都已生成或达到最大迭代次数
iteration = 0
while iteration < args.max_iterations:
iteration += 1
print(f"\n开始第 {iteration} 轮检查和生成")
# 获取缺失的QA索引
missing_qa_indices = get_missing_qa_indices(args.save_dir)
# 如果没有缺失的QA索引,则完成
if not missing_qa_indices:
print("所有QA数据已成功生成!")
break
print(f"发现 {len(missing_qa_indices)} 个缺失的QA数据,准备生成")
# 将缺失的QA索引列表转换为列表
missing_qa_indices_list = sorted(list(missing_qa_indices))
# 保存当前缺失的QA索引列表,用于记录
missing_ids_file = Path(args.save_dir) / f"missing_qa_indices_iteration_{iteration}.txt"
with open(missing_ids_file, 'w') as f:
for qa_index in missing_qa_indices_list:
protein_id = qa_data[qa_index]['protein_id']
f.write(f"{protein_id}_{qa_index}\n")
# 使用多进程处理生成缺失的QA数据
mprs = MultipleProcessRunnerSimplifier(
data=missing_qa_indices_list,
do=lambda process_id, idx, qa_index, writer: process_single_qa(process_id, idx, qa_index, writer, args.save_dir),
n_process=args.n_process,
split_strategy="static"
)
mprs.run()
print(f"第 {iteration} 轮生成完成")
# 最后检查一次
final_missing_indices = get_missing_qa_indices(args.save_dir)
if final_missing_indices:
print(f"经过 {iteration} 轮生成后,仍有 {len(final_missing_indices)} 个QA数据未成功生成")
# 保存最终缺失的QA索引列表
final_missing_ids_file = Path(args.save_dir) / "final_missing_qa_indices.txt"
with open(final_missing_ids_file, 'w') as f:
for qa_index in sorted(final_missing_indices):
protein_id = qa_data[qa_index]['protein_id']
f.write(f"{protein_id}_{qa_index}\n")
print(f"最终缺失的QA索引已保存到: {final_missing_ids_file}")
else:
print(f"经过 {iteration} 轮生成,所有QA数据已成功生成!")
if __name__ == "__main__":
main()