import json import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from jinja2 import Template try: from utils.protein_go_analysis import analyze_protein_go from utils.prompts import ENZYME_PROMPT, RELATION_SEMANTIC_PROMPT, FUNCTION_PROMPT from utils.get_motif import get_motif_pfam except ImportError: from protein_go_analysis import analyze_protein_go from prompts import ENZYME_PROMPT, RELATION_SEMANTIC_PROMPT, FUNCTION_PROMPT from get_motif import get_motif_pfam from tqdm import tqdm class InterProDescriptionManager: """管理InterPro描述信息的类,避免重复读取文件""" def __init__(self, interpro_data_path, interproscan_info_path): """ 初始化时读取所有需要的数据 Args: interpro_data_path: interpro_data.json文件路径 interproscan_info_path: interproscan_info.json文件路径 """ self.interpro_data_path = interpro_data_path self.interproscan_info_path = interproscan_info_path self.interpro_data = None self.interproscan_info = None self._load_data() def _load_data(self): """加载数据文件,只执行一次""" if self.interpro_data_path and os.path.exists(self.interpro_data_path): with open(self.interpro_data_path, 'r') as f: self.interpro_data = json.load(f) if self.interproscan_info_path and os.path.exists(self.interproscan_info_path): with open(self.interproscan_info_path, 'r') as f: self.interproscan_info = json.load(f) def get_description(self, protein_id, selected_types=None): """ 获取蛋白质的InterPro描述信息 Args: protein_id: 蛋白质ID selected_types: 需要获取的信息类型列表,如['superfamily', 'panther', 'gene3d'] Returns: dict: 包含各类型描述信息的字典 """ if selected_types is None: selected_types = [] if not self.interpro_data or not self.interproscan_info: return {} result = {} # 检查蛋白质是否存在 if protein_id not in self.interproscan_info: return result protein_info = self.interproscan_info[protein_id] interproscan_results = protein_info.get('interproscan_results', {}) # 遍历选定的类型 for info_type in selected_types: if info_type in interproscan_results: type_descriptions = {} # 获取该类型的所有IPR ID for entry in interproscan_results[info_type]: for key, ipr_id in entry.items(): if ipr_id and ipr_id in self.interpro_data: type_descriptions[ipr_id] = { 'name': self.interpro_data[ipr_id].get('name', ''), 'abstract': self.interpro_data[ipr_id].get('abstract', '') } if type_descriptions: result[info_type] = type_descriptions return result # 全局变量来缓存InterProDescriptionManager实例和lmdb连接 _interpro_manager = None _lmdb_db = None _lmdb_path = None def get_interpro_manager(interpro_data_path, interproscan_info_path): """获取或创建InterProDescriptionManager实例""" global _interpro_manager if _interpro_manager is None: _interpro_manager = InterProDescriptionManager(interpro_data_path, interproscan_info_path) return _interpro_manager def get_lmdb_connection(lmdb_path): """获取或创建lmdb连接""" global _lmdb_db, _lmdb_path if _lmdb_db is None or _lmdb_path != lmdb_path: if _lmdb_db is not None: _lmdb_db.close() if lmdb_path and os.path.exists(lmdb_path): import lmdb _lmdb_db = lmdb.open(lmdb_path, readonly=True) _lmdb_path = lmdb_path else: _lmdb_db = None _lmdb_path = None return _lmdb_db def get_prompt_template(selected_info_types=None,lmdb_path=None): """ 获取prompt模板,支持可选的信息类型 Args: selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther'] """ if selected_info_types is None: selected_info_types = ['motif', 'go'] # 默认包含motif和go信息 if lmdb_path is None: PROMPT_TEMPLATE = ENZYME_PROMPT + "\n" else: PROMPT_TEMPLATE = FUNCTION_PROMPT + "\n" PROMPT_TEMPLATE += """ input information: {%- if 'motif' in selected_info_types and motif_pfam %} motif:{% for motif_id, motif_info in motif_pfam.items() %} {{motif_id}}: {{motif_info}} {% endfor %} {%- endif %} {%- if 'go' in selected_info_types and go_data.status == 'success' %} GO:{% for go_entry in go_data.go_annotations %} ▢ GO term{{loop.index}}: {{go_entry.go_id}} • definition: {{ go_data.all_related_definitions.get(go_entry.go_id, 'not found definition') }} {% endfor %} {%- endif %} {%- for info_type in selected_info_types %} {%- if info_type not in ['motif', 'go'] and interpro_descriptions.get(info_type) %} {{info_type}}:{% for ipr_id, ipr_info in interpro_descriptions[info_type].items() %} ▢ {{ipr_id}}: {{ipr_info.name}} • description: {{ipr_info.abstract}} {% endfor %} {%- endif %} {%- endfor %} """ if lmdb_path is not None: PROMPT_TEMPLATE += "\n" + "question: \n {{question}}" return PROMPT_TEMPLATE def get_qa_data(protein_id, lmdb_path): """ 从lmdb中获取指定蛋白质的所有QA对 Args: protein_id: 蛋白质ID lmdb_path: lmdb数据库路径 Returns: list: QA对列表,每个元素包含question和ground_truth """ if not lmdb_path or not os.path.exists(lmdb_path): return [] import json qa_pairs = [] try: db = get_lmdb_connection(lmdb_path) if db is None: return [] with db.begin() as txn: # 遍历数字索引的数据,查找匹配的protein_id cursor = txn.cursor() for key, value in cursor: try: # 尝试将key解码为数字(数字索引的数据) key_str = key.decode('utf-8') if key_str.isdigit(): # 这是数字索引的数据,包含protein_id, question, ground_truth data = json.loads(value.decode('utf-8')) if isinstance(data, list) and len(data) >= 3: stored_protein_id, question, ground_truth = data[0], data[1], data[2] if stored_protein_id == protein_id: qa_pairs.append({ 'question': question, 'ground_truth': ground_truth }) except Exception as e: # 如果解析失败,跳过这个条目 continue except Exception as e: print(f"Error reading lmdb for protein {protein_id}: {e}") return qa_pairs def generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, interpro_data_path=None, interproscan_info_path=None, selected_info_types=None, lmdb_path=None, interpro_manager=None, question=None): """ 生成蛋白质prompt Args: selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther'] interpro_data_path: interpro_data.json文件路径 interproscan_info_path: interproscan_info.json文件路径 interpro_manager: InterProDescriptionManager实例,如果提供则优先使用 question: 问题文本,用于QA任务 """ if selected_info_types is None: selected_info_types = ['motif', 'go'] # 获取分析结果 analysis = analyze_protein_go(protein_id, protein2gopath, go_info_path) motif_pfam = get_motif_pfam(protein_id, protein2pfam_path, pfam_descriptions_path) # 获取InterPro描述信息(如果需要的话) interpro_descriptions = {} other_types = [t for t in selected_info_types if t not in ['motif', 'go']] if other_types: if interpro_manager: # 使用提供的manager实例 interpro_descriptions = interpro_manager.get_description(protein_id, other_types) elif interpro_data_path and interproscan_info_path: # 使用全局缓存的manager manager = get_interpro_manager(interpro_data_path, interproscan_info_path) interpro_descriptions = manager.get_description(protein_id, other_types) # 准备模板数据 template_data = { "protein_id": protein_id, "selected_info_types": selected_info_types, "go_data": { "status": analysis["status"], "go_annotations": analysis["go_annotations"] if analysis["status"] == "success" else [], "all_related_definitions": analysis["all_related_definitions"] if analysis["status"] == "success" else {} }, "motif_pfam": motif_pfam, "interpro_descriptions": interpro_descriptions, "question": question } PROMPT_TEMPLATE = get_prompt_template(selected_info_types,lmdb_path) template = Template(PROMPT_TEMPLATE) return template.render(**template_data) def save_prompts_parallel(protein_ids, output_path, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, interpro_data_path=None, interproscan_info_path=None, selected_info_types=None, lmdb_path=None, n_process=8): """并行生成和保存protein prompts""" import json try: from utils.mpr import MultipleProcessRunnerSimplifier except ImportError: from mpr import MultipleProcessRunnerSimplifier if selected_info_types is None: selected_info_types = ['motif', 'go'] # 在并行处理开始前创建InterProDescriptionManager实例 interpro_manager = None other_types = [t for t in selected_info_types if t not in ['motif', 'go']] if other_types and interpro_data_path and interproscan_info_path: interpro_manager = InterProDescriptionManager(interpro_data_path, interproscan_info_path) # 用于跟踪全局index的共享变量 if lmdb_path: import multiprocessing global_index = multiprocessing.Value('i', 0) # 共享整数,初始值为0 index_lock = multiprocessing.Lock() # 用于同步访问 else: global_index = None index_lock = None results = {} def process_protein(process_id, idx, protein_id, writer): protein_id = protein_id.strip() # 为每个进程初始化lmdb连接 if lmdb_path: get_lmdb_connection(lmdb_path) if lmdb_path: # 如果有lmdb_path,处理QA数据 qa_pairs = get_qa_data(protein_id, lmdb_path) for qa_pair in qa_pairs: question = qa_pair['question'] ground_truth = qa_pair['ground_truth'] prompt = generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, interpro_data_path, interproscan_info_path, selected_info_types, lmdb_path, interpro_manager, question) if prompt == "": continue if writer: # 获取并递增全局index with index_lock: current_index = global_index.value global_index.value += 1 result = { "index": current_index, "protein_id": protein_id, "prompt": prompt, "question": question, "ground_truth": ground_truth } writer.write(json.dumps(result) + '\n') else: # 如果没有lmdb_path,按原来的方式处理 prompt = generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, interpro_data_path, interproscan_info_path, selected_info_types, lmdb_path, interpro_manager) if prompt == "": return if writer: result = {protein_id: prompt} writer.write(json.dumps(result) + '\n') # 使用MultipleProcessRunnerSimplifier进行并行处理 runner = MultipleProcessRunnerSimplifier( data=protein_ids, do=process_protein, save_path=output_path + '.tmp', n_process=n_process, split_strategy="static" ) runner.run() # 清理全局lmdb连接 global _lmdb_db if _lmdb_db is not None: _lmdb_db.close() _lmdb_db = None if not lmdb_path: # 如果没有lmdb_path,合并所有结果到一个字典(兼容旧格式) final_results = {} with open(output_path + '.tmp', 'r') as f: for line in f: if line.strip(): # 忽略空行 final_results.update(json.loads(line)) # 保存最终结果为正确的JSON格式 with open(output_path, 'w') as f: json.dump(final_results, f, indent=2) else: # 如果有lmdb_path,直接保存为jsonl格式 import shutil shutil.move(output_path + '.tmp', output_path) # 删除临时文件(如果还存在的话) if os.path.exists(output_path + '.tmp'): os.remove(output_path + '.tmp') if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Generate protein prompt') parser.add_argument('--protein_path', type=str, default='data/raw_data/protein_ids_clean.txt') parser.add_argument('--protein2pfam_path', type=str, default='data/processed_data/interproscan_info.json') parser.add_argument('--pfam_descriptions_path', type=str, default='data/raw_data/all_pfam_descriptions.json') parser.add_argument('--protein2gopath', type=str, default='data/processed_data/go_integration_final_topk2.json') parser.add_argument('--go_info_path', type=str, default='data/raw_data/go.json') parser.add_argument('--interpro_data_path', type=str, default='data/raw_data/interpro_data.json') parser.add_argument('--interproscan_info_path', type=str, default='data/processed_data/interproscan_info.json') parser.add_argument('--lmdb_path', type=str, default=None) parser.add_argument('--output_path', type=str, default='data/processed_data/prompts@clean_test.json') parser.add_argument('--selected_info_types', type=str, nargs='+', default=['motif', 'go'], help='选择要包含的信息类型,如: motif go superfamily panther gene3d') parser.add_argument('--n_process', type=int, default=32) args = parser.parse_args() #更新output_path,需要包含selected_info_types args.output_path = args.output_path.replace('.json', '_' + '_'.join(args.selected_info_types) + '.json') print(args) with open(args.protein_path, 'r') as file: protein_ids = file.readlines() save_prompts_parallel( protein_ids=protein_ids, output_path=args.output_path, n_process=args.n_process, protein2gopath=args.protein2gopath, protein2pfam_path=args.protein2pfam_path, pfam_descriptions_path=args.pfam_descriptions_path, go_info_path=args.go_info_path, interpro_data_path=args.interpro_data_path, interproscan_info_path=args.interproscan_info_path, selected_info_types=args.selected_info_types, lmdb_path=args.lmdb_path ) # 测试示例 # protein_id = 'A8CF74' # prompt = generate_prompt(protein_id, 'data/processed_data/go_integration_final_topk2.json', # 'data/processed_data/interproscan_info.json', 'data/raw_data/all_pfam_descriptions.json', # 'data/raw_data/go.json', 'data/raw_data/interpro_data.json', # 'data/processed_data/interproscan_info.json', # ['motif', 'go', 'superfamily', 'panther']) # print(prompt)