Spaces:
Runtime error
Runtime error
import json | |
import os | |
import sys | |
import argparse | |
from typing import Dict, List, Tuple, Optional | |
from collections import defaultdict | |
import torch | |
from tqdm import tqdm | |
# 添加路径 | |
root_path = os.path.dirname((os.path.abspath(__file__))) | |
sys.path.append(root_path) | |
sys.path.append(os.path.join(root_path, "Models/ProTrek")) | |
from utils.protein_go_analysis import get_go_definition | |
class GOIntegrationPipeline: | |
def __init__(self, | |
identity_threshold: int = 80, | |
coverage_threshold: int = 80, | |
evalue_threshold: float = 1e-50, | |
topk: int = 2, | |
protrek_threshold: Optional[float] = None, | |
use_protrek: bool = False): | |
""" | |
GO信息整合管道 | |
Args: | |
identity_threshold: BLAST identity阈值 (0-100) | |
coverage_threshold: BLAST coverage阈值 (0-100) | |
evalue_threshold: BLAST E-value阈值 | |
protrek_threshold: ProTrek分数阈值 | |
use_protrek: 是否使用第二层ProTrek筛选 | |
""" | |
self.identity_threshold = identity_threshold | |
self.coverage_threshold = coverage_threshold | |
self.evalue_threshold = evalue_threshold | |
self.protrek_threshold = protrek_threshold | |
self.use_protrek = use_protrek | |
self.topk = topk | |
# 加载蛋白质-GO映射数据 | |
self._load_protein_go_dict() | |
# 如果使用protrek,初始化模型 | |
if self.use_protrek: | |
self._init_protrek_model() | |
def _init_protrek_model(self): | |
"""初始化ProTrek模型""" | |
from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel | |
config = { | |
"protein_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/esm2_t33_650M_UR50D", | |
"text_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", | |
"structure_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/foldseek_t30_150M", | |
"load_protein_pretrained": False, | |
"load_text_pretrained": False, | |
"from_checkpoint": "Models/ProTrek/weights/ProTrek_650M_UniRef50/ProTrek_650M_UniRef50.pt" | |
} | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.protrek_model = ProTrekTrimodalModel(**config).to(self.device).eval() | |
print(f"ProTrek模型已加载到设备: {self.device}") | |
def _load_protein_go_dict(self): | |
"""加载蛋白质-GO映射数据""" | |
self.protein_go_dict = {} | |
try: | |
with open('processed_data/protein_go.json', 'r') as f: | |
for line in f: | |
data = json.loads(line) | |
self.protein_go_dict[data['protein_id']] = data['GO_id'] | |
print(f"成功加载蛋白质-GO映射数据,共{len(self.protein_go_dict)}条记录") | |
except Exception as e: | |
print(f"加载蛋白质-GO映射数据时发生错误: {str(e)}") | |
self.protein_go_dict = {} | |
def _get_go_from_uniprot_id(self, uniprot_id: str) -> List[str]: | |
""" | |
从Uniprot ID获取GO ID | |
Args: | |
uniprot_id: Uniprot ID | |
Returns: | |
使用类内部加载的字典 | |
""" | |
# 使用类内部加载的字典 | |
return [go_id.split("_")[-1] if "_" in go_id else go_id | |
for go_id in self.protein_go_dict.get(uniprot_id, [])] | |
def extract_blast_go_ids(self, blast_results: List[Dict],protein_id: str) -> List[str]: | |
""" | |
从BLAST结果中提取符合条件的GO ID | |
Args: | |
blast_results: BLAST结果列表 | |
protein_id: 当前蛋白质ID(避免自身匹配) | |
Returns: | |
符合条件的GO ID列表 | |
""" | |
go_ids = [] | |
if self.topk > 0: | |
# 使用topk策略 | |
for result in blast_results[:self.topk]: | |
hit_id = result.get('ID', '') | |
if hit_id == protein_id: | |
continue | |
go_ids.extend(self._get_go_from_uniprot_id(hit_id)) | |
else: | |
# 使用阈值策略 | |
for result in blast_results: | |
identity = float(result.get('Identity%', 0)) | |
coverage = float(result.get('Coverage%', 0)) | |
evalue = float(result.get('E-value', 1.0)) | |
# 检查是否符合阈值条件 | |
if (identity >= self.identity_threshold and | |
coverage >= self.coverage_threshold and | |
evalue <= self.evalue_threshold): | |
# 获取该hit的protein_id | |
hit_id = result.get('ID', '') | |
if hit_id == protein_id: | |
continue | |
go_ids.extend(self._get_go_from_uniprot_id(hit_id)) | |
return go_ids | |
def first_level_filtering(self, interproscan_info: Dict, blast_info: Dict) -> Dict[str, List[str]]: | |
""" | |
第一层筛选:合并interproscan和符合条件的blast GO信息 | |
Args: | |
interproscan_info: InterProScan结果 | |
blast_info: BLAST结果 | |
Returns: | |
蛋白质ID到GO ID列表的映射 | |
""" | |
protein_go_dict = {} | |
for protein_id in interproscan_info.keys(): | |
go_ids = set() | |
# 添加interproscan的GO信息 | |
interproscan_gos = interproscan_info[protein_id].get('interproscan_results', {}).get('go_id', []) | |
interproscan_gos = [go_id.split(":")[-1] if ":" in go_id else go_id for go_id in interproscan_gos] | |
if interproscan_gos: | |
go_ids.update(interproscan_gos) | |
# 添加符合条件的blast GO信息 | |
if protein_id in blast_info: | |
blast_results = blast_info[protein_id].get('blast_results', []) | |
blast_gos = self.extract_blast_go_ids(blast_results,protein_id) | |
go_ids.update(blast_gos) | |
protein_go_dict[protein_id] = list(go_ids) | |
return protein_go_dict | |
def calculate_protrek_scores(self, protein_sequences: Dict[str, str], | |
protein_go_dict: Dict[str, List[str]]) -> Dict[str, Dict]: | |
""" | |
计算ProTrek分数 | |
Args: | |
protein_sequences: 蛋白质序列字典 | |
protein_go_dict: 蛋白质GO映射 | |
Returns: | |
包含GO分数的字典 | |
""" | |
results = {} | |
for protein_id, go_ids in tqdm(protein_go_dict.items(), desc="计算ProTrek分数"): | |
if protein_id not in protein_sequences: | |
continue | |
protein_seq = protein_sequences[protein_id] | |
go_scores = {} | |
# 获取GO定义 | |
go_definitions = {} | |
for go_id in go_ids: | |
definition = get_go_definition(go_id) | |
if definition: | |
go_definitions[go_id] = definition | |
if not go_definitions: | |
continue | |
try: | |
# 计算蛋白质序列嵌入 | |
seq_emb = self.protrek_model.get_protein_repr([protein_seq]) | |
# 计算文本嵌入和相似度分数 | |
definitions = list(go_definitions.values()) | |
text_embs = self.protrek_model.get_text_repr(definitions) | |
# 计算相似度分数 | |
scores = (seq_emb @ text_embs.T) / self.protrek_model.temperature | |
scores = scores.cpu().numpy().flatten() | |
# 映射回GO ID | |
for i, go_id in enumerate(go_definitions.keys()): | |
go_scores[go_id] = float(scores[i]) | |
except Exception as e: | |
print(f"计算 {protein_id} 的ProTrek分数时出错: {str(e)}") | |
continue | |
results[protein_id] = { | |
"protein_id": protein_id, | |
"GO_id": go_ids, | |
"Clip_score": go_scores | |
} | |
return results | |
def second_level_filtering(self, protrek_results: Dict[str, Dict]) -> Dict[str, List[str]]: | |
""" | |
第二层筛选:根据ProTrek阈值筛选GO | |
Args: | |
protrek_results: ProTrek计算结果 | |
Returns: | |
筛选后的蛋白质GO映射 | |
""" | |
filtered_results = {} | |
for protein_id, data in protrek_results.items(): | |
clip_scores = data.get('Clip_score', {}) | |
filtered_gos = [] | |
for go_id, score in clip_scores.items(): | |
if score >= self.protrek_threshold: | |
filtered_gos.append(go_id) | |
if filtered_gos: | |
filtered_results[protein_id] = filtered_gos | |
return filtered_results | |
def generate_filename(self, base_name: str, is_intermediate: bool = False) -> str: | |
"""生成包含参数信息的文件名""" | |
if self.topk > 0: | |
# 如果使用topk,则只包含topk信息 | |
params = f"topk{self.topk}" | |
else: | |
# 否则使用原有的参数组合 | |
params = f"identity{self.identity_threshold}_coverage{self.coverage_threshold}_evalue{self.evalue_threshold:.0e}" | |
if self.use_protrek and self.protrek_threshold is not None: | |
params += f"_protrek{self.protrek_threshold}" | |
if is_intermediate: | |
return f"{base_name}_intermediate_{params}.json" | |
else: | |
return f"{base_name}_final_{params}.json" | |
def run(self, interproscan_info: Dict = None, blast_info: Dict = None, | |
interproscan_file: str = None, blast_file: str = None, | |
output_dir: str = "output"): | |
""" | |
运行GO整合管道 | |
Args: | |
interproscan_info: InterProScan结果字典 | |
blast_info: BLAST结果字典 | |
interproscan_file: InterProScan结果文件路径 | |
blast_file: BLAST结果文件路径 | |
output_dir: 输出目录 | |
""" | |
# 加载数据 | |
if interproscan_info is None and interproscan_file: | |
with open(interproscan_file, 'r') as f: | |
interproscan_info = json.load(f) | |
if blast_info is None and blast_file: | |
with open(blast_file, 'r') as f: | |
blast_info = json.load(f) | |
if not interproscan_info or not blast_info: | |
raise ValueError("必须提供interproscan_info和blast_info数据或文件路径") | |
# 确保输出目录存在 | |
os.makedirs(output_dir, exist_ok=True) | |
print("开始第一层筛选...") | |
# 第一层筛选 | |
protein_go_dict = self.first_level_filtering(interproscan_info, blast_info) | |
if not self.use_protrek: | |
# 不使用第二层筛选,直接保存结果 | |
output_file = os.path.join(output_dir, self.generate_filename("go_integration")) | |
with open(output_file, 'w') as f: | |
for protein_id, go_ids in protein_go_dict.items(): | |
result = {"protein_id": protein_id, "GO_id": go_ids} | |
f.write(json.dumps(result) + '\n') | |
print(f"第一层筛选完成,结果已保存到: {output_file}") | |
return output_file | |
print("开始第二层筛选...") | |
# 提取蛋白质序列 | |
protein_sequences = {} | |
for protein_id, data in interproscan_info.items(): | |
protein_sequences[protein_id] = data.get('sequence', '') | |
# 计算ProTrek分数 | |
protrek_results = self.calculate_protrek_scores(protein_sequences, protein_go_dict) | |
# 保存中间结果 | |
intermediate_file = os.path.join(output_dir, self.generate_filename("go_integration", is_intermediate=True)) | |
with open(intermediate_file, 'w') as f: | |
for result in protrek_results.values(): | |
f.write(json.dumps(result) + '\n') | |
print(f"ProTrek分数计算完成,中间结果已保存到: {intermediate_file}") | |
# 第二层筛选 | |
if self.protrek_threshold is not None: | |
final_results = self.second_level_filtering(protrek_results) | |
# 保存最终结果 | |
final_file = os.path.join(output_dir, self.generate_filename("go_integration")) | |
with open(final_file, 'w') as f: | |
for protein_id, go_ids in final_results.items(): | |
result = {"protein_id": protein_id, "GO_id": go_ids} | |
f.write(json.dumps(result) + '\n') | |
print(f"第二层筛选完成,最终结果已保存到: {final_file}") | |
return final_file, intermediate_file | |
return intermediate_file | |
def main(): | |
parser = argparse.ArgumentParser(description="GO信息整合管道") | |
parser.add_argument("--interproscan_file", type=str,default="data/processed_data/interproscan_info.json", help="InterProScan结果文件路径") | |
parser.add_argument("--blast_file", type=str, default="data/processed_data/blast_info.json", help="BLAST结果文件路径") | |
parser.add_argument("--identity", type=int, default=80, help="BLAST identity阈值 (0-100)") | |
parser.add_argument("--coverage", type=int, default=80, help="BLAST coverage阈值 (0-100)") | |
parser.add_argument("--evalue", type=float, default=1e-50, help="BLAST E-value阈值") | |
parser.add_argument("--topk", type=int, default=2, help="BLAST topk结果") | |
parser.add_argument("--protrek_threshold", type=float, help="ProTrek分数阈值") | |
parser.add_argument("--use_protrek", action="store_true", help="是否使用第二层ProTrek筛选") | |
parser.add_argument("--output_dir", type=str, default="data/processed_data/go_integration_results", help="输出目录") | |
args = parser.parse_args() | |
# 创建管道实例 | |
pipeline = GOIntegrationPipeline( | |
identity_threshold=args.identity, | |
coverage_threshold=args.coverage, | |
evalue_threshold=args.evalue, | |
topk=args.topk, | |
protrek_threshold=args.protrek_threshold, | |
use_protrek=args.use_protrek | |
) | |
# 运行管道 | |
pipeline.run( | |
interproscan_file=args.interproscan_file, | |
blast_file=args.blast_file, | |
output_dir=args.output_dir | |
) | |
if __name__ == "__main__": | |
main() |