Spaces:
Runtime error
Runtime error
File size: 14,916 Bytes
5c20520 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 |
import json
import os
import sys
import argparse
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import torch
from tqdm import tqdm
# 添加路径
root_path = os.path.dirname((os.path.abspath(__file__)))
sys.path.append(root_path)
sys.path.append(os.path.join(root_path, "Models/ProTrek"))
from utils.protein_go_analysis import get_go_definition
class GOIntegrationPipeline:
def __init__(self,
identity_threshold: int = 80,
coverage_threshold: int = 80,
evalue_threshold: float = 1e-50,
topk: int = 2,
protrek_threshold: Optional[float] = None,
use_protrek: bool = False):
"""
GO信息整合管道
Args:
identity_threshold: BLAST identity阈值 (0-100)
coverage_threshold: BLAST coverage阈值 (0-100)
evalue_threshold: BLAST E-value阈值
protrek_threshold: ProTrek分数阈值
use_protrek: 是否使用第二层ProTrek筛选
"""
self.identity_threshold = identity_threshold
self.coverage_threshold = coverage_threshold
self.evalue_threshold = evalue_threshold
self.protrek_threshold = protrek_threshold
self.use_protrek = use_protrek
self.topk = topk
# 加载蛋白质-GO映射数据
self._load_protein_go_dict()
# 如果使用protrek,初始化模型
if self.use_protrek:
self._init_protrek_model()
def _init_protrek_model(self):
"""初始化ProTrek模型"""
from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
config = {
"protein_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/esm2_t33_650M_UR50D",
"text_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"structure_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/foldseek_t30_150M",
"load_protein_pretrained": False,
"load_text_pretrained": False,
"from_checkpoint": "Models/ProTrek/weights/ProTrek_650M_UniRef50/ProTrek_650M_UniRef50.pt"
}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.protrek_model = ProTrekTrimodalModel(**config).to(self.device).eval()
print(f"ProTrek模型已加载到设备: {self.device}")
def _load_protein_go_dict(self):
"""加载蛋白质-GO映射数据"""
self.protein_go_dict = {}
try:
with open('processed_data/protein_go.json', 'r') as f:
for line in f:
data = json.loads(line)
self.protein_go_dict[data['protein_id']] = data['GO_id']
print(f"成功加载蛋白质-GO映射数据,共{len(self.protein_go_dict)}条记录")
except Exception as e:
print(f"加载蛋白质-GO映射数据时发生错误: {str(e)}")
self.protein_go_dict = {}
def _get_go_from_uniprot_id(self, uniprot_id: str) -> List[str]:
"""
从Uniprot ID获取GO ID
Args:
uniprot_id: Uniprot ID
Returns:
使用类内部加载的字典
"""
# 使用类内部加载的字典
return [go_id.split("_")[-1] if "_" in go_id else go_id
for go_id in self.protein_go_dict.get(uniprot_id, [])]
def extract_blast_go_ids(self, blast_results: List[Dict],protein_id: str) -> List[str]:
"""
从BLAST结果中提取符合条件的GO ID
Args:
blast_results: BLAST结果列表
protein_id: 当前蛋白质ID(避免自身匹配)
Returns:
符合条件的GO ID列表
"""
go_ids = []
if self.topk > 0:
# 使用topk策略
for result in blast_results[:self.topk]:
hit_id = result.get('ID', '')
if hit_id == protein_id:
continue
go_ids.extend(self._get_go_from_uniprot_id(hit_id))
else:
# 使用阈值策略
for result in blast_results:
identity = float(result.get('Identity%', 0))
coverage = float(result.get('Coverage%', 0))
evalue = float(result.get('E-value', 1.0))
# 检查是否符合阈值条件
if (identity >= self.identity_threshold and
coverage >= self.coverage_threshold and
evalue <= self.evalue_threshold):
# 获取该hit的protein_id
hit_id = result.get('ID', '')
if hit_id == protein_id:
continue
go_ids.extend(self._get_go_from_uniprot_id(hit_id))
return go_ids
def first_level_filtering(self, interproscan_info: Dict, blast_info: Dict) -> Dict[str, List[str]]:
"""
第一层筛选:合并interproscan和符合条件的blast GO信息
Args:
interproscan_info: InterProScan结果
blast_info: BLAST结果
Returns:
蛋白质ID到GO ID列表的映射
"""
protein_go_dict = {}
for protein_id in interproscan_info.keys():
go_ids = set()
# 添加interproscan的GO信息
interproscan_gos = interproscan_info[protein_id].get('interproscan_results', {}).get('go_id', [])
interproscan_gos = [go_id.split(":")[-1] if ":" in go_id else go_id for go_id in interproscan_gos]
if interproscan_gos:
go_ids.update(interproscan_gos)
# 添加符合条件的blast GO信息
if protein_id in blast_info:
blast_results = blast_info[protein_id].get('blast_results', [])
blast_gos = self.extract_blast_go_ids(blast_results,protein_id)
go_ids.update(blast_gos)
protein_go_dict[protein_id] = list(go_ids)
return protein_go_dict
def calculate_protrek_scores(self, protein_sequences: Dict[str, str],
protein_go_dict: Dict[str, List[str]]) -> Dict[str, Dict]:
"""
计算ProTrek分数
Args:
protein_sequences: 蛋白质序列字典
protein_go_dict: 蛋白质GO映射
Returns:
包含GO分数的字典
"""
results = {}
for protein_id, go_ids in tqdm(protein_go_dict.items(), desc="计算ProTrek分数"):
if protein_id not in protein_sequences:
continue
protein_seq = protein_sequences[protein_id]
go_scores = {}
# 获取GO定义
go_definitions = {}
for go_id in go_ids:
definition = get_go_definition(go_id)
if definition:
go_definitions[go_id] = definition
if not go_definitions:
continue
try:
# 计算蛋白质序列嵌入
seq_emb = self.protrek_model.get_protein_repr([protein_seq])
# 计算文本嵌入和相似度分数
definitions = list(go_definitions.values())
text_embs = self.protrek_model.get_text_repr(definitions)
# 计算相似度分数
scores = (seq_emb @ text_embs.T) / self.protrek_model.temperature
scores = scores.cpu().numpy().flatten()
# 映射回GO ID
for i, go_id in enumerate(go_definitions.keys()):
go_scores[go_id] = float(scores[i])
except Exception as e:
print(f"计算 {protein_id} 的ProTrek分数时出错: {str(e)}")
continue
results[protein_id] = {
"protein_id": protein_id,
"GO_id": go_ids,
"Clip_score": go_scores
}
return results
def second_level_filtering(self, protrek_results: Dict[str, Dict]) -> Dict[str, List[str]]:
"""
第二层筛选:根据ProTrek阈值筛选GO
Args:
protrek_results: ProTrek计算结果
Returns:
筛选后的蛋白质GO映射
"""
filtered_results = {}
for protein_id, data in protrek_results.items():
clip_scores = data.get('Clip_score', {})
filtered_gos = []
for go_id, score in clip_scores.items():
if score >= self.protrek_threshold:
filtered_gos.append(go_id)
if filtered_gos:
filtered_results[protein_id] = filtered_gos
return filtered_results
def generate_filename(self, base_name: str, is_intermediate: bool = False) -> str:
"""生成包含参数信息的文件名"""
if self.topk > 0:
# 如果使用topk,则只包含topk信息
params = f"topk{self.topk}"
else:
# 否则使用原有的参数组合
params = f"identity{self.identity_threshold}_coverage{self.coverage_threshold}_evalue{self.evalue_threshold:.0e}"
if self.use_protrek and self.protrek_threshold is not None:
params += f"_protrek{self.protrek_threshold}"
if is_intermediate:
return f"{base_name}_intermediate_{params}.json"
else:
return f"{base_name}_final_{params}.json"
def run(self, interproscan_info: Dict = None, blast_info: Dict = None,
interproscan_file: str = None, blast_file: str = None,
output_dir: str = "output"):
"""
运行GO整合管道
Args:
interproscan_info: InterProScan结果字典
blast_info: BLAST结果字典
interproscan_file: InterProScan结果文件路径
blast_file: BLAST结果文件路径
output_dir: 输出目录
"""
# 加载数据
if interproscan_info is None and interproscan_file:
with open(interproscan_file, 'r') as f:
interproscan_info = json.load(f)
if blast_info is None and blast_file:
with open(blast_file, 'r') as f:
blast_info = json.load(f)
if not interproscan_info or not blast_info:
raise ValueError("必须提供interproscan_info和blast_info数据或文件路径")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
print("开始第一层筛选...")
# 第一层筛选
protein_go_dict = self.first_level_filtering(interproscan_info, blast_info)
if not self.use_protrek:
# 不使用第二层筛选,直接保存结果
output_file = os.path.join(output_dir, self.generate_filename("go_integration"))
with open(output_file, 'w') as f:
for protein_id, go_ids in protein_go_dict.items():
result = {"protein_id": protein_id, "GO_id": go_ids}
f.write(json.dumps(result) + '\n')
print(f"第一层筛选完成,结果已保存到: {output_file}")
return output_file
print("开始第二层筛选...")
# 提取蛋白质序列
protein_sequences = {}
for protein_id, data in interproscan_info.items():
protein_sequences[protein_id] = data.get('sequence', '')
# 计算ProTrek分数
protrek_results = self.calculate_protrek_scores(protein_sequences, protein_go_dict)
# 保存中间结果
intermediate_file = os.path.join(output_dir, self.generate_filename("go_integration", is_intermediate=True))
with open(intermediate_file, 'w') as f:
for result in protrek_results.values():
f.write(json.dumps(result) + '\n')
print(f"ProTrek分数计算完成,中间结果已保存到: {intermediate_file}")
# 第二层筛选
if self.protrek_threshold is not None:
final_results = self.second_level_filtering(protrek_results)
# 保存最终结果
final_file = os.path.join(output_dir, self.generate_filename("go_integration"))
with open(final_file, 'w') as f:
for protein_id, go_ids in final_results.items():
result = {"protein_id": protein_id, "GO_id": go_ids}
f.write(json.dumps(result) + '\n')
print(f"第二层筛选完成,最终结果已保存到: {final_file}")
return final_file, intermediate_file
return intermediate_file
def main():
parser = argparse.ArgumentParser(description="GO信息整合管道")
parser.add_argument("--interproscan_file", type=str,default="data/processed_data/interproscan_info.json", help="InterProScan结果文件路径")
parser.add_argument("--blast_file", type=str, default="data/processed_data/blast_info.json", help="BLAST结果文件路径")
parser.add_argument("--identity", type=int, default=80, help="BLAST identity阈值 (0-100)")
parser.add_argument("--coverage", type=int, default=80, help="BLAST coverage阈值 (0-100)")
parser.add_argument("--evalue", type=float, default=1e-50, help="BLAST E-value阈值")
parser.add_argument("--topk", type=int, default=2, help="BLAST topk结果")
parser.add_argument("--protrek_threshold", type=float, help="ProTrek分数阈值")
parser.add_argument("--use_protrek", action="store_true", help="是否使用第二层ProTrek筛选")
parser.add_argument("--output_dir", type=str, default="data/processed_data/go_integration_results", help="输出目录")
args = parser.parse_args()
# 创建管道实例
pipeline = GOIntegrationPipeline(
identity_threshold=args.identity,
coverage_threshold=args.coverage,
evalue_threshold=args.evalue,
topk=args.topk,
protrek_threshold=args.protrek_threshold,
use_protrek=args.use_protrek
)
# 运行管道
pipeline.run(
interproscan_file=args.interproscan_file,
blast_file=args.blast_file,
output_dir=args.output_dir
)
if __name__ == "__main__":
main() |