Spaces:
Runtime error
Runtime error
import json | |
import sys | |
import os | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from jinja2 import Template | |
try: | |
from utils.protein_go_analysis import analyze_protein_go | |
from utils.prompts import ENZYME_PROMPT, RELATION_SEMANTIC_PROMPT, FUNCTION_PROMPT | |
from utils.get_motif import get_motif_pfam | |
except ImportError: | |
from protein_go_analysis import analyze_protein_go | |
from prompts import ENZYME_PROMPT, RELATION_SEMANTIC_PROMPT, FUNCTION_PROMPT | |
from get_motif import get_motif_pfam | |
from tqdm import tqdm | |
class InterProDescriptionManager: | |
"""管理InterPro描述信息的类,避免重复读取文件""" | |
def __init__(self, interpro_data_path, interproscan_info_path): | |
""" | |
初始化时读取所有需要的数据 | |
Args: | |
interpro_data_path: interpro_data.json文件路径 | |
interproscan_info_path: interproscan_info.json文件路径 | |
""" | |
self.interpro_data_path = interpro_data_path | |
self.interproscan_info_path = interproscan_info_path | |
self.interpro_data = None | |
self.interproscan_info = None | |
self._load_data() | |
def _load_data(self): | |
"""加载数据文件,只执行一次""" | |
if self.interpro_data_path and os.path.exists(self.interpro_data_path): | |
with open(self.interpro_data_path, 'r') as f: | |
self.interpro_data = json.load(f) | |
if self.interproscan_info_path and os.path.exists(self.interproscan_info_path): | |
with open(self.interproscan_info_path, 'r') as f: | |
self.interproscan_info = json.load(f) | |
def get_description(self, protein_id, selected_types=None): | |
""" | |
获取蛋白质的InterPro描述信息 | |
Args: | |
protein_id: 蛋白质ID | |
selected_types: 需要获取的信息类型列表,如['superfamily', 'panther', 'gene3d'] | |
Returns: | |
dict: 包含各类型描述信息的字典 | |
""" | |
if selected_types is None: | |
selected_types = [] | |
if not self.interpro_data or not self.interproscan_info: | |
return {} | |
result = {} | |
# 检查蛋白质是否存在 | |
if protein_id not in self.interproscan_info: | |
return result | |
protein_info = self.interproscan_info[protein_id] | |
interproscan_results = protein_info.get('interproscan_results', {}) | |
# 遍历选定的类型 | |
for info_type in selected_types: | |
if info_type in interproscan_results: | |
type_descriptions = {} | |
# 获取该类型的所有IPR ID | |
for entry in interproscan_results[info_type]: | |
for key, ipr_id in entry.items(): | |
if ipr_id and ipr_id in self.interpro_data: | |
type_descriptions[ipr_id] = { | |
'name': self.interpro_data[ipr_id].get('name', ''), | |
'abstract': self.interpro_data[ipr_id].get('abstract', '') | |
} | |
if type_descriptions: | |
result[info_type] = type_descriptions | |
return result | |
# 全局变量来缓存InterProDescriptionManager实例和lmdb连接 | |
_interpro_manager = None | |
_lmdb_db = None | |
_lmdb_path = None | |
def get_interpro_manager(interpro_data_path, interproscan_info_path): | |
"""获取或创建InterProDescriptionManager实例""" | |
global _interpro_manager | |
if _interpro_manager is None: | |
_interpro_manager = InterProDescriptionManager(interpro_data_path, interproscan_info_path) | |
return _interpro_manager | |
def get_lmdb_connection(lmdb_path): | |
"""获取或创建lmdb连接""" | |
global _lmdb_db, _lmdb_path | |
if _lmdb_db is None or _lmdb_path != lmdb_path: | |
if _lmdb_db is not None: | |
_lmdb_db.close() | |
if lmdb_path and os.path.exists(lmdb_path): | |
import lmdb | |
_lmdb_db = lmdb.open(lmdb_path, readonly=True) | |
_lmdb_path = lmdb_path | |
else: | |
_lmdb_db = None | |
_lmdb_path = None | |
return _lmdb_db | |
def get_prompt_template(selected_info_types=None,lmdb_path=None): | |
""" | |
获取prompt模板,支持可选的信息类型 | |
Args: | |
selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther'] | |
""" | |
if selected_info_types is None: | |
selected_info_types = ['motif', 'go'] # 默认包含motif和go信息 | |
if lmdb_path is None: | |
PROMPT_TEMPLATE = ENZYME_PROMPT + "\n" | |
else: | |
PROMPT_TEMPLATE = FUNCTION_PROMPT + "\n" | |
PROMPT_TEMPLATE += """ | |
input information: | |
{%- if 'motif' in selected_info_types and motif_pfam %} | |
motif:{% for motif_id, motif_info in motif_pfam.items() %} | |
{{motif_id}}: {{motif_info}} | |
{% endfor %} | |
{%- endif %} | |
{%- if 'go' in selected_info_types and go_data.status == 'success' %} | |
GO:{% for go_entry in go_data.go_annotations %} | |
▢ GO term{{loop.index}}: {{go_entry.go_id}} | |
• definition: {{ go_data.all_related_definitions.get(go_entry.go_id, 'not found definition') }} | |
{% endfor %} | |
{%- endif %} | |
{%- for info_type in selected_info_types %} | |
{%- if info_type not in ['motif', 'go'] and interpro_descriptions.get(info_type) %} | |
{{info_type}}:{% for ipr_id, ipr_info in interpro_descriptions[info_type].items() %} | |
▢ {{ipr_id}}: {{ipr_info.name}} | |
• description: {{ipr_info.abstract}} | |
{% endfor %} | |
{%- endif %} | |
{%- endfor %} | |
""" | |
if lmdb_path is not None: | |
PROMPT_TEMPLATE += "\n" + "question: \n {{question}}" | |
return PROMPT_TEMPLATE | |
def get_qa_data(protein_id, lmdb_path): | |
""" | |
从lmdb中获取指定蛋白质的所有QA对 | |
Args: | |
protein_id: 蛋白质ID | |
lmdb_path: lmdb数据库路径 | |
Returns: | |
list: QA对列表,每个元素包含question和ground_truth | |
""" | |
if not lmdb_path or not os.path.exists(lmdb_path): | |
return [] | |
import json | |
qa_pairs = [] | |
try: | |
db = get_lmdb_connection(lmdb_path) | |
if db is None: | |
return [] | |
with db.begin() as txn: | |
# 遍历数字索引的数据,查找匹配的protein_id | |
cursor = txn.cursor() | |
for key, value in cursor: | |
try: | |
# 尝试将key解码为数字(数字索引的数据) | |
key_str = key.decode('utf-8') | |
if key_str.isdigit(): | |
# 这是数字索引的数据,包含protein_id, question, ground_truth | |
data = json.loads(value.decode('utf-8')) | |
if isinstance(data, list) and len(data) >= 3: | |
stored_protein_id, question, ground_truth = data[0], data[1], data[2] | |
if stored_protein_id == protein_id: | |
qa_pairs.append({ | |
'question': question, | |
'ground_truth': ground_truth | |
}) | |
except Exception as e: | |
# 如果解析失败,跳过这个条目 | |
continue | |
except Exception as e: | |
print(f"Error reading lmdb for protein {protein_id}: {e}") | |
return qa_pairs | |
def generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, | |
interpro_data_path=None, interproscan_info_path=None, selected_info_types=None, lmdb_path=None, interpro_manager=None, question=None): | |
""" | |
生成蛋白质prompt | |
Args: | |
selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther'] | |
interpro_data_path: interpro_data.json文件路径 | |
interproscan_info_path: interproscan_info.json文件路径 | |
interpro_manager: InterProDescriptionManager实例,如果提供则优先使用 | |
question: 问题文本,用于QA任务 | |
""" | |
if selected_info_types is None: | |
selected_info_types = ['motif', 'go'] | |
# 获取分析结果 | |
analysis = analyze_protein_go(protein_id, protein2gopath, go_info_path) | |
motif_pfam = get_motif_pfam(protein_id, protein2pfam_path, pfam_descriptions_path) | |
# 获取InterPro描述信息(如果需要的话) | |
interpro_descriptions = {} | |
other_types = [t for t in selected_info_types if t not in ['motif', 'go']] | |
if other_types: | |
if interpro_manager: | |
# 使用提供的manager实例 | |
interpro_descriptions = interpro_manager.get_description(protein_id, other_types) | |
elif interpro_data_path and interproscan_info_path: | |
# 使用全局缓存的manager | |
manager = get_interpro_manager(interpro_data_path, interproscan_info_path) | |
interpro_descriptions = manager.get_description(protein_id, other_types) | |
# 准备模板数据 | |
template_data = { | |
"protein_id": protein_id, | |
"selected_info_types": selected_info_types, | |
"go_data": { | |
"status": analysis["status"], | |
"go_annotations": analysis["go_annotations"] if analysis["status"] == "success" else [], | |
"all_related_definitions": analysis["all_related_definitions"] if analysis["status"] == "success" else {} | |
}, | |
"motif_pfam": motif_pfam, | |
"interpro_descriptions": interpro_descriptions, | |
"question": question | |
} | |
PROMPT_TEMPLATE = get_prompt_template(selected_info_types,lmdb_path) | |
template = Template(PROMPT_TEMPLATE) | |
return template.render(**template_data) | |
def save_prompts_parallel(protein_ids, output_path, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, | |
interpro_data_path=None, interproscan_info_path=None, selected_info_types=None, lmdb_path=None, n_process=8): | |
"""并行生成和保存protein prompts""" | |
import json | |
try: | |
from utils.mpr import MultipleProcessRunnerSimplifier | |
except ImportError: | |
from mpr import MultipleProcessRunnerSimplifier | |
if selected_info_types is None: | |
selected_info_types = ['motif', 'go'] | |
# 在并行处理开始前创建InterProDescriptionManager实例 | |
interpro_manager = None | |
other_types = [t for t in selected_info_types if t not in ['motif', 'go']] | |
if other_types and interpro_data_path and interproscan_info_path: | |
interpro_manager = InterProDescriptionManager(interpro_data_path, interproscan_info_path) | |
# 用于跟踪全局index的共享变量 | |
if lmdb_path: | |
import multiprocessing | |
global_index = multiprocessing.Value('i', 0) # 共享整数,初始值为0 | |
index_lock = multiprocessing.Lock() # 用于同步访问 | |
else: | |
global_index = None | |
index_lock = None | |
results = {} | |
def process_protein(process_id, idx, protein_id, writer): | |
protein_id = protein_id.strip() | |
# 为每个进程初始化lmdb连接 | |
if lmdb_path: | |
get_lmdb_connection(lmdb_path) | |
if lmdb_path: | |
# 如果有lmdb_path,处理QA数据 | |
qa_pairs = get_qa_data(protein_id, lmdb_path) | |
for qa_pair in qa_pairs: | |
question = qa_pair['question'] | |
ground_truth = qa_pair['ground_truth'] | |
prompt = generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, | |
interpro_data_path, interproscan_info_path, selected_info_types, lmdb_path, interpro_manager, question) | |
if prompt == "": | |
continue | |
if writer: | |
# 获取并递增全局index | |
with index_lock: | |
current_index = global_index.value | |
global_index.value += 1 | |
result = { | |
"index": current_index, | |
"protein_id": protein_id, | |
"prompt": prompt, | |
"question": question, | |
"ground_truth": ground_truth | |
} | |
writer.write(json.dumps(result) + '\n') | |
else: | |
# 如果没有lmdb_path,按原来的方式处理 | |
prompt = generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, | |
interpro_data_path, interproscan_info_path, selected_info_types, lmdb_path, interpro_manager) | |
if prompt == "": | |
return | |
if writer: | |
result = {protein_id: prompt} | |
writer.write(json.dumps(result) + '\n') | |
# 使用MultipleProcessRunnerSimplifier进行并行处理 | |
runner = MultipleProcessRunnerSimplifier( | |
data=protein_ids, | |
do=process_protein, | |
save_path=output_path + '.tmp', | |
n_process=n_process, | |
split_strategy="static" | |
) | |
runner.run() | |
# 清理全局lmdb连接 | |
global _lmdb_db | |
if _lmdb_db is not None: | |
_lmdb_db.close() | |
_lmdb_db = None | |
if not lmdb_path: | |
# 如果没有lmdb_path,合并所有结果到一个字典(兼容旧格式) | |
final_results = {} | |
with open(output_path + '.tmp', 'r') as f: | |
for line in f: | |
if line.strip(): # 忽略空行 | |
final_results.update(json.loads(line)) | |
# 保存最终结果为正确的JSON格式 | |
with open(output_path, 'w') as f: | |
json.dump(final_results, f, indent=2) | |
else: | |
# 如果有lmdb_path,直接保存为jsonl格式 | |
import shutil | |
shutil.move(output_path + '.tmp', output_path) | |
# 删除临时文件(如果还存在的话) | |
if os.path.exists(output_path + '.tmp'): | |
os.remove(output_path + '.tmp') | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description='Generate protein prompt') | |
parser.add_argument('--protein_path', type=str, default='data/raw_data/protein_ids_clean.txt') | |
parser.add_argument('--protein2pfam_path', type=str, default='data/processed_data/interproscan_info.json') | |
parser.add_argument('--pfam_descriptions_path', type=str, default='data/raw_data/all_pfam_descriptions.json') | |
parser.add_argument('--protein2gopath', type=str, default='data/processed_data/go_integration_final_topk2.json') | |
parser.add_argument('--go_info_path', type=str, default='data/raw_data/go.json') | |
parser.add_argument('--interpro_data_path', type=str, default='data/raw_data/interpro_data.json') | |
parser.add_argument('--interproscan_info_path', type=str, default='data/processed_data/interproscan_info.json') | |
parser.add_argument('--lmdb_path', type=str, default=None) | |
parser.add_argument('--output_path', type=str, default='data/processed_data/prompts@clean_test.json') | |
parser.add_argument('--selected_info_types', type=str, nargs='+', default=['motif', 'go'], | |
help='选择要包含的信息类型,如: motif go superfamily panther gene3d') | |
parser.add_argument('--n_process', type=int, default=32) | |
args = parser.parse_args() | |
#更新output_path,需要包含selected_info_types | |
args.output_path = args.output_path.replace('.json', '_' + '_'.join(args.selected_info_types) + '.json') | |
print(args) | |
with open(args.protein_path, 'r') as file: | |
protein_ids = file.readlines() | |
save_prompts_parallel( | |
protein_ids=protein_ids, | |
output_path=args.output_path, | |
n_process=args.n_process, | |
protein2gopath=args.protein2gopath, | |
protein2pfam_path=args.protein2pfam_path, | |
pfam_descriptions_path=args.pfam_descriptions_path, | |
go_info_path=args.go_info_path, | |
interpro_data_path=args.interpro_data_path, | |
interproscan_info_path=args.interproscan_info_path, | |
selected_info_types=args.selected_info_types, | |
lmdb_path=args.lmdb_path | |
) | |
# 测试示例 | |
# protein_id = 'A8CF74' | |
# prompt = generate_prompt(protein_id, 'data/processed_data/go_integration_final_topk2.json', | |
# 'data/processed_data/interproscan_info.json', 'data/raw_data/all_pfam_descriptions.json', | |
# 'data/raw_data/go.json', 'data/raw_data/interpro_data.json', | |
# 'data/processed_data/interproscan_info.json', | |
# ['motif', 'go', 'superfamily', 'panther']) | |
# print(prompt) | |