import json import os import sys import argparse from typing import Dict, List, Tuple, Optional from collections import defaultdict import torch from tqdm import tqdm # 添加路径 root_path = os.path.dirname((os.path.abspath(__file__))) sys.path.append(root_path) sys.path.append(os.path.join(root_path, "Models/ProTrek")) from utils.protein_go_analysis import get_go_definition class GOIntegrationPipeline: def __init__(self, identity_threshold: int = 80, coverage_threshold: int = 80, evalue_threshold: float = 1e-50, topk: int = 2, protrek_threshold: Optional[float] = None, use_protrek: bool = False): """ GO信息整合管道 Args: identity_threshold: BLAST identity阈值 (0-100) coverage_threshold: BLAST coverage阈值 (0-100) evalue_threshold: BLAST E-value阈值 protrek_threshold: ProTrek分数阈值 use_protrek: 是否使用第二层ProTrek筛选 """ self.identity_threshold = identity_threshold self.coverage_threshold = coverage_threshold self.evalue_threshold = evalue_threshold self.protrek_threshold = protrek_threshold self.use_protrek = use_protrek self.topk = topk # 加载蛋白质-GO映射数据 self._load_protein_go_dict() # 如果使用protrek,初始化模型 if self.use_protrek: self._init_protrek_model() def _init_protrek_model(self): """初始化ProTrek模型""" from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel config = { "protein_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/esm2_t33_650M_UR50D", "text_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", "structure_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/foldseek_t30_150M", "load_protein_pretrained": False, "load_text_pretrained": False, "from_checkpoint": "Models/ProTrek/weights/ProTrek_650M_UniRef50/ProTrek_650M_UniRef50.pt" } self.device = "cuda" if torch.cuda.is_available() else "cpu" self.protrek_model = ProTrekTrimodalModel(**config).to(self.device).eval() print(f"ProTrek模型已加载到设备: {self.device}") def _load_protein_go_dict(self): """加载蛋白质-GO映射数据""" self.protein_go_dict = {} try: with open('processed_data/protein_go.json', 'r') as f: for line in f: data = json.loads(line) self.protein_go_dict[data['protein_id']] = data['GO_id'] print(f"成功加载蛋白质-GO映射数据,共{len(self.protein_go_dict)}条记录") except Exception as e: print(f"加载蛋白质-GO映射数据时发生错误: {str(e)}") self.protein_go_dict = {} def _get_go_from_uniprot_id(self, uniprot_id: str) -> List[str]: """ 从Uniprot ID获取GO ID Args: uniprot_id: Uniprot ID Returns: 使用类内部加载的字典 """ # 使用类内部加载的字典 return [go_id.split("_")[-1] if "_" in go_id else go_id for go_id in self.protein_go_dict.get(uniprot_id, [])] def extract_blast_go_ids(self, blast_results: List[Dict],protein_id: str) -> List[str]: """ 从BLAST结果中提取符合条件的GO ID Args: blast_results: BLAST结果列表 protein_id: 当前蛋白质ID(避免自身匹配) Returns: 符合条件的GO ID列表 """ go_ids = [] if self.topk > 0: # 使用topk策略 for result in blast_results[:self.topk]: hit_id = result.get('ID', '') if hit_id == protein_id: continue go_ids.extend(self._get_go_from_uniprot_id(hit_id)) else: # 使用阈值策略 for result in blast_results: identity = float(result.get('Identity%', 0)) coverage = float(result.get('Coverage%', 0)) evalue = float(result.get('E-value', 1.0)) # 检查是否符合阈值条件 if (identity >= self.identity_threshold and coverage >= self.coverage_threshold and evalue <= self.evalue_threshold): # 获取该hit的protein_id hit_id = result.get('ID', '') if hit_id == protein_id: continue go_ids.extend(self._get_go_from_uniprot_id(hit_id)) return go_ids def first_level_filtering(self, interproscan_info: Dict, blast_info: Dict) -> Dict[str, List[str]]: """ 第一层筛选:合并interproscan和符合条件的blast GO信息 Args: interproscan_info: InterProScan结果 blast_info: BLAST结果 Returns: 蛋白质ID到GO ID列表的映射 """ protein_go_dict = {} for protein_id in interproscan_info.keys(): go_ids = set() # 添加interproscan的GO信息 interproscan_gos = interproscan_info[protein_id].get('interproscan_results', {}).get('go_id', []) interproscan_gos = [go_id.split(":")[-1] if ":" in go_id else go_id for go_id in interproscan_gos] if interproscan_gos: go_ids.update(interproscan_gos) # 添加符合条件的blast GO信息 if protein_id in blast_info: blast_results = blast_info[protein_id].get('blast_results', []) blast_gos = self.extract_blast_go_ids(blast_results,protein_id) go_ids.update(blast_gos) protein_go_dict[protein_id] = list(go_ids) return protein_go_dict def calculate_protrek_scores(self, protein_sequences: Dict[str, str], protein_go_dict: Dict[str, List[str]]) -> Dict[str, Dict]: """ 计算ProTrek分数 Args: protein_sequences: 蛋白质序列字典 protein_go_dict: 蛋白质GO映射 Returns: 包含GO分数的字典 """ results = {} for protein_id, go_ids in tqdm(protein_go_dict.items(), desc="计算ProTrek分数"): if protein_id not in protein_sequences: continue protein_seq = protein_sequences[protein_id] go_scores = {} # 获取GO定义 go_definitions = {} for go_id in go_ids: definition = get_go_definition(go_id) if definition: go_definitions[go_id] = definition if not go_definitions: continue try: # 计算蛋白质序列嵌入 seq_emb = self.protrek_model.get_protein_repr([protein_seq]) # 计算文本嵌入和相似度分数 definitions = list(go_definitions.values()) text_embs = self.protrek_model.get_text_repr(definitions) # 计算相似度分数 scores = (seq_emb @ text_embs.T) / self.protrek_model.temperature scores = scores.cpu().numpy().flatten() # 映射回GO ID for i, go_id in enumerate(go_definitions.keys()): go_scores[go_id] = float(scores[i]) except Exception as e: print(f"计算 {protein_id} 的ProTrek分数时出错: {str(e)}") continue results[protein_id] = { "protein_id": protein_id, "GO_id": go_ids, "Clip_score": go_scores } return results def second_level_filtering(self, protrek_results: Dict[str, Dict]) -> Dict[str, List[str]]: """ 第二层筛选:根据ProTrek阈值筛选GO Args: protrek_results: ProTrek计算结果 Returns: 筛选后的蛋白质GO映射 """ filtered_results = {} for protein_id, data in protrek_results.items(): clip_scores = data.get('Clip_score', {}) filtered_gos = [] for go_id, score in clip_scores.items(): if score >= self.protrek_threshold: filtered_gos.append(go_id) if filtered_gos: filtered_results[protein_id] = filtered_gos return filtered_results def generate_filename(self, base_name: str, is_intermediate: bool = False) -> str: """生成包含参数信息的文件名""" if self.topk > 0: # 如果使用topk,则只包含topk信息 params = f"topk{self.topk}" else: # 否则使用原有的参数组合 params = f"identity{self.identity_threshold}_coverage{self.coverage_threshold}_evalue{self.evalue_threshold:.0e}" if self.use_protrek and self.protrek_threshold is not None: params += f"_protrek{self.protrek_threshold}" if is_intermediate: return f"{base_name}_intermediate_{params}.json" else: return f"{base_name}_final_{params}.json" def run(self, interproscan_info: Dict = None, blast_info: Dict = None, interproscan_file: str = None, blast_file: str = None, output_dir: str = "output"): """ 运行GO整合管道 Args: interproscan_info: InterProScan结果字典 blast_info: BLAST结果字典 interproscan_file: InterProScan结果文件路径 blast_file: BLAST结果文件路径 output_dir: 输出目录 """ # 加载数据 if interproscan_info is None and interproscan_file: with open(interproscan_file, 'r') as f: interproscan_info = json.load(f) if blast_info is None and blast_file: with open(blast_file, 'r') as f: blast_info = json.load(f) if not interproscan_info or not blast_info: raise ValueError("必须提供interproscan_info和blast_info数据或文件路径") # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) print("开始第一层筛选...") # 第一层筛选 protein_go_dict = self.first_level_filtering(interproscan_info, blast_info) if not self.use_protrek: # 不使用第二层筛选,直接保存结果 output_file = os.path.join(output_dir, self.generate_filename("go_integration")) with open(output_file, 'w') as f: for protein_id, go_ids in protein_go_dict.items(): result = {"protein_id": protein_id, "GO_id": go_ids} f.write(json.dumps(result) + '\n') print(f"第一层筛选完成,结果已保存到: {output_file}") return output_file print("开始第二层筛选...") # 提取蛋白质序列 protein_sequences = {} for protein_id, data in interproscan_info.items(): protein_sequences[protein_id] = data.get('sequence', '') # 计算ProTrek分数 protrek_results = self.calculate_protrek_scores(protein_sequences, protein_go_dict) # 保存中间结果 intermediate_file = os.path.join(output_dir, self.generate_filename("go_integration", is_intermediate=True)) with open(intermediate_file, 'w') as f: for result in protrek_results.values(): f.write(json.dumps(result) + '\n') print(f"ProTrek分数计算完成,中间结果已保存到: {intermediate_file}") # 第二层筛选 if self.protrek_threshold is not None: final_results = self.second_level_filtering(protrek_results) # 保存最终结果 final_file = os.path.join(output_dir, self.generate_filename("go_integration")) with open(final_file, 'w') as f: for protein_id, go_ids in final_results.items(): result = {"protein_id": protein_id, "GO_id": go_ids} f.write(json.dumps(result) + '\n') print(f"第二层筛选完成,最终结果已保存到: {final_file}") return final_file, intermediate_file return intermediate_file def main(): parser = argparse.ArgumentParser(description="GO信息整合管道") parser.add_argument("--interproscan_file", type=str,default="data/processed_data/interproscan_info.json", help="InterProScan结果文件路径") parser.add_argument("--blast_file", type=str, default="data/processed_data/blast_info.json", help="BLAST结果文件路径") parser.add_argument("--identity", type=int, default=80, help="BLAST identity阈值 (0-100)") parser.add_argument("--coverage", type=int, default=80, help="BLAST coverage阈值 (0-100)") parser.add_argument("--evalue", type=float, default=1e-50, help="BLAST E-value阈值") parser.add_argument("--topk", type=int, default=2, help="BLAST topk结果") parser.add_argument("--protrek_threshold", type=float, help="ProTrek分数阈值") parser.add_argument("--use_protrek", action="store_true", help="是否使用第二层ProTrek筛选") parser.add_argument("--output_dir", type=str, default="data/processed_data/go_integration_results", help="输出目录") args = parser.parse_args() # 创建管道实例 pipeline = GOIntegrationPipeline( identity_threshold=args.identity, coverage_threshold=args.coverage, evalue_threshold=args.evalue, topk=args.topk, protrek_threshold=args.protrek_threshold, use_protrek=args.use_protrek ) # 运行管道 pipeline.run( interproscan_file=args.interproscan_file, blast_file=args.blast_file, output_dir=args.output_dir ) if __name__ == "__main__": main()