protein_rag / utils /generate_protein_prompt.py
ericzhang1122's picture
Upload folder using huggingface_hub
5c20520 verified
import json
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from jinja2 import Template
try:
from utils.protein_go_analysis import analyze_protein_go
from utils.prompts import ENZYME_PROMPT, RELATION_SEMANTIC_PROMPT, FUNCTION_PROMPT
from utils.get_motif import get_motif_pfam
except ImportError:
from protein_go_analysis import analyze_protein_go
from prompts import ENZYME_PROMPT, RELATION_SEMANTIC_PROMPT, FUNCTION_PROMPT
from get_motif import get_motif_pfam
from tqdm import tqdm
class InterProDescriptionManager:
"""管理InterPro描述信息的类,避免重复读取文件"""
def __init__(self, interpro_data_path, interproscan_info_path):
"""
初始化时读取所有需要的数据
Args:
interpro_data_path: interpro_data.json文件路径
interproscan_info_path: interproscan_info.json文件路径
"""
self.interpro_data_path = interpro_data_path
self.interproscan_info_path = interproscan_info_path
self.interpro_data = None
self.interproscan_info = None
self._load_data()
def _load_data(self):
"""加载数据文件,只执行一次"""
if self.interpro_data_path and os.path.exists(self.interpro_data_path):
with open(self.interpro_data_path, 'r') as f:
self.interpro_data = json.load(f)
if self.interproscan_info_path and os.path.exists(self.interproscan_info_path):
with open(self.interproscan_info_path, 'r') as f:
self.interproscan_info = json.load(f)
def get_description(self, protein_id, selected_types=None):
"""
获取蛋白质的InterPro描述信息
Args:
protein_id: 蛋白质ID
selected_types: 需要获取的信息类型列表,如['superfamily', 'panther', 'gene3d']
Returns:
dict: 包含各类型描述信息的字典
"""
if selected_types is None:
selected_types = []
if not self.interpro_data or not self.interproscan_info:
return {}
result = {}
# 检查蛋白质是否存在
if protein_id not in self.interproscan_info:
return result
protein_info = self.interproscan_info[protein_id]
interproscan_results = protein_info.get('interproscan_results', {})
# 遍历选定的类型
for info_type in selected_types:
if info_type in interproscan_results:
type_descriptions = {}
# 获取该类型的所有IPR ID
for entry in interproscan_results[info_type]:
for key, ipr_id in entry.items():
if ipr_id and ipr_id in self.interpro_data:
type_descriptions[ipr_id] = {
'name': self.interpro_data[ipr_id].get('name', ''),
'abstract': self.interpro_data[ipr_id].get('abstract', '')
}
if type_descriptions:
result[info_type] = type_descriptions
return result
# 全局变量来缓存InterProDescriptionManager实例和lmdb连接
_interpro_manager = None
_lmdb_db = None
_lmdb_path = None
def get_interpro_manager(interpro_data_path, interproscan_info_path):
"""获取或创建InterProDescriptionManager实例"""
global _interpro_manager
if _interpro_manager is None:
_interpro_manager = InterProDescriptionManager(interpro_data_path, interproscan_info_path)
return _interpro_manager
def get_lmdb_connection(lmdb_path):
"""获取或创建lmdb连接"""
global _lmdb_db, _lmdb_path
if _lmdb_db is None or _lmdb_path != lmdb_path:
if _lmdb_db is not None:
_lmdb_db.close()
if lmdb_path and os.path.exists(lmdb_path):
import lmdb
_lmdb_db = lmdb.open(lmdb_path, readonly=True)
_lmdb_path = lmdb_path
else:
_lmdb_db = None
_lmdb_path = None
return _lmdb_db
def get_prompt_template(selected_info_types=None,lmdb_path=None):
"""
获取prompt模板,支持可选的信息类型
Args:
selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther']
"""
if selected_info_types is None:
selected_info_types = ['motif', 'go'] # 默认包含motif和go信息
if lmdb_path is None:
PROMPT_TEMPLATE = ENZYME_PROMPT + "\n"
else:
PROMPT_TEMPLATE = FUNCTION_PROMPT + "\n"
PROMPT_TEMPLATE += """
input information:
{%- if 'motif' in selected_info_types and motif_pfam %}
motif:{% for motif_id, motif_info in motif_pfam.items() %}
{{motif_id}}: {{motif_info}}
{% endfor %}
{%- endif %}
{%- if 'go' in selected_info_types and go_data.status == 'success' %}
GO:{% for go_entry in go_data.go_annotations %}
▢ GO term{{loop.index}}: {{go_entry.go_id}}
• definition: {{ go_data.all_related_definitions.get(go_entry.go_id, 'not found definition') }}
{% endfor %}
{%- endif %}
{%- for info_type in selected_info_types %}
{%- if info_type not in ['motif', 'go'] and interpro_descriptions.get(info_type) %}
{{info_type}}:{% for ipr_id, ipr_info in interpro_descriptions[info_type].items() %}
▢ {{ipr_id}}: {{ipr_info.name}}
• description: {{ipr_info.abstract}}
{% endfor %}
{%- endif %}
{%- endfor %}
"""
if lmdb_path is not None:
PROMPT_TEMPLATE += "\n" + "question: \n {{question}}"
return PROMPT_TEMPLATE
def get_qa_data(protein_id, lmdb_path):
"""
从lmdb中获取指定蛋白质的所有QA对
Args:
protein_id: 蛋白质ID
lmdb_path: lmdb数据库路径
Returns:
list: QA对列表,每个元素包含question和ground_truth
"""
if not lmdb_path or not os.path.exists(lmdb_path):
return []
import json
qa_pairs = []
try:
db = get_lmdb_connection(lmdb_path)
if db is None:
return []
with db.begin() as txn:
# 遍历数字索引的数据,查找匹配的protein_id
cursor = txn.cursor()
for key, value in cursor:
try:
# 尝试将key解码为数字(数字索引的数据)
key_str = key.decode('utf-8')
if key_str.isdigit():
# 这是数字索引的数据,包含protein_id, question, ground_truth
data = json.loads(value.decode('utf-8'))
if isinstance(data, list) and len(data) >= 3:
stored_protein_id, question, ground_truth = data[0], data[1], data[2]
if stored_protein_id == protein_id:
qa_pairs.append({
'question': question,
'ground_truth': ground_truth
})
except Exception as e:
# 如果解析失败,跳过这个条目
continue
except Exception as e:
print(f"Error reading lmdb for protein {protein_id}: {e}")
return qa_pairs
def generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path,
interpro_data_path=None, interproscan_info_path=None, selected_info_types=None, lmdb_path=None, interpro_manager=None, question=None):
"""
生成蛋白质prompt
Args:
selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther']
interpro_data_path: interpro_data.json文件路径
interproscan_info_path: interproscan_info.json文件路径
interpro_manager: InterProDescriptionManager实例,如果提供则优先使用
question: 问题文本,用于QA任务
"""
if selected_info_types is None:
selected_info_types = ['motif', 'go']
# 获取分析结果
analysis = analyze_protein_go(protein_id, protein2gopath, go_info_path)
motif_pfam = get_motif_pfam(protein_id, protein2pfam_path, pfam_descriptions_path)
# 获取InterPro描述信息(如果需要的话)
interpro_descriptions = {}
other_types = [t for t in selected_info_types if t not in ['motif', 'go']]
if other_types:
if interpro_manager:
# 使用提供的manager实例
interpro_descriptions = interpro_manager.get_description(protein_id, other_types)
elif interpro_data_path and interproscan_info_path:
# 使用全局缓存的manager
manager = get_interpro_manager(interpro_data_path, interproscan_info_path)
interpro_descriptions = manager.get_description(protein_id, other_types)
# 准备模板数据
template_data = {
"protein_id": protein_id,
"selected_info_types": selected_info_types,
"go_data": {
"status": analysis["status"],
"go_annotations": analysis["go_annotations"] if analysis["status"] == "success" else [],
"all_related_definitions": analysis["all_related_definitions"] if analysis["status"] == "success" else {}
},
"motif_pfam": motif_pfam,
"interpro_descriptions": interpro_descriptions,
"question": question
}
PROMPT_TEMPLATE = get_prompt_template(selected_info_types,lmdb_path)
template = Template(PROMPT_TEMPLATE)
return template.render(**template_data)
def save_prompts_parallel(protein_ids, output_path, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path,
interpro_data_path=None, interproscan_info_path=None, selected_info_types=None, lmdb_path=None, n_process=8):
"""并行生成和保存protein prompts"""
import json
try:
from utils.mpr import MultipleProcessRunnerSimplifier
except ImportError:
from mpr import MultipleProcessRunnerSimplifier
if selected_info_types is None:
selected_info_types = ['motif', 'go']
# 在并行处理开始前创建InterProDescriptionManager实例
interpro_manager = None
other_types = [t for t in selected_info_types if t not in ['motif', 'go']]
if other_types and interpro_data_path and interproscan_info_path:
interpro_manager = InterProDescriptionManager(interpro_data_path, interproscan_info_path)
# 用于跟踪全局index的共享变量
if lmdb_path:
import multiprocessing
global_index = multiprocessing.Value('i', 0) # 共享整数,初始值为0
index_lock = multiprocessing.Lock() # 用于同步访问
else:
global_index = None
index_lock = None
results = {}
def process_protein(process_id, idx, protein_id, writer):
protein_id = protein_id.strip()
# 为每个进程初始化lmdb连接
if lmdb_path:
get_lmdb_connection(lmdb_path)
if lmdb_path:
# 如果有lmdb_path,处理QA数据
qa_pairs = get_qa_data(protein_id, lmdb_path)
for qa_pair in qa_pairs:
question = qa_pair['question']
ground_truth = qa_pair['ground_truth']
prompt = generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path,
interpro_data_path, interproscan_info_path, selected_info_types, lmdb_path, interpro_manager, question)
if prompt == "":
continue
if writer:
# 获取并递增全局index
with index_lock:
current_index = global_index.value
global_index.value += 1
result = {
"index": current_index,
"protein_id": protein_id,
"prompt": prompt,
"question": question,
"ground_truth": ground_truth
}
writer.write(json.dumps(result) + '\n')
else:
# 如果没有lmdb_path,按原来的方式处理
prompt = generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path,
interpro_data_path, interproscan_info_path, selected_info_types, lmdb_path, interpro_manager)
if prompt == "":
return
if writer:
result = {protein_id: prompt}
writer.write(json.dumps(result) + '\n')
# 使用MultipleProcessRunnerSimplifier进行并行处理
runner = MultipleProcessRunnerSimplifier(
data=protein_ids,
do=process_protein,
save_path=output_path + '.tmp',
n_process=n_process,
split_strategy="static"
)
runner.run()
# 清理全局lmdb连接
global _lmdb_db
if _lmdb_db is not None:
_lmdb_db.close()
_lmdb_db = None
if not lmdb_path:
# 如果没有lmdb_path,合并所有结果到一个字典(兼容旧格式)
final_results = {}
with open(output_path + '.tmp', 'r') as f:
for line in f:
if line.strip(): # 忽略空行
final_results.update(json.loads(line))
# 保存最终结果为正确的JSON格式
with open(output_path, 'w') as f:
json.dump(final_results, f, indent=2)
else:
# 如果有lmdb_path,直接保存为jsonl格式
import shutil
shutil.move(output_path + '.tmp', output_path)
# 删除临时文件(如果还存在的话)
if os.path.exists(output_path + '.tmp'):
os.remove(output_path + '.tmp')
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Generate protein prompt')
parser.add_argument('--protein_path', type=str, default='data/raw_data/protein_ids_clean.txt')
parser.add_argument('--protein2pfam_path', type=str, default='data/processed_data/interproscan_info.json')
parser.add_argument('--pfam_descriptions_path', type=str, default='data/raw_data/all_pfam_descriptions.json')
parser.add_argument('--protein2gopath', type=str, default='data/processed_data/go_integration_final_topk2.json')
parser.add_argument('--go_info_path', type=str, default='data/raw_data/go.json')
parser.add_argument('--interpro_data_path', type=str, default='data/raw_data/interpro_data.json')
parser.add_argument('--interproscan_info_path', type=str, default='data/processed_data/interproscan_info.json')
parser.add_argument('--lmdb_path', type=str, default=None)
parser.add_argument('--output_path', type=str, default='data/processed_data/prompts@clean_test.json')
parser.add_argument('--selected_info_types', type=str, nargs='+', default=['motif', 'go'],
help='选择要包含的信息类型,如: motif go superfamily panther gene3d')
parser.add_argument('--n_process', type=int, default=32)
args = parser.parse_args()
#更新output_path,需要包含selected_info_types
args.output_path = args.output_path.replace('.json', '_' + '_'.join(args.selected_info_types) + '.json')
print(args)
with open(args.protein_path, 'r') as file:
protein_ids = file.readlines()
save_prompts_parallel(
protein_ids=protein_ids,
output_path=args.output_path,
n_process=args.n_process,
protein2gopath=args.protein2gopath,
protein2pfam_path=args.protein2pfam_path,
pfam_descriptions_path=args.pfam_descriptions_path,
go_info_path=args.go_info_path,
interpro_data_path=args.interpro_data_path,
interproscan_info_path=args.interproscan_info_path,
selected_info_types=args.selected_info_types,
lmdb_path=args.lmdb_path
)
# 测试示例
# protein_id = 'A8CF74'
# prompt = generate_prompt(protein_id, 'data/processed_data/go_integration_final_topk2.json',
# 'data/processed_data/interproscan_info.json', 'data/raw_data/all_pfam_descriptions.json',
# 'data/raw_data/go.json', 'data/raw_data/interpro_data.json',
# 'data/processed_data/interproscan_info.json',
# ['motif', 'go', 'superfamily', 'panther'])
# print(prompt)