import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import json from pathlib import Path from tqdm import tqdm from utils.openai_access import call_chatgpt from utils.mpr import MultipleProcessRunnerSimplifier from utils.generate_protein_prompt import generate_prompt prompts = None def _load_prompts(prompt_path): global prompts if prompts is None: prompts = json.load(open(prompt_path, 'r')) return prompts def read_protein_ids(protein_id_path): """读取蛋白质ID列表""" with open(protein_id_path, 'r') as f: protein_ids = [line.strip() for line in f if line.strip()] return protein_ids def process_single_protein(process_id, idx, protein_id, writer, save_dir): """处理单个蛋白质的motif信息并生成摘要""" try: # prompt = generate_prompt(protein_id) prompt = prompts[protein_id] response = call_chatgpt(prompt) # 写入单独的文件 save_path = os.path.join(save_dir, f"{protein_id}.json") with open(save_path, 'w') as f: json.dump(response, f, indent=2) except Exception as e: print(f"Error processing protein {protein_id}: {str(e)}") def get_missing_protein_ids(save_dir): """检查哪些蛋白质ID尚未成功生成数据""" # 读取所有应该生成的protein_id all_protein_ids = list(prompts.keys()) # with open(all_protein_ids_path, 'r') as f: # all_protein_ids = set(line.strip() for line in f if line.strip()) # 存储问题protein_id(包括空文件和未生成的文件) problem_protein_ids = set() # 检查每个应该存在的protein_id for protein_id in tqdm(all_protein_ids, desc="检查蛋白质数据文件"): json_file = Path(save_dir) / f"{protein_id}.json" # 如果文件不存在,加入问题列表 if not json_file.exists(): problem_protein_ids.add(protein_id) continue # 检查文件内容 try: with open(json_file, 'r') as f: data = json.load(f) # 检查文件内容是否为空或null if data is None or len(data) == 0: problem_protein_ids.add(protein_id) json_file.unlink() # 删除空文件 except (json.JSONDecodeError, Exception) as e: # 如果JSON解析失败,也认为是问题文件 problem_protein_ids.add(protein_id) try: json_file.unlink() # 删除损坏的文件 except: pass return problem_protein_ids def main(): import argparse parser = argparse.ArgumentParser() # parser.add_argument("--all_protein_ids_path", type=str, # default="/zhuangkai/projects/TTS4Protein/data/processed_data/protein_id@1024_go@10_covermotif_go.txt", # help="Path to the file containing all protein IDs that should be generated") parser.add_argument("--prompt_path", type=str, default="data/processed_data/prompts@clean_test.json", help="Path to the file containing prompts") parser.add_argument("--n_process", type=int, default=64, help="Number of parallel processes") parser.add_argument("--save_dir", type=str, default="data/clean_test_results_top2", help="Directory to save results") parser.add_argument("--max_iterations", type=int, default=3, help="Maximum number of iterations to try generating all proteins") args = parser.parse_args() # 创建保存目录 os.makedirs(args.save_dir, exist_ok=True) # 加载提示 _load_prompts(args.prompt_path) print(f"已加载 {len(prompts)} 个提示") # 循环检查和生成,直到所有蛋白质都已生成或达到最大迭代次数 iteration = 0 while iteration < args.max_iterations: iteration += 1 print(f"\n开始第 {iteration} 轮检查和生成") # 获取缺失的蛋白质ID missing_protein_ids = get_missing_protein_ids(args.save_dir) # 如果没有缺失的蛋白质ID,则完成 if not missing_protein_ids: print("所有蛋白质数据已成功生成!") break print(f"发现 {len(missing_protein_ids)} 个缺失的蛋白质数据,准备生成") # 将缺失的蛋白质ID列表转换为列表 missing_protein_ids_list = sorted(list(missing_protein_ids)) # 保存当前缺失的蛋白质ID列表,用于记录 missing_ids_file = Path(args.save_dir) / f"missing_protein_ids_iteration_{iteration}.txt" with open(missing_ids_file, 'w') as f: for protein_id in missing_protein_ids_list: f.write(f"{protein_id}\n") # 使用多进程处理生成缺失的蛋白质数据 mprs = MultipleProcessRunnerSimplifier( data=missing_protein_ids_list, do=lambda process_id, idx, protein_id, writer: process_single_protein(process_id, idx, protein_id, writer, args.save_dir), n_process=args.n_process, split_strategy="static" ) mprs.run() print(f"第 {iteration} 轮生成完成") # 最后检查一次 final_missing_ids = get_missing_protein_ids(args.save_dir) if final_missing_ids: print(f"经过 {iteration} 轮生成后,仍有 {len(final_missing_ids)} 个蛋白质数据未成功生成") # 保存最终缺失的蛋白质ID列表 final_missing_ids_file = Path(args.save_dir) / "final_missing_protein_ids.txt" with open(final_missing_ids_file, 'w') as f: for protein_id in sorted(final_missing_ids): f.write(f"{protein_id}\n") print(f"最终缺失的蛋白质ID已保存到: {final_missing_ids_file}") else: print(f"经过 {iteration} 轮生成,所有蛋白质数据已成功生成!") if __name__ == "__main__": main()