File size: 6,191 Bytes
5c20520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from pathlib import Path
from tqdm import tqdm
from utils.openai_access import call_chatgpt
from utils.mpr import MultipleProcessRunnerSimplifier
from utils.generate_protein_prompt import generate_prompt

prompts = None

def _load_prompts(prompt_path):
    global prompts
    if prompts is None:
        prompts = json.load(open(prompt_path, 'r'))
    return prompts

def read_protein_ids(protein_id_path):
    """读取蛋白质ID列表"""
    with open(protein_id_path, 'r') as f:
        protein_ids = [line.strip() for line in f if line.strip()]
    return protein_ids

def process_single_protein(process_id, idx, protein_id, writer, save_dir):
    """处理单个蛋白质的motif信息并生成摘要"""
    try:
        # prompt = generate_prompt(protein_id)
        prompt = prompts[protein_id]
        response = call_chatgpt(prompt)
        # 写入单独的文件
        save_path = os.path.join(save_dir, f"{protein_id}.json")
        with open(save_path, 'w') as f:
            json.dump(response, f, indent=2)
        
    except Exception as e:
        print(f"Error processing protein {protein_id}: {str(e)}")

def get_missing_protein_ids(save_dir):
    """检查哪些蛋白质ID尚未成功生成数据"""
    # 读取所有应该生成的protein_id
    all_protein_ids = list(prompts.keys())
    # with open(all_protein_ids_path, 'r') as f:
    #     all_protein_ids = set(line.strip() for line in f if line.strip())
    
    # 存储问题protein_id(包括空文件和未生成的文件)
    problem_protein_ids = set()
    
    # 检查每个应该存在的protein_id
    for protein_id in tqdm(all_protein_ids, desc="检查蛋白质数据文件"):
        json_file = Path(save_dir) / f"{protein_id}.json"
        
        # 如果文件不存在,加入问题列表
        if not json_file.exists():
            problem_protein_ids.add(protein_id)
            continue
            
        # 检查文件内容
        try:
            with open(json_file, 'r') as f:
                data = json.load(f)
                # 检查文件内容是否为空或null
                if data is None or len(data) == 0:
                    problem_protein_ids.add(protein_id)
                    json_file.unlink()  # 删除空文件
        except (json.JSONDecodeError, Exception) as e:
            # 如果JSON解析失败,也认为是问题文件
            problem_protein_ids.add(protein_id)
            try:
                json_file.unlink()  # 删除损坏的文件
            except:
                pass
    
    return problem_protein_ids

def main():
    import argparse
    parser = argparse.ArgumentParser()
    # parser.add_argument("--all_protein_ids_path", type=str, 
    #                   default="/zhuangkai/projects/TTS4Protein/data/processed_data/protein_id@1024_go@10_covermotif_go.txt",
    #                   help="Path to the file containing all protein IDs that should be generated")
    parser.add_argument("--prompt_path", type=str,
                      default="data/processed_data/prompts@clean_test.json",
                      help="Path to the file containing prompts")
    parser.add_argument("--n_process", type=int, default=64,
                      help="Number of parallel processes")
    parser.add_argument("--save_dir", type=str,
                      default="data/clean_test_results_top2",
                      help="Directory to save results")
    parser.add_argument("--max_iterations", type=int, default=3,
                      help="Maximum number of iterations to try generating all proteins")
    args = parser.parse_args()

    # 创建保存目录
    os.makedirs(args.save_dir, exist_ok=True)
    
    # 加载提示
    _load_prompts(args.prompt_path)
    print(f"已加载 {len(prompts)} 个提示")
    
    # 循环检查和生成,直到所有蛋白质都已生成或达到最大迭代次数
    iteration = 0
    while iteration < args.max_iterations:
        iteration += 1
        print(f"\n开始第 {iteration} 轮检查和生成")
        
        # 获取缺失的蛋白质ID
        missing_protein_ids = get_missing_protein_ids(args.save_dir)
        
        # 如果没有缺失的蛋白质ID,则完成
        if not missing_protein_ids:
            print("所有蛋白质数据已成功生成!")
            break
        
        print(f"发现 {len(missing_protein_ids)} 个缺失的蛋白质数据,准备生成")
        
        # 将缺失的蛋白质ID列表转换为列表
        missing_protein_ids_list = sorted(list(missing_protein_ids))
        
        # 保存当前缺失的蛋白质ID列表,用于记录
        missing_ids_file = Path(args.save_dir) / f"missing_protein_ids_iteration_{iteration}.txt"
        with open(missing_ids_file, 'w') as f:
            for protein_id in missing_protein_ids_list:
                f.write(f"{protein_id}\n")
        
        # 使用多进程处理生成缺失的蛋白质数据
        mprs = MultipleProcessRunnerSimplifier(
            data=missing_protein_ids_list,
            do=lambda process_id, idx, protein_id, writer: process_single_protein(process_id, idx, protein_id, writer, args.save_dir),
            n_process=args.n_process,
            split_strategy="static"
        )
        mprs.run()
        
        print(f"第 {iteration} 轮生成完成")
    
    # 最后检查一次
    final_missing_ids = get_missing_protein_ids(args.save_dir)
    if final_missing_ids:
        print(f"经过 {iteration} 轮生成后,仍有 {len(final_missing_ids)} 个蛋白质数据未成功生成")
        # 保存最终缺失的蛋白质ID列表
        final_missing_ids_file = Path(args.save_dir) / "final_missing_protein_ids.txt"
        with open(final_missing_ids_file, 'w') as f:
            for protein_id in sorted(final_missing_ids):
                f.write(f"{protein_id}\n")
        print(f"最终缺失的蛋白质ID已保存到: {final_missing_ids_file}")
    else:
        print(f"经过 {iteration} 轮生成,所有蛋白质数据已成功生成!")

if __name__ == "__main__":
    main()