Spaces:
Runtime error
Runtime error
File size: 6,191 Bytes
5c20520 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
from pathlib import Path
from tqdm import tqdm
from utils.openai_access import call_chatgpt
from utils.mpr import MultipleProcessRunnerSimplifier
from utils.generate_protein_prompt import generate_prompt
prompts = None
def _load_prompts(prompt_path):
global prompts
if prompts is None:
prompts = json.load(open(prompt_path, 'r'))
return prompts
def read_protein_ids(protein_id_path):
"""读取蛋白质ID列表"""
with open(protein_id_path, 'r') as f:
protein_ids = [line.strip() for line in f if line.strip()]
return protein_ids
def process_single_protein(process_id, idx, protein_id, writer, save_dir):
"""处理单个蛋白质的motif信息并生成摘要"""
try:
# prompt = generate_prompt(protein_id)
prompt = prompts[protein_id]
response = call_chatgpt(prompt)
# 写入单独的文件
save_path = os.path.join(save_dir, f"{protein_id}.json")
with open(save_path, 'w') as f:
json.dump(response, f, indent=2)
except Exception as e:
print(f"Error processing protein {protein_id}: {str(e)}")
def get_missing_protein_ids(save_dir):
"""检查哪些蛋白质ID尚未成功生成数据"""
# 读取所有应该生成的protein_id
all_protein_ids = list(prompts.keys())
# with open(all_protein_ids_path, 'r') as f:
# all_protein_ids = set(line.strip() for line in f if line.strip())
# 存储问题protein_id(包括空文件和未生成的文件)
problem_protein_ids = set()
# 检查每个应该存在的protein_id
for protein_id in tqdm(all_protein_ids, desc="检查蛋白质数据文件"):
json_file = Path(save_dir) / f"{protein_id}.json"
# 如果文件不存在,加入问题列表
if not json_file.exists():
problem_protein_ids.add(protein_id)
continue
# 检查文件内容
try:
with open(json_file, 'r') as f:
data = json.load(f)
# 检查文件内容是否为空或null
if data is None or len(data) == 0:
problem_protein_ids.add(protein_id)
json_file.unlink() # 删除空文件
except (json.JSONDecodeError, Exception) as e:
# 如果JSON解析失败,也认为是问题文件
problem_protein_ids.add(protein_id)
try:
json_file.unlink() # 删除损坏的文件
except:
pass
return problem_protein_ids
def main():
import argparse
parser = argparse.ArgumentParser()
# parser.add_argument("--all_protein_ids_path", type=str,
# default="/zhuangkai/projects/TTS4Protein/data/processed_data/protein_id@1024_go@10_covermotif_go.txt",
# help="Path to the file containing all protein IDs that should be generated")
parser.add_argument("--prompt_path", type=str,
default="data/processed_data/prompts@clean_test.json",
help="Path to the file containing prompts")
parser.add_argument("--n_process", type=int, default=64,
help="Number of parallel processes")
parser.add_argument("--save_dir", type=str,
default="data/clean_test_results_top2",
help="Directory to save results")
parser.add_argument("--max_iterations", type=int, default=3,
help="Maximum number of iterations to try generating all proteins")
args = parser.parse_args()
# 创建保存目录
os.makedirs(args.save_dir, exist_ok=True)
# 加载提示
_load_prompts(args.prompt_path)
print(f"已加载 {len(prompts)} 个提示")
# 循环检查和生成,直到所有蛋白质都已生成或达到最大迭代次数
iteration = 0
while iteration < args.max_iterations:
iteration += 1
print(f"\n开始第 {iteration} 轮检查和生成")
# 获取缺失的蛋白质ID
missing_protein_ids = get_missing_protein_ids(args.save_dir)
# 如果没有缺失的蛋白质ID,则完成
if not missing_protein_ids:
print("所有蛋白质数据已成功生成!")
break
print(f"发现 {len(missing_protein_ids)} 个缺失的蛋白质数据,准备生成")
# 将缺失的蛋白质ID列表转换为列表
missing_protein_ids_list = sorted(list(missing_protein_ids))
# 保存当前缺失的蛋白质ID列表,用于记录
missing_ids_file = Path(args.save_dir) / f"missing_protein_ids_iteration_{iteration}.txt"
with open(missing_ids_file, 'w') as f:
for protein_id in missing_protein_ids_list:
f.write(f"{protein_id}\n")
# 使用多进程处理生成缺失的蛋白质数据
mprs = MultipleProcessRunnerSimplifier(
data=missing_protein_ids_list,
do=lambda process_id, idx, protein_id, writer: process_single_protein(process_id, idx, protein_id, writer, args.save_dir),
n_process=args.n_process,
split_strategy="static"
)
mprs.run()
print(f"第 {iteration} 轮生成完成")
# 最后检查一次
final_missing_ids = get_missing_protein_ids(args.save_dir)
if final_missing_ids:
print(f"经过 {iteration} 轮生成后,仍有 {len(final_missing_ids)} 个蛋白质数据未成功生成")
# 保存最终缺失的蛋白质ID列表
final_missing_ids_file = Path(args.save_dir) / "final_missing_protein_ids.txt"
with open(final_missing_ids_file, 'w') as f:
for protein_id in sorted(final_missing_ids):
f.write(f"{protein_id}\n")
print(f"最终缺失的蛋白质ID已保存到: {final_missing_ids_file}")
else:
print(f"经过 {iteration} 轮生成,所有蛋白质数据已成功生成!")
if __name__ == "__main__":
main()
|