File size: 14,916 Bytes
5c20520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import json
import os
import sys
import argparse
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import torch
from tqdm import tqdm

# 添加路径
root_path = os.path.dirname((os.path.abspath(__file__)))
sys.path.append(root_path)
sys.path.append(os.path.join(root_path, "Models/ProTrek"))

from utils.protein_go_analysis import get_go_definition

class GOIntegrationPipeline:
    def __init__(self, 
                 identity_threshold: int = 80,
                 coverage_threshold: int = 80, 
                 evalue_threshold: float = 1e-50,
                 topk: int = 2,
                 protrek_threshold: Optional[float] = None,
                 use_protrek: bool = False):
        """
        GO信息整合管道
        
        Args:
            identity_threshold: BLAST identity阈值 (0-100)
            coverage_threshold: BLAST coverage阈值 (0-100) 
            evalue_threshold: BLAST E-value阈值
            protrek_threshold: ProTrek分数阈值
            use_protrek: 是否使用第二层ProTrek筛选
        """
        self.identity_threshold = identity_threshold
        self.coverage_threshold = coverage_threshold
        self.evalue_threshold = evalue_threshold
        self.protrek_threshold = protrek_threshold
        self.use_protrek = use_protrek
        self.topk = topk
        
        # 加载蛋白质-GO映射数据
        self._load_protein_go_dict()
        
        # 如果使用protrek,初始化模型
        if self.use_protrek:
            self._init_protrek_model()
    
    def _init_protrek_model(self):
        """初始化ProTrek模型"""
        from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel

        config = {
            "protein_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/esm2_t33_650M_UR50D",
            "text_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
            "structure_config": "Models/ProTrek/weights/ProTrek_650M_UniRef50/foldseek_t30_150M",
            "load_protein_pretrained": False,
            "load_text_pretrained": False,
            "from_checkpoint": "Models/ProTrek/weights/ProTrek_650M_UniRef50/ProTrek_650M_UniRef50.pt"
        }
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.protrek_model = ProTrekTrimodalModel(**config).to(self.device).eval()
        print(f"ProTrek模型已加载到设备: {self.device}")
    
    def _load_protein_go_dict(self):
        """加载蛋白质-GO映射数据"""
        self.protein_go_dict = {}
        try:
            with open('processed_data/protein_go.json', 'r') as f:
                for line in f:
                    data = json.loads(line)
                    self.protein_go_dict[data['protein_id']] = data['GO_id']
            print(f"成功加载蛋白质-GO映射数据,共{len(self.protein_go_dict)}条记录")
        except Exception as e:
            print(f"加载蛋白质-GO映射数据时发生错误: {str(e)}")
            self.protein_go_dict = {}
    
    def _get_go_from_uniprot_id(self, uniprot_id: str) -> List[str]:
        """
        从Uniprot ID获取GO ID
        
        Args:
            uniprot_id: Uniprot ID
        
        Returns:
            使用类内部加载的字典
        """
        # 使用类内部加载的字典
        return [go_id.split("_")[-1] if "_" in go_id else go_id 
                for go_id in self.protein_go_dict.get(uniprot_id, [])]
    
    def extract_blast_go_ids(self, blast_results: List[Dict],protein_id: str) -> List[str]:
        """
        从BLAST结果中提取符合条件的GO ID
        
        Args:
            blast_results: BLAST结果列表
            protein_id: 当前蛋白质ID(避免自身匹配)
            
        Returns:
            符合条件的GO ID列表
        """
        go_ids = []
        
        if self.topk > 0:
            # 使用topk策略
            for result in blast_results[:self.topk]:
                hit_id = result.get('ID', '')
                if hit_id == protein_id:
                    continue
                go_ids.extend(self._get_go_from_uniprot_id(hit_id))
        else:
            # 使用阈值策略
            for result in blast_results:
                identity = float(result.get('Identity%', 0))
                coverage = float(result.get('Coverage%', 0))
                evalue = float(result.get('E-value', 1.0))
                
                # 检查是否符合阈值条件
                if (identity >= self.identity_threshold and 
                    coverage >= self.coverage_threshold and 
                    evalue <= self.evalue_threshold):
                    
                    # 获取该hit的protein_id
                    hit_id = result.get('ID', '')
                    if hit_id == protein_id:
                        continue
                    go_ids.extend(self._get_go_from_uniprot_id(hit_id))
        
        return go_ids
    
    def first_level_filtering(self, interproscan_info: Dict, blast_info: Dict) -> Dict[str, List[str]]:
        """
        第一层筛选:合并interproscan和符合条件的blast GO信息
        
        Args:
            interproscan_info: InterProScan结果
            blast_info: BLAST结果
            
        Returns:
            蛋白质ID到GO ID列表的映射
        """
        protein_go_dict = {}
        
        for protein_id in interproscan_info.keys():
            go_ids = set()
            
            # 添加interproscan的GO信息
            interproscan_gos = interproscan_info[protein_id].get('interproscan_results', {}).get('go_id', [])
            interproscan_gos = [go_id.split(":")[-1] if ":" in go_id else go_id for go_id in interproscan_gos]
            if interproscan_gos:
                go_ids.update(interproscan_gos)
            
            # 添加符合条件的blast GO信息
            if protein_id in blast_info:
                blast_results = blast_info[protein_id].get('blast_results', [])
                blast_gos = self.extract_blast_go_ids(blast_results,protein_id)
                go_ids.update(blast_gos)
            
            protein_go_dict[protein_id] = list(go_ids)
        
        return protein_go_dict
    
    def calculate_protrek_scores(self, protein_sequences: Dict[str, str], 
                                protein_go_dict: Dict[str, List[str]]) -> Dict[str, Dict]:
        """
        计算ProTrek分数
        
        Args:
            protein_sequences: 蛋白质序列字典
            protein_go_dict: 蛋白质GO映射
            
        Returns:
            包含GO分数的字典
        """
        results = {}
        
        for protein_id, go_ids in tqdm(protein_go_dict.items(), desc="计算ProTrek分数"):
            if protein_id not in protein_sequences:
                continue
                
            protein_seq = protein_sequences[protein_id]
            go_scores = {}
            
            # 获取GO定义
            go_definitions = {}
            for go_id in go_ids:
                definition = get_go_definition(go_id)
                if definition:
                    go_definitions[go_id] = definition
            
            if not go_definitions:
                continue
            
            try:
                # 计算蛋白质序列嵌入
                seq_emb = self.protrek_model.get_protein_repr([protein_seq])
                
                # 计算文本嵌入和相似度分数
                definitions = list(go_definitions.values())
                text_embs = self.protrek_model.get_text_repr(definitions)
                
                # 计算相似度分数
                scores = (seq_emb @ text_embs.T) / self.protrek_model.temperature
                scores = scores.cpu().numpy().flatten()
                
                # 映射回GO ID
                for i, go_id in enumerate(go_definitions.keys()):
                    go_scores[go_id] = float(scores[i])
                    
            except Exception as e:
                print(f"计算 {protein_id} 的ProTrek分数时出错: {str(e)}")
                continue
            
            results[protein_id] = {
                "protein_id": protein_id,
                "GO_id": go_ids,
                "Clip_score": go_scores
            }
        
        return results
    
    def second_level_filtering(self, protrek_results: Dict[str, Dict]) -> Dict[str, List[str]]:
        """
        第二层筛选:根据ProTrek阈值筛选GO
        
        Args:
            protrek_results: ProTrek计算结果
            
        Returns:
            筛选后的蛋白质GO映射
        """
        filtered_results = {}
        
        for protein_id, data in protrek_results.items():
            clip_scores = data.get('Clip_score', {})
            filtered_gos = []
            
            for go_id, score in clip_scores.items():
                if score >= self.protrek_threshold:
                    filtered_gos.append(go_id)
            
            if filtered_gos:
                filtered_results[protein_id] = filtered_gos
        
        return filtered_results
    
    def generate_filename(self, base_name: str, is_intermediate: bool = False) -> str:
        """生成包含参数信息的文件名"""
        if self.topk > 0:
            # 如果使用topk,则只包含topk信息
            params = f"topk{self.topk}"
        else:
            # 否则使用原有的参数组合
            params = f"identity{self.identity_threshold}_coverage{self.coverage_threshold}_evalue{self.evalue_threshold:.0e}"
        
        if self.use_protrek and self.protrek_threshold is not None:
            params += f"_protrek{self.protrek_threshold}"
        
        if is_intermediate:
            return f"{base_name}_intermediate_{params}.json"
        else:
            return f"{base_name}_final_{params}.json"
    
    def run(self, interproscan_info: Dict = None, blast_info: Dict = None,
            interproscan_file: str = None, blast_file: str = None,
            output_dir: str = "output"):
        """
        运行GO整合管道
        
        Args:
            interproscan_info: InterProScan结果字典
            blast_info: BLAST结果字典  
            interproscan_file: InterProScan结果文件路径
            blast_file: BLAST结果文件路径
            output_dir: 输出目录
        """
        # 加载数据
        if interproscan_info is None and interproscan_file:
            with open(interproscan_file, 'r') as f:
                interproscan_info = json.load(f)
        
        if blast_info is None and blast_file:
            with open(blast_file, 'r') as f:
                blast_info = json.load(f)
        
        if not interproscan_info or not blast_info:
            raise ValueError("必须提供interproscan_info和blast_info数据或文件路径")
        
        # 确保输出目录存在
        os.makedirs(output_dir, exist_ok=True)
        
        print("开始第一层筛选...")
        # 第一层筛选
        protein_go_dict = self.first_level_filtering(interproscan_info, blast_info)
        
        if not self.use_protrek:
            # 不使用第二层筛选,直接保存结果
            output_file = os.path.join(output_dir, self.generate_filename("go_integration"))
            with open(output_file, 'w') as f:
                for protein_id, go_ids in protein_go_dict.items():
                    result = {"protein_id": protein_id, "GO_id": go_ids}
                    f.write(json.dumps(result) + '\n')
            
            print(f"第一层筛选完成,结果已保存到: {output_file}")
            return output_file
        
        print("开始第二层筛选...")
        # 提取蛋白质序列
        protein_sequences = {}
        for protein_id, data in interproscan_info.items():
            protein_sequences[protein_id] = data.get('sequence', '')
        
        # 计算ProTrek分数
        protrek_results = self.calculate_protrek_scores(protein_sequences, protein_go_dict)
        
        # 保存中间结果
        intermediate_file = os.path.join(output_dir, self.generate_filename("go_integration", is_intermediate=True))
        with open(intermediate_file, 'w') as f:
            for result in protrek_results.values():
                f.write(json.dumps(result) + '\n')
        
        print(f"ProTrek分数计算完成,中间结果已保存到: {intermediate_file}")
        
        # 第二层筛选
        if self.protrek_threshold is not None:
            final_results = self.second_level_filtering(protrek_results)
            
            # 保存最终结果
            final_file = os.path.join(output_dir, self.generate_filename("go_integration"))
            with open(final_file, 'w') as f:
                for protein_id, go_ids in final_results.items():
                    result = {"protein_id": protein_id, "GO_id": go_ids}
                    f.write(json.dumps(result) + '\n')
            
            print(f"第二层筛选完成,最终结果已保存到: {final_file}")
            return final_file, intermediate_file
        
        return intermediate_file

