File size: 6,506 Bytes
5c20520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from pathlib import Path
from tqdm import tqdm
from utils.openai_access import call_chatgpt
from utils.mpr import MultipleProcessRunnerSimplifier
from utils.generate_protein_prompt import generate_prompt

qa_data = None

def _load_qa_data(prompt_path):
    global qa_data
    if qa_data is None:
        qa_data = {}
        with open(prompt_path, 'r') as f:
            for line in f:
                if line.strip():
                    item = json.loads(line.strip())
                    qa_data[item['index']] = item
    return qa_data

def process_single_qa(process_id, idx, qa_index, writer, save_dir):
    """处理单个QA对并生成答案"""
    try:
        qa_item = qa_data[qa_index]
        protein_id = qa_item['protein_id']
        prompt = qa_item['prompt']
        question = qa_item['question']
        ground_truth = qa_item['ground_truth']
        
        # 调用LLM生成答案
        llm_response = call_chatgpt(prompt)
        
        # 构建结果数据
        result = {
            'protein_id': protein_id,
            'index': qa_index,
            'question': question,
            'ground_truth': ground_truth,
            'llm_answer': llm_response
        }
        
        # 保存文件,文件名使用protein_id和index
        save_path = os.path.join(save_dir, f"{protein_id}_{qa_index}.json")
        with open(save_path, 'w') as f:
            json.dump(result, f, indent=2, ensure_ascii=False)
        
    except Exception as e:
        print(f"Error processing QA index {qa_index}: {str(e)}")

def get_missing_qa_indices(save_dir):
    """检查哪些QA索引尚未成功生成数据"""
    # 获取所有应该生成的qa索引
    all_qa_indices = list(qa_data.keys())
    
    # 存储问题qa索引(包括空文件和未生成的文件)
    problem_qa_indices = set()
    
    # 检查每个应该存在的qa索引
    for qa_index in tqdm(all_qa_indices, desc="检查QA数据文件"):
        protein_id = qa_data[qa_index]['protein_id']
        json_file = Path(save_dir) / f"{protein_id}_{qa_index}.json"
        
        # 如果文件不存在,加入问题列表
        if not json_file.exists():
            problem_qa_indices.add(qa_index)
            continue
            
        # 检查文件内容
        try:
            with open(json_file, 'r') as f:
                data = json.load(f)
                # 检查文件内容是否为空或缺少必要字段
                if (data is None or len(data) == 0 or 
                    'llm_answer' not in data or 
                    data.get('llm_answer') is None or 
                    data.get('llm_answer') == ''):
                    problem_qa_indices.add(qa_index)
                    json_file.unlink()  # 删除空文件或不完整文件
        except (json.JSONDecodeError, Exception) as e:
            # 如果JSON解析失败,也认为是问题文件
            problem_qa_indices.add(qa_index)
            try:
                json_file.unlink()  # 删除损坏的文件
            except:
                pass
    
    return problem_qa_indices

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt_path", type=str,
                      default="data/processed_data/prompts@clean_test.jsonl",
                      help="Path to the JSONL file containing QA prompts")
    parser.add_argument("--n_process", type=int, default=64,
                      help="Number of parallel processes")
    parser.add_argument("--save_dir", type=str,
                      default="data/clean_test_results_top2",
                      help="Directory to save results")
    parser.add_argument("--max_iterations", type=int, default=3,
                      help="Maximum number of iterations to try generating all QA pairs")
    args = parser.parse_args()

    # 创建保存目录
    os.makedirs(args.save_dir, exist_ok=True)
    
    # 加载QA数据
    _load_qa_data(args.prompt_path)
    print(f"已加载 {len(qa_data)} 个QA对")
    
    # 循环检查和生成,直到所有QA对都已生成或达到最大迭代次数
    iteration = 0
    while iteration < args.max_iterations:
        iteration += 1
        print(f"\n开始第 {iteration} 轮检查和生成")
        
        # 获取缺失的QA索引
        missing_qa_indices = get_missing_qa_indices(args.save_dir)
        
        # 如果没有缺失的QA索引,则完成
        if not missing_qa_indices:
            print("所有QA数据已成功生成!")
            break
        
        print(f"发现 {len(missing_qa_indices)} 个缺失的QA数据,准备生成")
        
        # 将缺失的QA索引列表转换为列表
        missing_qa_indices_list = sorted(list(missing_qa_indices))
        
        # 保存当前缺失的QA索引列表,用于记录
        missing_ids_file = Path(args.save_dir) / f"missing_qa_indices_iteration_{iteration}.txt"
        with open(missing_ids_file, 'w') as f:
            for qa_index in missing_qa_indices_list:
                protein_id = qa_data[qa_index]['protein_id']
                f.write(f"{protein_id}_{qa_index}\n")
        
        # 使用多进程处理生成缺失的QA数据
        mprs = MultipleProcessRunnerSimplifier(
            data=missing_qa_indices_list,
            do=lambda process_id, idx, qa_index, writer: process_single_qa(process_id, idx, qa_index, writer, args.save_dir),
            n_process=args.n_process,
            split_strategy="static"
        )
        mprs.run()
        
        print(f"第 {iteration} 轮生成完成")
    
    # 最后检查一次
    final_missing_indices = get_missing_qa_indices(args.save_dir)
    if final_missing_indices:
        print(f"经过 {iteration} 轮生成后,仍有 {len(final_missing_indices)} 个QA数据未成功生成")
        # 保存最终缺失的QA索引列表
        final_missing_ids_file = Path(args.save_dir) / "final_missing_qa_indices.txt"
        with open(final_missing_ids_file, 'w') as f:
            for qa_index in sorted(final_missing_indices):
                protein_id = qa_data[qa_index]['protein_id']
                f.write(f"{protein_id}_{qa_index}\n")
        print(f"最终缺失的QA索引已保存到: {final_missing_ids_file}")
    else:
        print(f"经过 {iteration} 轮生成,所有QA数据已成功生成!")

if __name__ == "__main__":
    main()