protein_rag / go_integration_pipeline.py
ericzhang1122's picture
Upload folder using huggingface_hub
5c20520 verified
import json
import os
import sys
import argparse
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import torch
from tqdm import tqdm
# 添加路径
root_path = os.path.dirname((os.path.abspath(__file__)))
sys.path.append(root_path)
sys.path.append(os.path.join(root_path, "Models/ProTrek"))
from utils.protein_go_analysis import get_go_definition
class GOIntegrationPipeline:
def __init__(self,
identity_threshold: int = 80,
coverage_threshold: int = 80,
evalue_threshold: float = 1e-50,
topk: int = 2,
protrek_threshold: Optional[float] = None,
use_protrek: bool = False):
"""
GO信息整合管道
Args:
identity_threshold: BLAST identity阈值 (0-100)
coverage_threshold: BLAST coverage阈值 (0-100)
evalue_threshold: BLAST E-value阈值
protrek_threshold: ProTrek分数阈值
use_protrek: 是否使用第二层ProTrek筛选
"""
self.identity_threshold = identity_threshold
self.coverage_threshold = coverage_threshold
self.evalue_threshold = evalue_threshold
self.protrek_threshold = protrek_threshold
self.use_protrek = use_protrek
self.topk = topk
# 加载蛋白质-GO映射数据
self._load_protein_go_dict()
# 如果使用protrek,初始化模型
if self.use_protrek:
self._init_protrek_model()
def _init_protrek_model(self):
"""初始化ProTrek模型"""
from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
config = {
"protein_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/esm2_t33_650M_UR50D",
"text_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"structure_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/foldseek_t30_150M",
"load_protein_pretrained": False,
"load_text_pretrained": False,
"from_checkpoint": "Models/ProTrek/weights/ProTrek_650M_UniRef50/ProTrek_650M_UniRef50.pt"
}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.protrek_model = ProTrekTrimodalModel(**config).to(self.device).eval()
print(f"ProTrek模型已加载到设备: {self.device}")
def _load_protein_go_dict(self):
"""加载蛋白质-GO映射数据"""
self.protein_go_dict = {}
try:
with open('processed_data/protein_go.json', 'r') as f:
for line in f:
data = json.loads(line)
self.protein_go_dict[data['protein_id']] = data['GO_id']
print(f"成功加载蛋白质-GO映射数据,共{len(self.protein_go_dict)}条记录")
except Exception as e:
print(f"加载蛋白质-GO映射数据时发生错误: {str(e)}")
self.protein_go_dict = {}
def _get_go_from_uniprot_id(self, uniprot_id: str) -> List[str]:
"""
从Uniprot ID获取GO ID
Args:
uniprot_id: Uniprot ID
Returns:
使用类内部加载的字典
"""
# 使用类内部加载的字典
return [go_id.split("_")[-1] if "_" in go_id else go_id
for go_id in self.protein_go_dict.get(uniprot_id, [])]
def extract_blast_go_ids(self, blast_results: List[Dict],protein_id: str) -> List[str]:
"""
从BLAST结果中提取符合条件的GO ID
Args:
blast_results: BLAST结果列表
protein_id: 当前蛋白质ID(避免自身匹配)
Returns:
符合条件的GO ID列表
"""
go_ids = []
if self.topk > 0:
# 使用topk策略
for result in blast_results[:self.topk]:
hit_id = result.get('ID', '')
if hit_id == protein_id:
continue
go_ids.extend(self._get_go_from_uniprot_id(hit_id))
else:
# 使用阈值策略
for result in blast_results:
identity = float(result.get('Identity%', 0))
coverage = float(result.get('Coverage%', 0))
evalue = float(result.get('E-value', 1.0))
# 检查是否符合阈值条件
if (identity >= self.identity_threshold and
coverage >= self.coverage_threshold and
evalue <= self.evalue_threshold):
# 获取该hit的protein_id
hit_id = result.get('ID', '')
if hit_id == protein_id:
continue
go_ids.extend(self._get_go_from_uniprot_id(hit_id))
return go_ids
def first_level_filtering(self, interproscan_info: Dict, blast_info: Dict) -> Dict[str, List[str]]:
"""
第一层筛选:合并interproscan和符合条件的blast GO信息
Args:
interproscan_info: InterProScan结果
blast_info: BLAST结果
Returns:
蛋白质ID到GO ID列表的映射
"""
protein_go_dict = {}
for protein_id in interproscan_info.keys():
go_ids = set()
# 添加interproscan的GO信息
interproscan_gos = interproscan_info[protein_id].get('interproscan_results', {}).get('go_id', [])
interproscan_gos = [go_id.split(":")[-1] if ":" in go_id else go_id for go_id in interproscan_gos]
if interproscan_gos:
go_ids.update(interproscan_gos)
# 添加符合条件的blast GO信息
if protein_id in blast_info:
blast_results = blast_info[protein_id].get('blast_results', [])
blast_gos = self.extract_blast_go_ids(blast_results,protein_id)
go_ids.update(blast_gos)
protein_go_dict[protein_id] = list(go_ids)
return protein_go_dict
def calculate_protrek_scores(self, protein_sequences: Dict[str, str],
protein_go_dict: Dict[str, List[str]]) -> Dict[str, Dict]:
"""
计算ProTrek分数
Args:
protein_sequences: 蛋白质序列字典
protein_go_dict: 蛋白质GO映射
Returns:
包含GO分数的字典
"""
results = {}
for protein_id, go_ids in tqdm(protein_go_dict.items(), desc="计算ProTrek分数"):
if protein_id not in protein_sequences:
continue
protein_seq = protein_sequences[protein_id]
go_scores = {}
# 获取GO定义
go_definitions = {}
for go_id in go_ids:
definition = get_go_definition(go_id)
if definition:
go_definitions[go_id] = definition
if not go_definitions:
continue
try:
# 计算蛋白质序列嵌入
seq_emb = self.protrek_model.get_protein_repr([protein_seq])
# 计算文本嵌入和相似度分数
definitions = list(go_definitions.values())
text_embs = self.protrek_model.get_text_repr(definitions)
# 计算相似度分数
scores = (seq_emb @ text_embs.T) / self.protrek_model.temperature
scores = scores.cpu().numpy().flatten()
# 映射回GO ID
for i, go_id in enumerate(go_definitions.keys()):
go_scores[go_id] = float(scores[i])
except Exception as e:
print(f"计算 {protein_id} 的ProTrek分数时出错: {str(e)}")
continue
results[protein_id] = {
"protein_id": protein_id,
"GO_id": go_ids,
"Clip_score": go_scores
}
return results
def second_level_filtering(self, protrek_results: Dict[str, Dict]) -> Dict[str, List[str]]:
"""
第二层筛选:根据ProTrek阈值筛选GO
Args:
protrek_results: ProTrek计算结果
Returns:
筛选后的蛋白质GO映射
"""
filtered_results = {}
for protein_id, data in protrek_results.items():
clip_scores = data.get('Clip_score', {})
filtered_gos = []
for go_id, score in clip_scores.items():
if score >= self.protrek_threshold:
filtered_gos.append(go_id)
if filtered_gos:
filtered_results[protein_id] = filtered_gos
return filtered_results
def generate_filename(self, base_name: str, is_intermediate: bool = False) -> str:
"""生成包含参数信息的文件名"""
if self.topk > 0:
# 如果使用topk,则只包含topk信息
params = f"topk{self.topk}"
else:
# 否则使用原有的参数组合
params = f"identity{self.identity_threshold}_coverage{self.coverage_threshold}_evalue{self.evalue_threshold:.0e}"
if self.use_protrek and self.protrek_threshold is not None:
params += f"_protrek{self.protrek_threshold}"
if is_intermediate:
return f"{base_name}_intermediate_{params}.json"
else:
return f"{base_name}_final_{params}.json"
def run(self, interproscan_info: Dict = None, blast_info: Dict = None,
interproscan_file: str = None, blast_file: str = None,
output_dir: str = "output"):
"""
运行GO整合管道
Args:
interproscan_info: InterProScan结果字典
blast_info: BLAST结果字典
interproscan_file: InterProScan结果文件路径
blast_file: BLAST结果文件路径
output_dir: 输出目录
"""
# 加载数据
if interproscan_info is None and interproscan_file:
with open(interproscan_file, 'r') as f:
interproscan_info = json.load(f)
if blast_info is None and blast_file:
with open(blast_file, 'r') as f:
blast_info = json.load(f)
if not interproscan_info or not blast_info:
raise ValueError("必须提供interproscan_info和blast_info数据或文件路径")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
print("开始第一层筛选...")
# 第一层筛选
protein_go_dict = self.first_level_filtering(interproscan_info, blast_info)
if not self.use_protrek:
# 不使用第二层筛选,直接保存结果
output_file = os.path.join(output_dir, self.generate_filename("go_integration"))
with open(output_file, 'w') as f:
for protein_id, go_ids in protein_go_dict.items():
result = {"protein_id": protein_id, "GO_id": go_ids}
f.write(json.dumps(result) + '\n')
print(f"第一层筛选完成,结果已保存到: {output_file}")
return output_file
print("开始第二层筛选...")
# 提取蛋白质序列
protein_sequences = {}
for protein_id, data in interproscan_info.items():
protein_sequences[protein_id] = data.get('sequence', '')
# 计算ProTrek分数
protrek_results = self.calculate_protrek_scores(protein_sequences, protein_go_dict)
# 保存中间结果
intermediate_file = os.path.join(output_dir, self.generate_filename("go_integration", is_intermediate=True))
with open(intermediate_file, 'w') as f:
for result in protrek_results.values():
f.write(json.dumps(result) + '\n')
print(f"ProTrek分数计算完成,中间结果已保存到: {intermediate_file}")
# 第二层筛选
if self.protrek_threshold is not None:
final_results = self.second_level_filtering(protrek_results)
# 保存最终结果
final_file = os.path.join(output_dir, self.generate_filename("go_integration"))
with open(final_file, 'w') as f:
for protein_id, go_ids in final_results.items():
result = {"protein_id": protein_id, "GO_id": go_ids}
f.write(json.dumps(result) + '\n')
print(f"第二层筛选完成,最终结果已保存到: {final_file}")
return final_file, intermediate_file
return intermediate_file
def main():
parser = argparse.ArgumentParser(description="GO信息整合管道")
parser.add_argument("--interproscan_file", type=str,default="data/processed_data/interproscan_info.json", help="InterProScan结果文件路径")
parser.add_argument("--blast_file", type=str, default="data/processed_data/blast_info.json", help="BLAST结果文件路径")
parser.add_argument("--identity", type=int, default=80, help="BLAST identity阈值 (0-100)")
parser.add_argument("--coverage", type=int, default=80, help="BLAST coverage阈值 (0-100)")
parser.add_argument("--evalue", type=float, default=1e-50, help="BLAST E-value阈值")
parser.add_argument("--topk", type=int, default=2, help="BLAST topk结果")
parser.add_argument("--protrek_threshold", type=float, help="ProTrek分数阈值")
parser.add_argument("--use_protrek", action="store_true", help="是否使用第二层ProTrek筛选")
parser.add_argument("--output_dir", type=str, default="data/processed_data/go_integration_results", help="输出目录")
args = parser.parse_args()
# 创建管道实例
pipeline = GOIntegrationPipeline(
identity_threshold=args.identity,
coverage_threshold=args.coverage,
evalue_threshold=args.evalue,
topk=args.topk,
protrek_threshold=args.protrek_threshold,
use_protrek=args.use_protrek
)
# 运行管道
pipeline.run(
interproscan_file=args.interproscan_file,
blast_file=args.blast_file,
output_dir=args.output_dir
)
if __name__ == "__main__":
main()