def main():
    parser = argparse.ArgumentParser(description="GO信息整合管道")
    parser.add_argument("--interproscan_file", type=str,default="data/processed_data/interproscan_info.json", help="InterProScan结果文件路径")
    parser.add_argument("--blast_file", type=str, default="data/processed_data/blast_info.json", help="BLAST结果文件路径")
    parser.add_argument("--identity", type=int, default=80, help="BLAST identity阈值 (0-100)")
    parser.add_argument("--coverage", type=int, default=80, help="BLAST coverage阈值 (0-100)")
    parser.add_argument("--evalue", type=float, default=1e-50, help="BLAST E-value阈值")
    parser.add_argument("--topk", type=int, default=2, help="BLAST topk结果")
    parser.add_argument("--protrek_threshold", type=float, help="ProTrek分数阈值")
    parser.add_argument("--use_protrek", action="store_true", help="是否使用第二层ProTrek筛选")
    parser.add_argument("--output_dir", type=str, default="data/processed_data/go_integration_results", help="输出目录")
    
    args = parser.parse_args()
    
    # 创建管道实例
    pipeline = GOIntegrationPipeline(
        identity_threshold=args.identity,
        coverage_threshold=args.coverage,
        evalue_threshold=args.evalue,
        topk=args.topk,
        protrek_threshold=args.protrek_threshold,
        use_protrek=args.use_protrek
    )
    
    # 运行管道
    pipeline.run(
        interproscan_file=args.interproscan_file,
        blast_file=args.blast_file,
        output_dir=args.output_dir
    )

if __name__ == "__main__":
    main()