AlienChen commited on
Commit
dbeb56d
·
verified ·
1 Parent(s): f900c86

Delete muppit

Browse files
muppit/.gitkeep DELETED
File without changes
muppit/__init__.py DELETED
File without changes
muppit/calculate_steps.py DELETED
@@ -1,72 +0,0 @@
1
- import math
2
-
3
-
4
- # def calculate_steps_per_epoch(total_samples, batch_size_per_gpu, num_gpus, scheduling):
5
- # # Calculate total batch size across all GPUs
6
- # total_batch_size = batch_size_per_gpu * num_gpus
7
- #
8
- # # Calculate total batches per epoch
9
- # batches_per_epoch = math.ceil(total_samples / total_batch_size)
10
- #
11
- # steps_per_epoch = []
12
- # current_accumulation_factor = 1 # Default accumulation factor
13
- #
14
- # for epoch in range(max(scheduling.keys()) + 1):
15
- # # Update accumulation factor if it's defined for the current epoch
16
- # if epoch in scheduling:
17
- # current_accumulation_factor = scheduling[epoch]
18
- #
19
- # effective_steps = math.ceil(batches_per_epoch / current_accumulation_factor)
20
- # steps_per_epoch.append(effective_steps)
21
- #
22
- # return steps_per_epoch
23
-
24
- def calculate_total_steps(total_samples, batch_size, num_gpus, accumulation_schedule, max_epochs):
25
- total_steps = 0
26
-
27
- for epoch in range(max_epochs):
28
- # Determine the accumulation steps for the current epoch
29
- for start_epoch, steps in accumulation_schedule.items():
30
- if start_epoch > epoch:
31
- break
32
- accumulation_steps = steps
33
-
34
- effective_batch_size = batch_size * num_gpus * accumulation_steps
35
- steps_per_epoch = (total_samples + effective_batch_size - 1) // effective_batch_size
36
-
37
- total_steps += steps_per_epoch
38
- print(f'Epoch {epoch}: {steps_per_epoch} steps (accumulation_steps={accumulation_steps})')
39
-
40
- return total_steps
41
-
42
-
43
- total_samples = 4804 # Replace with the actual number of samples in your dataset
44
- batch_size = 32
45
- num_gpus = 1
46
- accumulation_schedule = {0: 4, 3: 3, 10: 2}
47
- max_epochs = 20
48
-
49
- total_steps = calculate_total_steps(total_samples, batch_size, num_gpus, accumulation_schedule, max_epochs)
50
- print(f"Total Steps: {total_steps}")
51
-
52
- # total_samples = 309503 # Replace with the actual number of samples in your dataset
53
- # batch_size = 32
54
- # num_gpus = 7
55
- # accumulation_schedule = {0: 4, 2: 2, 7: 1}
56
- # max_epochs = 10
57
- #
58
- # total_steps = calculate_total_steps(total_samples, batch_size, num_gpus, accumulation_schedule, max_epochs)
59
- # print(f"Total Steps: {total_steps}")
60
-
61
- #
62
- # # Example usage
63
- # total_samples = 309503
64
- # batch_size_per_gpu = 16
65
- # num_gpus = 7
66
- # scheduling = {0: 4, 5: 3, 10: 2, 13: 1}
67
- #
68
- # steps_per_epoch = calculate_steps_per_epoch(total_samples, batch_size_per_gpu, num_gpus, scheduling)
69
- # for epoch, steps in enumerate(steps_per_epoch):
70
- # print(f"Epoch {epoch}: {steps} steps")
71
- #
72
- # print(f"Total steps: {sum(steps_per_epoch)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/finetune.py DELETED
@@ -1,386 +0,0 @@
1
- import pdb
2
- from pytorch_lightning.strategies import DDPStrategy
3
- import torch
4
- import torch.nn.functional as F
5
- from torch.utils.data import DataLoader, DistributedSampler
6
- from datasets import load_from_disk
7
- import pytorch_lightning as pl
8
- from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
9
- Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
10
- from pytorch_lightning.loggers import WandbLogger
11
- from torch.optim.lr_scheduler import _LRScheduler
12
- from transformers.optimization import get_cosine_schedule_with_warmup
13
- from argparse import ArgumentParser
14
- import os
15
- import uuid
16
- import numpy as np
17
- import torch.distributed as dist
18
- from models import *
19
- from torch.nn.utils.rnn import pad_sequence
20
- from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
21
- from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
22
- from torch.optim import Adam, AdamW
23
- from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
24
- import gc
25
-
26
-
27
- os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
28
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
29
- os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
30
-
31
-
32
- def collate_fn(batch):
33
- # Unpack the batch
34
- anchors = []
35
- positives = []
36
- # negatives = []
37
- binding_sites = []
38
-
39
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
40
-
41
- for b in batch:
42
- anchors.append(b['anchors'])
43
- positives.append(b['positives'])
44
- binding_sites.append(b['binding_site'])
45
-
46
- # Collate the tensors using torch's pad_sequence
47
- anchor_input_ids = torch.nn.utils.rnn.pad_sequence(
48
- [torch.Tensor(item['input_ids']).squeeze(0) for item in anchors], batch_first=True, padding_value=tokenizer.pad_token_id)
49
- anchor_attention_mask = torch.nn.utils.rnn.pad_sequence(
50
- [torch.Tensor(item['attention_mask']).squeeze(0) for item in anchors], batch_first=True, padding_value=0)
51
-
52
- positive_input_ids = torch.nn.utils.rnn.pad_sequence(
53
- [torch.Tensor(item['input_ids']).squeeze(0) for item in positives], batch_first=True, padding_value=tokenizer.pad_token_id)
54
- positive_attention_mask = torch.nn.utils.rnn.pad_sequence(
55
- [torch.Tensor(item['attention_mask']).squeeze(0) for item in positives], batch_first=True, padding_value=0)
56
-
57
- n, max_length = anchor_input_ids.shape[0], anchor_input_ids.shape[1]
58
- site = torch.zeros(n, max_length)
59
- for i in range(len(binding_sites)):
60
- binding_site = binding_sites[i]
61
- site[i, binding_site] = 1
62
-
63
- # Return the collated batch
64
- return {
65
- 'anchor_input_ids': anchor_input_ids.int(),
66
- 'anchor_attention_mask': anchor_attention_mask.int(),
67
- 'positive_input_ids': positive_input_ids.int(),
68
- 'positive_attention_mask': positive_attention_mask.int(),
69
- 'binding_site': site
70
- }
71
-
72
-
73
- class CustomDataModule(pl.LightningDataModule):
74
- def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
75
- super().__init__()
76
- self.train_dataset = train_dataset
77
- self.val_dataset = val_dataset
78
- self.batch_size = batch_size
79
- self.tokenizer = tokenizer
80
-
81
- def train_dataloader(self):
82
- return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
83
- num_workers=8, pin_memory=True)
84
-
85
- def val_dataloader(self):
86
- return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
87
- pin_memory=True)
88
-
89
- def setup(self, stage=None):
90
- if stage == 'test' or stage is None:
91
- test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/pep_prot_test')
92
- self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
93
- num_workers=8, pin_memory=True)
94
-
95
-
96
- class CosineAnnealingWithWarmup(_LRScheduler):
97
- def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
98
- self.warmup_steps = warmup_steps
99
- self.total_steps = total_steps
100
- self.base_lr = base_lr
101
- self.max_lr = max_lr
102
- self.min_lr = min_lr
103
- super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
104
- print(f"SELF BASE LRS = {self.base_lrs}")
105
-
106
- def get_lr(self):
107
- if self.last_epoch < self.warmup_steps:
108
- # Linear warmup phase from base_lr to max_lr
109
- return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
110
-
111
- # Cosine annealing phase from max_lr to min_lr
112
- progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
113
- cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
114
- decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
115
-
116
- return [decayed_lr for base_lr in self.base_lrs]
117
-
118
- class PeptideModel(pl.LightningModule):
119
- def __init__(self, n_layers, d_model, n_head,
120
- d_k, d_v, d_inner, dropout=0.2,
121
- learning_rate=0.00001, max_epochs=15):
122
- super(PeptideModel, self).__init__()
123
-
124
- self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
125
- # freeze all the esm_model parameters
126
- for param in self.esm_model.parameters():
127
- param.requires_grad = False
128
-
129
- self.repeated_module = RepeatedModule2(n_layers, d_model,
130
- n_head, d_k, d_v, d_inner, dropout=dropout)
131
-
132
- self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
133
- d_k, d_v, dropout=dropout)
134
-
135
- self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
136
-
137
- self.output_projection_prot = nn.Linear(d_model, 1)
138
-
139
- self.learning_rate = learning_rate
140
- self.max_epochs = max_epochs
141
-
142
- self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
143
- self.historical_memory = 0.9
144
- self.class_weights = torch.tensor([3.6625710315221727, 0.5790496079007189]) # binding_site weights, non-bidning site weights
145
-
146
- def forward(self, binder_tokens, target_tokens):
147
- peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
148
- protein_sequence = self.esm_model(**target_tokens).last_hidden_state
149
-
150
- prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
151
- seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
152
- protein_sequence)
153
-
154
- prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
155
-
156
- prot_enc = self.final_ffn(prot_enc)
157
-
158
- prot_enc = self.output_projection_prot(prot_enc)
159
-
160
- return prot_enc
161
-
162
- def training_step(self, batch, batch_idx):
163
- opt = self.optimizers()
164
- lr = opt.param_groups[0]['lr']
165
- self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
166
-
167
- target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
168
- 'attention_mask': batch['anchor_attention_mask'].to(self.device)}
169
- binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
170
- 'attention_mask': batch['positive_attention_mask'].to(self.device)}
171
- binding_site = batch['binding_site'].to(self.device)
172
- mask = target_tokens['attention_mask']
173
-
174
- outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
175
-
176
- weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
177
- loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
178
-
179
- masked_loss = loss * mask
180
- mean_loss = masked_loss.sum() / mask.sum()
181
-
182
- # print('logging')
183
- self.log('train_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
184
- return mean_loss
185
-
186
- def validation_step(self, batch, batch_idx):
187
- target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
188
- 'attention_mask': batch['anchor_attention_mask'].to(self.device)}
189
- binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
190
- 'attention_mask': batch['positive_attention_mask'].to(self.device)}
191
- binding_site = batch['binding_site'].to(self.device)
192
- mask = target_tokens['attention_mask']
193
-
194
- outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
195
-
196
- weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
197
- loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
198
-
199
- # Apply the mask to the loss
200
- masked_loss = loss * mask
201
-
202
- # Compute the mean loss only over the valid positions
203
- mean_loss = masked_loss.sum() / mask.sum()
204
-
205
- # Calculate predictions and apply mask
206
- sigmoid_outputs = torch.sigmoid(outputs_nodes)
207
- total = mask.sum()
208
-
209
- self.update_class_thresholds(sigmoid_outputs, binding_site, mask)
210
- self.log('threshold', self.classification_threshold, on_epoch=True)
211
-
212
- predict = (sigmoid_outputs >= self.classification_threshold).float()
213
- correct = ((predict == binding_site) * mask).sum()
214
- accuracy = correct / total
215
-
216
- # Compute AUC
217
- outputs_nodes_flat = sigmoid_outputs[mask.bool()].float().cpu().detach().numpy().flatten()
218
- binding_site_flat = binding_site[mask.bool()].float().cpu().detach().numpy().flatten()
219
- predictions_flat = predict[mask.bool()].float().cpu().detach().numpy().flatten()
220
-
221
- auc = roc_auc_score(binding_site_flat, outputs_nodes_flat)
222
- f1 = f1_score(binding_site_flat, predictions_flat)
223
- mcc = matthews_corrcoef(binding_site_flat, predictions_flat)
224
-
225
- self.log('val_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
226
- self.log('val_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
227
- self.log('val_auc', auc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
228
- self.log('val_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
229
- self.log('val_mcc', mcc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
230
-
231
- def configure_optimizers(self):
232
- print(f"MAX STEPS = {self.max_epochs}")
233
- optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
234
- # schedulers = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=0.1*self.max_epochs,
235
- # max_epochs=self.max_epochs,
236
- # warmup_start_lr=5e-4,
237
- # eta_min=0.1 * self.learning_rate)
238
-
239
- base_lr = 0
240
- max_lr = self.learning_rate
241
- min_lr = 0.1 * self.learning_rate
242
-
243
- schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=76, total_steps=1231,
244
- base_lr=base_lr, max_lr=max_lr, min_lr=min_lr)
245
-
246
- lr_schedulers = {
247
- "scheduler": schedulers,
248
- "name": 'learning_rate_logs',
249
- "interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
250
- 'frequency': 1 # The scheduler updates the learning rate after every batch
251
- }
252
- return [optimizer], [lr_schedulers]
253
-
254
- def update_class_thresholds(self, inputs, targets, mask):
255
- with torch.no_grad():
256
- min_threshold_value = 0.001
257
- thresholds = torch.arange(0.1, 1.0, 0.05, device=inputs.device)
258
-
259
- best_f1_score = 0
260
- best_threshold = min_threshold_value
261
-
262
- for threshold in thresholds:
263
- binary_predictions = (inputs >= threshold).float()
264
-
265
- tp = ((binary_predictions * targets) * mask).sum().item()
266
- fp = ((binary_predictions * (1 - targets)) * mask).sum().item()
267
- fn = (((1 - binary_predictions) * targets) * mask).sum().item()
268
-
269
- precision = tp / (tp + fp + 1e-7)
270
- recall = tp / (tp + fn + 1e-7)
271
- f1_score = 2 * precision * recall / (precision + recall + 1e-7)
272
-
273
- if f1_score > best_f1_score:
274
- best_f1_score = f1_score
275
- best_threshold = threshold
276
-
277
- updated_threshold = self.historical_memory * self.classification_threshold + (
278
- 1 - self.historical_memory) * best_threshold
279
- self.classification_threshold = nn.Parameter(torch.clamp(updated_threshold, min=min_threshold_value))
280
- gc.collect()
281
- torch.cuda.empty_cache()
282
-
283
- def training_epoch_end(self, outputs):
284
- gc.collect()
285
- torch.cuda.empty_cache()
286
- super().training_epoch_end(outputs)
287
-
288
- def validation_epoch_end(self, outputs):
289
- gc.collect()
290
- torch.cuda.empty_cache()
291
- super().validation_epoch_end(outputs)
292
-
293
-
294
-
295
-
296
- def main():
297
- parser = ArgumentParser()
298
-
299
- parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
300
- parser.add_argument("-d", dest="dataset", required=False, type=str, default="pepnn",
301
- help="Which dataset to train on, pepnn, pepbind, or interpep")
302
- parser.add_argument("-lr", type=float, default=1e-3)
303
- parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
304
- parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
305
- parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
306
- parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
307
- parser.add_argument("-d_inner", type=int, default=64)
308
- # parser.add_argument("-sm", dest="saved_model", help="File containing initial params", required=False, type=str,
309
- # default=None)
310
- parser.add_argument("-sm", default=None, help="File containing initial params", type=str)
311
- parser.add_argument("--max_epochs", type=int, default=15, help="Max number of epochs to train")
312
- args = parser.parse_args()
313
-
314
- print(args.max_epochs)
315
-
316
- # Initialize the process group for distributed training
317
- dist.init_process_group(backend="nccl")
318
-
319
- train_dataset = load_from_disk('/home/tc415/muPPIt/dataset/pep_prot_train')
320
- val_dataset = load_from_disk('/home/tc415/muPPIt/dataset/pep_prot_val')
321
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
322
-
323
- data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
324
-
325
- model = PeptideModel(6, 64, 6, 64, 128, 64, dropout=0.2,
326
- learning_rate=args.lr, max_epochs=args.max_epochs)
327
- if args.sm:
328
- model = PeptideModel.load_from_checkpoint(args.sm,
329
- n_layers=args.n_layers,
330
- d_model=args.d_model,
331
- n_head=args.n_head,
332
- d_k=64,
333
- d_v=128,
334
- d_inner=64,
335
- dropout=0.2,
336
- learning_rate=args.lr,
337
- max_epochs=args.max_epochs)
338
-
339
- run_id = str(uuid.uuid4())
340
-
341
- print("Classification Thresholds:")
342
- print(model.classification_threshold)
343
-
344
- logger = WandbLogger(project=f"bind_evaluator",
345
- name=f"finetune_lr={args.lr}_nlayers={args.n_layers}_dmodel={args.d_model}_nhead={args.n_head}_dinner={args.d_inner}",
346
- # display on the web
347
- # save_dir=f'./pl_logs/',
348
- job_type='model-training',
349
- id=run_id)
350
-
351
- checkpoint_callback = ModelCheckpoint(
352
- monitor='val_mcc',
353
- dirpath=args.output_file,
354
- filename='model-{epoch:02d}-{val_mcc:.2f}',
355
- save_top_k=1,
356
- mode='max',
357
- )
358
-
359
- early_stopping_callback = EarlyStopping(
360
- monitor='val_mcc',
361
- patience=5,
362
- verbose=True,
363
- mode='max'
364
- )
365
-
366
- accumulator = GradientAccumulationScheduler(scheduling={0: 4, 3: 3, 10: 2})
367
-
368
- trainer = pl.Trainer(
369
- max_epochs=args.max_epochs,
370
- accelerator='gpu',
371
- strategy='ddp',
372
- precision='bf16',
373
- logger=logger,
374
- devices=[0],
375
- callbacks=[checkpoint_callback, accumulator, early_stopping_callback],
376
- gradient_clip_val=1.0
377
- )
378
-
379
- trainer.fit(model, datamodule=data_module)
380
-
381
- best_model_path = checkpoint_callback.best_model_path
382
- print(best_model_path)
383
-
384
-
385
- if __name__ == "__main__":
386
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/models/.gitattributes DELETED
@@ -1 +0,0 @@
1
- ProtBert-BFD/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
 
 
muppit/models/.gitkeep DELETED
File without changes
muppit/models/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .models import *
2
- from .score_domain import *
3
- from .dataloaders import *
 
 
 
 
muppit/models/dataloaders.py DELETED
@@ -1,426 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Sat Jul 31 21:54:08 2021
4
-
5
- @author: Osama
6
- """
7
-
8
- from torch.utils.data import Dataset
9
- from Bio.PDB import Polypeptide
10
- import numpy as np
11
- import torch
12
- import pandas as pd
13
- import os
14
- import esm
15
- import ast
16
- import pdb
17
-
18
-
19
- class InterpepComplexes(Dataset):
20
-
21
- def __init__(self, mode,
22
- encoded_data_directory = "../../datasets/interpep_data/"):
23
-
24
- self.mode = mode
25
-
26
- self.encoded_data_directory = encoded_data_directory
27
-
28
- self.train_dir = "../../datasets/interpep_data/train_examples.npy"
29
-
30
- self.test_dir = "../../datasets/interpep_data/test_examples.npy"
31
-
32
- self.val_dir = "../../datasets/interpep_data/val_examples.npy"
33
-
34
-
35
- self.test_list = np.load(self.test_dir)
36
-
37
- self.train_list = np.load(self.train_dir)
38
-
39
- self.val_list = np.load(self.val_dir)
40
-
41
-
42
-
43
- if mode == "train":
44
- self.num_data = len(self.train_list)
45
- elif mode == "val":
46
- self.num_data = len(self.val_list)
47
- elif mode == "test":
48
- self.num_data = len(self.test_list)
49
-
50
-
51
-
52
- def __getitem__(self, index):
53
-
54
- if self.mode == "train":
55
- item = self.train_list[index]
56
- elif self.mode == "val":
57
- item = self.val_list[index]
58
- elif self.mode == "test":
59
- item = self.test_list[index]
60
-
61
- file_dir = self.encoded_data_directory
62
-
63
- with np.load(file_dir + "fragment_data/" + item + ".npz") as data:
64
- temp_pep_sequence = data["target_sequence"]
65
- temp_binding_sites = data["binding_sites"]
66
-
67
-
68
- with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\
69
- item.split("_")[1] + ".npz") as data:
70
- temp_nodes = data["nodes"]
71
-
72
-
73
- binding = np.zeros(len(temp_nodes))
74
- if len(temp_binding_sites) != 0:
75
- binding[temp_binding_sites] = 1
76
- target = torch.LongTensor(binding)
77
-
78
-
79
-
80
-
81
-
82
-
83
-
84
- nodes = temp_nodes[:, 0:20]
85
-
86
- prot_sequence = np.argmax(nodes, axis=-1)
87
-
88
-
89
-
90
- prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence])
91
-
92
-
93
-
94
- pep_sequence = temp_pep_sequence
95
-
96
- pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1)
97
-
98
-
99
-
100
-
101
-
102
- return pep_sequence, prot_sequence, target
103
-
104
- def __len__(self):
105
- return self.num_data
106
-
107
- class PPI(Dataset):
108
-
109
- def __init__(self, mode, csv_dir_path = "/home/u21307130002/PepNN/pepnn/datasets/ppi/"):
110
-
111
- self.mode = mode
112
- self.train_data = pd.read_csv(os.path.join(csv_dir_path, 'train.csv'))
113
- self.val_data = pd.read_csv(os.path.join(csv_dir_path, 'val.csv'))
114
- # self.test_data = pd.read_csv(os.path.join(csv_dir_path, 'test.csv'))
115
-
116
- if self.mode == 'train':
117
- self.num_data = len(self.train_data)
118
-
119
- def __len__(self):
120
- return self.num_data
121
-
122
- def __getitem__(self, index):
123
- # pdb.set_trace()
124
- if torch.is_tensor(index):
125
- index = index.tolist()
126
-
127
- if self.mode == "train":
128
- item = self.train_data.iloc[index]
129
- elif self.mode == "val":
130
- item = self.val_data.iloc[index]
131
- elif self.mode == "test":
132
- item = self.test_data.iloc[index]
133
- else:
134
- item = None
135
-
136
- # print(item)
137
-
138
- motif1 = ast.literal_eval(item['Chain_1_motifs'])
139
- motif2 = ast.literal_eval(item['Chain_2_motifs'])
140
-
141
- if len(motif1[0]) > len(motif2[0]):
142
- target = motif1
143
- prot_sequence = item['Sequence1']
144
- pep_sequence = item['Sequence2']
145
- else:
146
- target = motif2
147
- pep_sequence = item['Sequence1']
148
- prot_sequence = item['Sequence2']
149
-
150
- target = [int(motif.split('_')[1]) for motif in target]
151
-
152
- if target[-1] >= len(prot_sequence):
153
- pdb.set_trace()
154
-
155
- binding = np.zeros(len(prot_sequence))
156
- if len(target) != 0:
157
- binding[target] = 1
158
- target = torch.LongTensor(binding).float()
159
-
160
- # print(f"peptide length: {len(pep_sequence)}")
161
- # print(f"protein length: {len(prot_sequence)}")
162
- # print(f"target length: {len(target)}")
163
- # pdb.set_trace()
164
-
165
- return pep_sequence, prot_sequence, target
166
-
167
-
168
-
169
-
170
- class PepBindComplexes(Dataset):
171
-
172
- def __init__(self, mode,
173
- encoded_data_directory = "../../datasets/pepbind_data/"):
174
-
175
- self.mode = mode
176
-
177
- self.encoded_data_directory = encoded_data_directory
178
-
179
- self.train_dir = "../../datasets/pepbind_data/train_examples.npy"
180
-
181
- self.test_dir = "../../datasets/pepbind_data/test_examples.npy"
182
-
183
- self.val_dir = "../../datasets/pepbind_data/val_examples.npy"
184
-
185
-
186
- self.test_list = np.load(self.test_dir)
187
-
188
- self.train_list = np.load(self.train_dir)
189
-
190
- self.val_list = np.load(self.val_dir)
191
-
192
-
193
- if mode == "train":
194
- self.num_data = len(self.train_list)
195
- elif mode == "val":
196
- self.num_data = len(self.val_list)
197
- elif mode == "test":
198
- self.num_data = len(self.test_list)
199
-
200
-
201
-
202
- def __getitem__(self, index):
203
-
204
- if self.mode == "train":
205
- item = self.train_list[index]
206
-
207
-
208
- elif self.mode == "val":
209
- item = self.val_list[index]
210
-
211
-
212
- elif self.mode == "test":
213
- item = self.test_list[index]
214
-
215
-
216
-
217
- file_dir = self.encoded_data_directory
218
-
219
-
220
- with np.load(file_dir + "fragment_data/" + item + ".npz") as data:
221
- temp_pep_sequence = data["target_sequence"]
222
- temp_binding_sites = data["binding_sites"]
223
-
224
-
225
- with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\
226
- item.split("_")[1] + ".npz") as data:
227
- temp_nodes = data["nodes"]
228
-
229
-
230
- binding = np.zeros(len(temp_nodes))
231
- if len(temp_binding_sites) != 0:
232
- binding[temp_binding_sites] = 1
233
- target = torch.LongTensor(binding)
234
-
235
- nodes = temp_nodes[:, 0:20]
236
-
237
- prot_sequence = np.argmax(nodes, axis=-1)
238
-
239
-
240
- prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence])
241
-
242
-
243
- pep_sequence = temp_pep_sequence
244
-
245
- pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1)
246
-
247
-
248
- return pep_sequence, prot_sequence, target
249
-
250
-
251
- def __len__(self):
252
- return self.num_data
253
-
254
- class PeptideComplexes(Dataset):
255
-
256
- def __init__(self, mode,
257
- encoded_data_directory = "../../datasets/pepnn_data/all_data/"):
258
-
259
- self.mode = mode
260
-
261
- self.encoded_data_directory = encoded_data_directory
262
-
263
- self.train_dir = "../../datasets/pepnn_data/train_examples.npy"
264
-
265
- self.test_dir = "../../datasets/pepnn_test_data/test_examples.npy"
266
-
267
- self.val_dir = "../../datasets/pepnn_data/val_examples.npy"
268
-
269
-
270
- self.example_weights = np.load("../../datasets/pepnn_data/example_weights.npy")
271
-
272
- self.test_list = np.load(self.test_dir)
273
-
274
- self.train_list = np.load(self.train_dir)
275
-
276
- self.val_list = np.load(self.val_dir)
277
-
278
-
279
-
280
- if mode == "train":
281
- self.num_data = len(self.train_list)
282
- elif mode == "val":
283
- self.num_data = len(self.val_list)
284
- elif mode == "test":
285
- self.num_data = len(self.test_list)
286
-
287
-
288
-
289
- def __getitem__(self, index):
290
-
291
-
292
- if self.mode == "train":
293
- item = self.train_list[index]
294
-
295
- weight = self.example_weights[item]
296
-
297
- elif self.mode == "val":
298
- item = self.val_list[index]
299
-
300
- weight = self.example_weights[item]
301
-
302
- elif self.mode == "test":
303
- item = self.test_list[index]
304
-
305
- weight = 1
306
-
307
- if self.mode != "test":
308
- file_dir = self.encoded_data_directory
309
- else:
310
- file_dir = "../../datasets/pepnn_test_data/all_data/"
311
-
312
-
313
- with np.load(file_dir + "fragment_data/" + item + ".npz") as data:
314
- temp_pep_sequence = data["target_sequence"]
315
- temp_binding_sites = data["binding_sites"]
316
-
317
-
318
- with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\
319
- item.split("_")[1] + ".npz") as data:
320
- temp_nodes = data["nodes"]
321
-
322
-
323
- binding = np.zeros(len(temp_nodes))
324
- if len(temp_binding_sites) != 0:
325
- binding[temp_binding_sites] = 1
326
- target = torch.LongTensor(binding)
327
-
328
-
329
-
330
-
331
-
332
-
333
-
334
- nodes = temp_nodes[:, 0:20]
335
-
336
- prot_sequence = np.argmax(nodes, axis=-1)
337
-
338
-
339
-
340
- prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence])
341
-
342
-
343
-
344
- pep_sequence = temp_pep_sequence
345
-
346
- pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1)
347
-
348
-
349
-
350
-
351
-
352
- return pep_sequence, prot_sequence, target, weight
353
-
354
-
355
- def __len__(self):
356
- return self.num_data
357
-
358
-
359
- class BitenetComplexes(Dataset):
360
-
361
- def __init__(self, encoded_data_directory = "../bitenet_data/all_data/"):
362
-
363
-
364
- self.encoded_data_directory = encoded_data_directory
365
-
366
-
367
-
368
-
369
- self.train_dir = "../../datasets/bitenet_data/examples.npy"
370
-
371
-
372
-
373
-
374
- self.full_list = np.load(self.train_dir)
375
-
376
-
377
-
378
-
379
- self.num_data = len(self.full_list)
380
-
381
-
382
-
383
-
384
- def __getitem__(self, index):
385
-
386
- item = self.full_list[index]
387
-
388
- file_dir = self.encoded_data_directory
389
-
390
- with np.load(file_dir + "fragment_data/" + item[:-1] + "_" + item[-1] + ".npz") as data:
391
- temp_pep_sequence = data["target_sequence"]
392
- temp_binding_matrix = data["binding_matrix"]
393
-
394
-
395
- with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\
396
- item.split("_")[1][0] + ".npz") as data:
397
- temp_nodes = data["nodes"]
398
-
399
-
400
- binding_sum = np.sum(temp_binding_matrix, axis=0).T
401
-
402
- target = torch.LongTensor(binding_sum >= 1)
403
-
404
-
405
-
406
- nodes = temp_nodes[:, 0:20]
407
-
408
- prot_sequence = np.argmax(nodes, axis=-1)
409
-
410
-
411
-
412
- prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence])
413
-
414
-
415
-
416
- pep_sequence = temp_pep_sequence
417
-
418
- pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1)
419
-
420
-
421
-
422
-
423
- return pep_sequence, prot_sequence, target
424
-
425
- def __len__(self):
426
- return self.num_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/models/layers.py DELETED
@@ -1,44 +0,0 @@
1
- from torch import nn
2
- from .modules import *
3
-
4
- class ReciprocalLayer(nn.Module):
5
-
6
- def __init__(self, d_model, d_inner, n_head, d_k, d_v):
7
-
8
- super().__init__()
9
-
10
- self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
11
- d_k, d_v)
12
-
13
- self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
14
- d_k, d_v)
15
-
16
- self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model,
17
- d_k, d_v)
18
-
19
-
20
-
21
- self.ffn_seq = FFN(d_model, d_inner)
22
-
23
- self.ffn_protein = FFN(d_model, d_inner)
24
-
25
- def forward(self, sequence_enc, protein_seq_enc):
26
- prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
27
-
28
- seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
29
-
30
-
31
- prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc,
32
- seq_enc,
33
- seq_enc,
34
- prot_enc)
35
- prot_enc = self.ffn_protein(prot_enc)
36
-
37
- seq_enc = self.ffn_seq(seq_enc)
38
-
39
-
40
-
41
- return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
42
-
43
-
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/models/models.py DELETED
@@ -1,238 +0,0 @@
1
- import pdb
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- from .layers import *
7
- from .modules import *
8
- import pdb
9
- from transformers import EsmModel, EsmTokenizer
10
-
11
- def to_var(x):
12
- if torch.cuda.is_available():
13
- x = x.cuda()
14
- return x
15
-
16
-
17
- class RepeatedModule2(nn.Module):
18
- def __init__(self, n_layers, d_model,
19
- n_head, d_k, d_v, d_inner, dropout=0.1):
20
- super().__init__()
21
-
22
- self.linear1 = nn.Linear(1280, d_model)
23
- self.linear2 = nn.Linear(1280, d_model)
24
- self.sequence_embedding = nn.Embedding(20, d_model)
25
- self.d_model = d_model
26
-
27
- self.reciprocal_layer_stack = nn.ModuleList([
28
- ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v)
29
- for _ in range(n_layers)])
30
-
31
- self.dropout = nn.Dropout(dropout)
32
- self.dropout_2 = nn.Dropout(dropout)
33
-
34
- def forward(self, peptide_sequence, protein_sequence):
35
- sequence_attention_list = []
36
-
37
- prot_attention_list = []
38
-
39
- prot_seq_attention_list = []
40
-
41
- seq_prot_attention_list = []
42
-
43
- sequence_enc = self.dropout(self.linear1(peptide_sequence))
44
-
45
- prot_enc = self.dropout_2(self.linear2(protein_sequence))
46
-
47
- for reciprocal_layer in self.reciprocal_layer_stack:
48
- prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \
49
- reciprocal_layer(sequence_enc, prot_enc)
50
-
51
- sequence_attention_list.append(sequence_attention)
52
-
53
- prot_attention_list.append(prot_attention)
54
-
55
- prot_seq_attention_list.append(prot_seq_attention)
56
-
57
- seq_prot_attention_list.append(seq_prot_attention)
58
-
59
- return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
60
- seq_prot_attention_list, seq_prot_attention_list
61
-
62
-
63
- class RepeatedModule(nn.Module):
64
-
65
- def __init__(self, n_layers, d_model,
66
- n_head, d_k, d_v, d_inner, dropout=0.1):
67
-
68
- super().__init__()
69
-
70
- self.linear = nn.Linear(1024, d_model)
71
- self.sequence_embedding = nn.Embedding(20, d_model)
72
- self.d_model = d_model
73
-
74
- self.reciprocal_layer_stack = nn.ModuleList([
75
- ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v)
76
- for _ in range(n_layers)])
77
-
78
- self.dropout = nn.Dropout(dropout)
79
- self.dropout_2 = nn.Dropout(dropout)
80
-
81
-
82
-
83
- def _positional_embedding(self, batches, number):
84
-
85
- result = torch.exp(torch.arange(0, self.d_model,2,dtype=torch.float32)*-1*(np.log(10000)/self.d_model))
86
-
87
- numbers = torch.arange(0, number, dtype=torch.float32)
88
-
89
- numbers = numbers.unsqueeze(0)
90
-
91
- numbers = numbers.unsqueeze(2)
92
-
93
- result = numbers*result
94
-
95
- result = torch.cat((torch.sin(result), torch.cos(result)),2)
96
-
97
- return result
98
-
99
- def forward(self, peptide_sequence, protein_sequence):
100
-
101
-
102
- sequence_attention_list = []
103
-
104
- prot_attention_list = []
105
-
106
- prot_seq_attention_list = []
107
-
108
- seq_prot_attention_list = []
109
-
110
- sequence_enc = self.sequence_embedding(peptide_sequence)
111
-
112
- sequence_enc += to_var(self._positional_embedding(peptide_sequence.shape[0],
113
- peptide_sequence.shape[1]))
114
- sequence_enc = self.dropout(sequence_enc)
115
-
116
-
117
-
118
-
119
-
120
- prot_enc = self.dropout_2(self.linear(protein_sequence))
121
-
122
-
123
-
124
-
125
- for reciprocal_layer in self.reciprocal_layer_stack:
126
-
127
- prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention =\
128
- reciprocal_layer(sequence_enc, prot_enc)
129
-
130
- sequence_attention_list.append(sequence_attention)
131
-
132
- prot_attention_list.append(prot_attention)
133
-
134
- prot_seq_attention_list.append(prot_seq_attention)
135
-
136
- seq_prot_attention_list.append(seq_prot_attention)
137
-
138
-
139
-
140
- return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\
141
- seq_prot_attention_list, seq_prot_attention_list
142
-
143
-
144
- class FullModel(nn.Module):
145
-
146
- def __init__(self, n_layers, d_model, n_head,
147
- d_k, d_v, d_inner, return_attention=False, dropout=0.2):
148
- super().__init__()
149
-
150
- self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
151
-
152
- # freeze all the esm_model parameters
153
- for param in self.esm_model.parameters():
154
- param.requires_grad = False
155
-
156
- self.repeated_module = RepeatedModule2(n_layers, d_model,
157
- n_head, d_k, d_v, d_inner, dropout=dropout)
158
-
159
- self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
160
- d_k, d_v, dropout=dropout)
161
-
162
- self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
163
-
164
- self.output_projection_prot = nn.Linear(d_model, 1)
165
- self.sigmoid = nn.Sigmoid()
166
-
167
- self.return_attention = return_attention
168
-
169
- def forward(self, binder_tokens, target_tokens):
170
-
171
- with torch.no_grad():
172
- peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
173
- protein_sequence = self.esm_model(**target_tokens).last_hidden_state
174
-
175
- # pdb.set_trace()
176
-
177
- prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
178
- seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
179
- protein_sequence)
180
-
181
- prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
182
-
183
- # pdb.set_trace()
184
-
185
- prot_enc = self.final_ffn(prot_enc)
186
-
187
- prot_enc = self.sigmoid(self.output_projection_prot(prot_enc))
188
-
189
- return prot_enc
190
-
191
-
192
-
193
- class Original_FullModel(nn.Module):
194
-
195
- def __init__(self, n_layers, d_model, n_head,
196
- d_k, d_v, d_inner, return_attention=False, dropout=0.2):
197
-
198
- super().__init__()
199
- self.repeated_module = RepeatedModule(n_layers, d_model,
200
- n_head, d_k, d_v, d_inner, dropout=dropout)
201
-
202
- self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
203
- d_k, d_v, dropout=dropout)
204
-
205
- self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
206
- self.output_projection_prot = nn.Linear(d_model, 2)
207
-
208
-
209
-
210
- self.softmax_prot =nn.LogSoftmax(dim=-1)
211
-
212
-
213
- self.return_attention = return_attention
214
-
215
- def forward(self, peptide_sequence, protein_sequence):
216
-
217
- prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\
218
- seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
219
- protein_sequence)
220
-
221
-
222
-
223
- prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
224
-
225
- prot_enc = self.final_ffn(prot_enc)
226
-
227
- prot_enc = self.softmax_prot(self.output_projection_prot(prot_enc))
228
-
229
-
230
-
231
-
232
-
233
- if not self.return_attention:
234
- return prot_enc
235
- else:
236
- return prot_enc, sequence_attention_list, prot_attention_list,\
237
- seq_prot_attention_list, seq_prot_attention_list
238
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/models/modules.py DELETED
@@ -1,213 +0,0 @@
1
- from torch import nn
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
-
6
-
7
- def to_var(x):
8
- if torch.cuda.is_available():
9
- x = x.cuda()
10
- return x
11
-
12
-
13
-
14
-
15
-
16
- class MultiHeadAttentionSequence(nn.Module):
17
-
18
-
19
- def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
20
-
21
- super().__init__()
22
-
23
- self.n_head = n_head
24
- self.d_model = d_model
25
- self.d_k = d_k
26
- self.d_v = d_v
27
-
28
-
29
- self.W_Q = nn.Linear(d_model, n_head*d_k)
30
- self.W_K = nn.Linear(d_model, n_head*d_k)
31
- self.W_V = nn.Linear(d_model, n_head*d_v)
32
- self.W_O = nn.Linear(n_head*d_v, d_model)
33
-
34
-
35
- self.layer_norm = nn.LayerNorm(d_model)
36
-
37
- self.dropout = nn.Dropout(dropout)
38
-
39
-
40
- def forward(self, q, k, v):
41
-
42
- batch, len_q, _ = q.size()
43
- batch, len_k, _ = k.size()
44
- batch, len_v, _ = v.size()
45
-
46
-
47
-
48
- Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
49
- K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
50
- V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
51
-
52
-
53
-
54
-
55
- Q = Q.transpose(1, 2)
56
- K = K.transpose(1, 2).transpose(2, 3)
57
- V = V.transpose(1, 2)
58
-
59
-
60
- attention = torch.matmul(Q, K)
61
-
62
-
63
-
64
-
65
- attention = attention /np.sqrt(self.d_k)
66
-
67
-
68
-
69
- attention = F.softmax(attention, dim=-1)
70
-
71
-
72
-
73
-
74
- output = torch.matmul(attention, V)
75
-
76
-
77
-
78
- output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
79
-
80
-
81
-
82
- output = self.W_O(output)
83
-
84
-
85
- output = self.dropout(output)
86
-
87
- output = self.layer_norm(output + q)
88
-
89
-
90
-
91
-
92
-
93
- return output, attention
94
-
95
- class MultiHeadAttentionReciprocal(nn.Module):
96
-
97
-
98
- def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
99
-
100
- super().__init__()
101
-
102
- self.n_head = n_head
103
- self.d_model = d_model
104
- self.d_k = d_k
105
- self.d_v = d_v
106
-
107
-
108
- self.W_Q = nn.Linear(d_model, n_head*d_k)
109
- self.W_K = nn.Linear(d_model, n_head*d_k)
110
- self.W_V = nn.Linear(d_model, n_head*d_v)
111
- self.W_O = nn.Linear(n_head*d_v, d_model)
112
- self.W_V_2 = nn.Linear(d_model, n_head*d_v)
113
- self.W_O_2 = nn.Linear(n_head*d_v, d_model)
114
-
115
- self.layer_norm = nn.LayerNorm(d_model)
116
-
117
- self.dropout = nn.Dropout(dropout)
118
-
119
- self.layer_norm_2 = nn.LayerNorm(d_model)
120
-
121
- self.dropout_2 = nn.Dropout(dropout)
122
-
123
-
124
-
125
-
126
- def forward(self, q, k, v, v_2):
127
-
128
- batch, len_q, _ = q.size()
129
- batch, len_k, _ = k.size()
130
- batch, len_v, _ = v.size()
131
- batch, len_v_2, _ = v_2.size()
132
-
133
-
134
- Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
135
- K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
136
- V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
137
- V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
138
-
139
-
140
-
141
- Q = Q.transpose(1, 2)
142
- K = K.transpose(1, 2).transpose(2, 3)
143
- V = V.transpose(1, 2)
144
- V_2 = V_2.transpose(1,2)
145
-
146
- attention = torch.matmul(Q, K)
147
-
148
-
149
- attention = attention /np.sqrt(self.d_k)
150
-
151
- attention_2 = attention.transpose(-2, -1)
152
-
153
-
154
-
155
- attention = F.softmax(attention, dim=-1)
156
-
157
- attention_2 = F.softmax(attention_2, dim=-1)
158
-
159
-
160
- output = torch.matmul(attention, V)
161
-
162
- output_2 = torch.matmul(attention_2, V_2)
163
-
164
- output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
165
-
166
- output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
167
-
168
- output = self.W_O(output)
169
-
170
- output_2 = self.W_O_2(output_2)
171
-
172
- output = self.dropout(output)
173
-
174
- output = self.layer_norm(output + q)
175
-
176
- output_2 = self.dropout(output_2)
177
-
178
- output_2 = self.layer_norm(output_2 + k)
179
-
180
-
181
-
182
-
183
-
184
- return output, output_2, attention, attention_2
185
-
186
-
187
- class FFN(nn.Module):
188
-
189
- def __init__(self, d_in, d_hid, dropout=0.1):
190
- super().__init__()
191
-
192
- self.layer_1 = nn.Conv1d(d_in, d_hid,1)
193
- self.layer_2 = nn.Conv1d(d_hid, d_in,1)
194
- self.relu = nn.ReLU()
195
- self.layer_norm = nn.LayerNorm(d_in)
196
-
197
- self.dropout = nn.Dropout(dropout)
198
-
199
- def forward(self, x):
200
-
201
- residual = x
202
- output = self.layer_1(x.transpose(1, 2))
203
-
204
- output = self.relu(output)
205
-
206
- output = self.layer_2(output)
207
-
208
- output = self.dropout(output)
209
-
210
- output = self.layer_norm(output.transpose(1, 2)+residual)
211
-
212
- return output
213
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/models/score_domain.py DELETED
@@ -1,40 +0,0 @@
1
- from scipy.stats import norm
2
- import numpy as np
3
- import os
4
-
5
-
6
- def score(outputs):
7
-
8
- weight = 0.03
9
- binding_size_dist = np.load(os.path.join(os.path.dirname(__file__), "../params/binding_size_train_dist.npy"))
10
-
11
-
12
- mean = np.mean(binding_size_dist)
13
-
14
- std = np.std(binding_size_dist)
15
-
16
- dist = norm(mean, std)
17
-
18
-
19
- max_score = 0
20
-
21
-
22
-
23
- scores = np.exp(outputs[0])[:, 1]
24
-
25
- indices = np.argsort(-1*scores)
26
-
27
- for j in range(1, len(indices)):
28
-
29
-
30
-
31
- score = (1-weight)*np.mean(scores[indices[:j]]) + weight*(dist.pdf(j/len(indices)))
32
-
33
-
34
- if score > max_score:
35
-
36
- max_score = score
37
-
38
-
39
- return max_score
40
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/predict.py DELETED
@@ -1,118 +0,0 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- from torch.utils.data import DataLoader
4
- from datasets import load_from_disk
5
- from transformers import AutoTokenizer
6
- from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
7
- from argparse import ArgumentParser
8
- import os
9
- import torch.distributed as dist
10
-
11
- from models import * # Import your model and other necessary classes/functions here
12
-
13
-
14
- class PeptideModel(pl.LightningModule):
15
- def __init__(self, n_layers, d_model, n_head,
16
- d_k, d_v, d_inner, dropout=0.2,
17
- learning_rate=0.00001, max_epochs=15):
18
- super(PeptideModel, self).__init__()
19
-
20
- self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
21
- # freeze all the esm_model parameters
22
- for param in self.esm_model.parameters():
23
- param.requires_grad = False
24
-
25
- self.repeated_module = RepeatedModule2(n_layers, d_model,
26
- n_head, d_k, d_v, d_inner, dropout=dropout)
27
-
28
- self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
29
- d_k, d_v, dropout=dropout)
30
-
31
- self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
32
-
33
- self.output_projection_prot = nn.Linear(d_model, 1)
34
-
35
- self.learning_rate = learning_rate
36
- self.max_epochs = max_epochs
37
-
38
- self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
39
- self.historical_memory = 0.9
40
- self.class_weights = torch.tensor(
41
- [3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
42
-
43
- def forward(self, binder_tokens, target_tokens):
44
- peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
45
- protein_sequence = self.esm_model(**target_tokens).last_hidden_state
46
-
47
- prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
48
- seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
49
- protein_sequence)
50
-
51
- prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
52
-
53
- prot_enc = self.final_ffn(prot_enc)
54
-
55
- prot_enc = self.output_projection_prot(prot_enc)
56
-
57
- return torch.sigmoid(prot_enc)
58
-
59
-
60
- def main():
61
- parser = ArgumentParser()
62
- parser.add_argument("-sm", default='/home/tc415/muPPIt/muppit/train_base_1/model-epoch=14-val_loss=0.40.ckpt',
63
- help="File containing initial params", type=str)
64
- parser.add_argument("-batch_size", type=int, default=32, help="Batch size")
65
- parser.add_argument("-lr", type=float, default=1e-3)
66
- parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
67
- parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
68
- parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
69
- parser.add_argument("-d_inner", type=int, default=64)
70
- parser.add_argument("-target", type=str)
71
- parser.add_argument("-binder", type=str)
72
- args = parser.parse_args()
73
- # print(args)
74
- device = 'cuda:0'
75
-
76
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
77
-
78
- anchor_tokens = tokenizer(args.target, return_tensors='pt', padding=True, truncation=True, max_length=40000)
79
-
80
- positive_tokens = tokenizer(args.binder, return_tensors='pt', padding=True, truncation=True, max_length=40000)
81
-
82
- anchor_tokens['attention_mask'][0][0] = 0
83
- anchor_tokens['attention_mask'][0][-1] = 0
84
- positive_tokens['attention_mask'][0][0] = 0
85
- positive_tokens['attention_mask'][0][-1] = 0
86
-
87
- target_tokens = {'input_ids': anchor_tokens["input_ids"].to(device),
88
- 'attention_mask': anchor_tokens["attention_mask"].to(device)}
89
- binder_tokens = {'input_ids': positive_tokens['input_ids'].to(device),
90
- 'attention_mask': positive_tokens['attention_mask'].to(device)}
91
-
92
- print(binder_tokens['input_ids'].shape)
93
-
94
- model = PeptideModel.load_from_checkpoint(args.sm,
95
- n_layers=args.n_layers,
96
- d_model=args.d_model,
97
- n_head=args.n_head,
98
- d_k=64,
99
- d_v=128,
100
- d_inner=64).to(device)
101
-
102
- model.eval()
103
-
104
- prediction = model(binder_tokens, target_tokens).squeeze(-1)[0][1:-1]
105
- print(prediction.shape)
106
- print(model.classification_threshold)
107
-
108
- binding_site = []
109
- for i in range(len(prediction)):
110
- if prediction[i] >= model.classification_threshold:
111
- binding_site.append(i)
112
-
113
- print(binding_site)
114
- print(len(binding_site))
115
-
116
-
117
- if __name__ == "__main__":
118
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/scripts/.gitkeep DELETED
File without changes
muppit/scripts/predict_binding_site.py DELETED
@@ -1,149 +0,0 @@
1
- from Bio import SeqIO
2
- from Bio.PDB import Polypeptide
3
- from transformers import BertModel, BertTokenizer, pipeline
4
- from pepnn_seq.models import FullModel
5
- from pepnn_seq.models import score
6
- import pandas as pd
7
- import numpy as np
8
- import torch
9
- import argparse
10
- import os
11
-
12
-
13
- def to_var(x):
14
- if torch.cuda.is_available():
15
- x = x.cuda()
16
- return x
17
-
18
-
19
- if __name__ == "__main__":
20
-
21
- parser = argparse.ArgumentParser()
22
-
23
-
24
- parser.add_argument("-prot", dest="input_protein_file", required=False, type=str,
25
- help="Fasta file with protein sequence")
26
-
27
- parser.add_argument("-pep", dest="input_peptide_file", required=False, type=str, default=None,
28
- help="Fasta file with peptide sequence")
29
-
30
-
31
-
32
- parser.add_argument("-o", dest="output_directory", required=False, type=str, default=None,
33
- help="Output directory")
34
-
35
-
36
- parser.add_argument("-p", dest="params", required=False, type=str, default="../params/params.pth",
37
- help="Model parameters")
38
-
39
- args = parser.parse_args()
40
-
41
- if args.output_directory == None:
42
- output_directory = os.path.split(args.input_protein_file)[-1].split(".")[0] + "_seq"
43
- else:
44
- output_directory = args.output_directory
45
-
46
- if not os.path.exists(output_directory):
47
- os.mkdir(output_directory)
48
-
49
-
50
- records = SeqIO.parse(args.input_protein_file, format="fasta")
51
-
52
- prot_sequence = ' '.join(list(records)[0].seq)
53
-
54
-
55
- protbert_dir = os.path.join(os.path.dirname(__file__), '../models/ProtBert-BFD/')
56
-
57
- vocabFilePath = os.path.join(protbert_dir, 'vocab.txt')
58
- tokenizer = BertTokenizer(vocabFilePath, do_lower_case=False )
59
- seq_embedding = BertModel.from_pretrained(protbert_dir)
60
-
61
- if torch.cuda.is_available():
62
- seq_embedding = pipeline('feature-extraction', model=seq_embedding, tokenizer=tokenizer, device=0)
63
- else:
64
- seq_embedding = pipeline('feature-extraction', model=seq_embedding, tokenizer=tokenizer, device=-1)
65
-
66
- embedding = seq_embedding(prot_sequence)
67
-
68
- embedding = np.array(embedding)
69
-
70
- seq_len = len(prot_sequence.replace(" ", ""))
71
- start_Idx = 1
72
- end_Idx = seq_len+1
73
- seq_emd = embedding[0][start_Idx:end_Idx]
74
-
75
-
76
- prot_seq = to_var(torch.FloatTensor(seq_emd).unsqueeze(0))
77
-
78
-
79
-
80
- if args.input_peptide_file != None:
81
-
82
- records = SeqIO.parse(args.input_peptide_file, format="fasta")
83
-
84
- pep_sequence = str(list(records)[0].seq).replace("X", "")
85
-
86
- pep_sequence = [Polypeptide.d1_to_index[i] for i in pep_sequence]
87
-
88
- else:
89
-
90
- pep_sequence = [5 for i in range(10)]
91
-
92
- pep_seq = to_var(torch.LongTensor(pep_sequence).unsqueeze(0))
93
-
94
-
95
- model = FullModel(6, 64, 6,
96
-
97
- 64, 128, 64, dropout=0.2)
98
- if torch.cuda.is_available():
99
- model.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), args.params)))
100
- else:
101
- model.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), args.params),
102
- map_location='cpu'))
103
-
104
- if torch.cuda.is_available():
105
- torch.cuda.empty_cache()
106
- model.cuda()
107
-
108
-
109
- model.eval()
110
-
111
- if torch.cuda.is_available():
112
- outputs = model(pep_seq, prot_seq).cpu().detach().numpy()
113
- else:
114
- outputs = model(pep_seq, prot_seq).detach().numpy()
115
-
116
- # compute score for the domain and output file
117
-
118
- score_prm = score(outputs)
119
-
120
- with open(output_directory + "/prm_score.txt", 'w') as output_file:
121
-
122
- output_file.writelines("The input protein's score is {0:.2f}".format(score_prm))
123
-
124
-
125
- # output prediction as csv
126
-
127
- outputs = np.exp(outputs[0])
128
-
129
- amino_acids = []
130
-
131
- probabilities = []
132
-
133
- position = []
134
- for index, aa in enumerate(prot_sequence.split(" ")):
135
- probabilities.append(outputs[index, 1])
136
- amino_acids.append(aa)
137
- position.append(index+1)
138
-
139
-
140
- output = pd.DataFrame()
141
-
142
- output["Position"] = position
143
- output["Amino acid"] = amino_acids
144
- output["Probabilities"] = probabilities
145
-
146
-
147
-
148
- output.to_csv(output_directory + "/binding_site_prediction.csv", index=False)
149
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/test_evaluator.py DELETED
@@ -1,197 +0,0 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- from torch.utils.data import DataLoader
4
- from datasets import load_from_disk
5
- from transformers import AutoTokenizer
6
- from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
7
- from argparse import ArgumentParser
8
- import os
9
- import torch.distributed as dist
10
-
11
- from models import * # Import your model and other necessary classes/functions here
12
-
13
-
14
- def collate_fn(batch):
15
- # Unpack the batch
16
- anchors = []
17
- positives = []
18
- # negatives = []
19
- binding_sites = []
20
-
21
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
22
-
23
- for b in batch:
24
- anchors.append(b['anchors'])
25
- positives.append(b['positives'])
26
- # negatives.append(b['negatives'])
27
- binding_sites.append(b['binding_site'])
28
-
29
- # Collate the tensors using torch's pad_sequence
30
- anchor_input_ids = torch.nn.utils.rnn.pad_sequence(
31
- [torch.Tensor(item['input_ids']).squeeze(0) for item in anchors], batch_first=True, padding_value=tokenizer.pad_token_id)
32
- anchor_attention_mask = torch.nn.utils.rnn.pad_sequence(
33
- [torch.Tensor(item['attention_mask']).squeeze(0) for item in anchors], batch_first=True, padding_value=0)
34
-
35
- positive_input_ids = torch.nn.utils.rnn.pad_sequence(
36
- [torch.Tensor(item['input_ids']).squeeze(0) for item in positives], batch_first=True, padding_value=tokenizer.pad_token_id)
37
- positive_attention_mask = torch.nn.utils.rnn.pad_sequence(
38
- [torch.Tensor(item['attention_mask']).squeeze(0) for item in positives], batch_first=True, padding_value=0)
39
-
40
- n, max_length = anchor_input_ids.shape[0], anchor_input_ids.shape[1]
41
- site = torch.zeros(n, max_length)
42
- for i in range(len(binding_sites)):
43
- binding_site = binding_sites[i]
44
- site[i, binding_site] = 1
45
-
46
- # Return the collated batch
47
- return {
48
- 'anchor_input_ids': anchor_input_ids.int(),
49
- 'anchor_attention_mask': anchor_attention_mask.int(),
50
- 'positive_input_ids': positive_input_ids.int(),
51
- 'positive_attention_mask': positive_attention_mask.int(),
52
- # 'negative_input_ids': negative_input_ids.int(),
53
- # 'negative_attention_mask': negative_attention_mask.int(),
54
- 'binding_site': site
55
- }
56
-
57
-
58
- class CustomDataModule(pl.LightningDataModule):
59
- def __init__(self, tokenizer, batch_size: int = 128):
60
- super().__init__()
61
- self.batch_size = batch_size
62
- self.tokenizer = tokenizer
63
-
64
- def test_dataloader(self):
65
- test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/test_dataset_drop_500')
66
- return DataLoader(test_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8, pin_memory=True)
67
-
68
-
69
- class PeptideModel(pl.LightningModule):
70
- def __init__(self, n_layers, d_model, n_head,
71
- d_k, d_v, d_inner, dropout=0.2,
72
- learning_rate=0.00001, max_epochs=15):
73
- super(PeptideModel, self).__init__()
74
-
75
- self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
76
- # freeze all the esm_model parameters
77
- for param in self.esm_model.parameters():
78
- param.requires_grad = False
79
-
80
- self.repeated_module = RepeatedModule2(n_layers, d_model,
81
- n_head, d_k, d_v, d_inner, dropout=dropout)
82
-
83
- self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
84
- d_k, d_v, dropout=dropout)
85
-
86
- self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
87
-
88
- self.output_projection_prot = nn.Linear(d_model, 1)
89
-
90
- self.learning_rate = learning_rate
91
- self.max_epochs = max_epochs
92
-
93
- self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
94
- self.historical_memory = 0.9
95
- self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
96
-
97
- def forward(self, binder_tokens, target_tokens):
98
- peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
99
- protein_sequence = self.esm_model(**target_tokens).last_hidden_state
100
-
101
- prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
102
- seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
103
- protein_sequence)
104
-
105
- prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
106
-
107
- prot_enc = self.final_ffn(prot_enc)
108
-
109
- prot_enc = self.output_projection_prot(prot_enc)
110
-
111
- return prot_enc
112
-
113
- def test_step(self, batch, batch_idx):
114
- target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
115
- 'attention_mask': batch['anchor_attention_mask'].to(self.device)}
116
- binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
117
- 'attention_mask': batch['positive_attention_mask'].to(self.device)}
118
- binding_site = batch['binding_site'].to(self.device)
119
- mask = target_tokens['attention_mask']
120
-
121
- outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
122
-
123
- weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
124
- loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
125
-
126
- masked_loss = loss * mask
127
- mean_loss = masked_loss.sum() / mask.sum()
128
-
129
- sigmoid_outputs = torch.sigmoid(outputs_nodes)
130
- total = mask.sum()
131
-
132
- # self.update_class_thresholds(sigmoid_outputs, binding_site, mask)
133
- # self.log('threshold', self.classification_threshold, on_epoch=True)
134
-
135
- predict = (sigmoid_outputs >= self.classification_threshold).float()
136
- correct = ((predict == binding_site) * mask).sum()
137
- accuracy = correct / total
138
-
139
- outputs_nodes_flat = sigmoid_outputs[mask.bool()].float().cpu().detach().numpy().flatten()
140
- binding_site_flat = binding_site[mask.bool()].float().cpu().detach().numpy().flatten()
141
- predictions_flat = predict[mask.bool()].float().cpu().detach().numpy().flatten()
142
-
143
- auc = roc_auc_score(binding_site_flat, outputs_nodes_flat)
144
- f1 = f1_score(binding_site_flat, predictions_flat)
145
- mcc = matthews_corrcoef(binding_site_flat, predictions_flat)
146
-
147
- self.log('test_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
148
- self.log('test_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
149
- self.log('test_auc', auc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
150
- self.log('test_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
151
- self.log('test_mcc', mcc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
152
-
153
-
154
- def main():
155
- parser = ArgumentParser()
156
- parser.add_argument("-sm", default='/home/tc415/muPPIt/muppit/train_base_1/model-epoch=14-val_loss=0.40.ckpt', help="File containing initial params", type=str)
157
- parser.add_argument("-batch_size", type=int, default=32, help="Batch size")
158
- parser.add_argument("-lr", type=float, default=1e-3)
159
- parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
160
- parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
161
- parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
162
- parser.add_argument("-d_inner", type=int, default=64)
163
- args = parser.parse_args()
164
- print(args.sm)
165
-
166
- # Initialize the process group for distributed training
167
- dist.init_process_group(backend='nccl')
168
-
169
- test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/pep_prot_test')
170
- # print(len(test_dataset))
171
-
172
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
173
-
174
- data_module = CustomDataModule(tokenizer, args.batch_size)
175
-
176
- model = PeptideModel.load_from_checkpoint(args.sm,
177
- n_layers=args.n_layers,
178
- d_model=args.d_model,
179
- n_head=args.n_head,
180
- d_k=64,
181
- d_v=128,
182
- d_inner=64)
183
-
184
- print(f"Class threshold = {model.classification_threshold}")
185
-
186
- trainer = pl.Trainer(accelerator='gpu',
187
- devices=[0,1,2,3,4,5,6,7],
188
- strategy='ddp',
189
- precision='bf16')
190
-
191
- results = trainer.test(model, datamodule=data_module)
192
-
193
- print(results)
194
-
195
-
196
- if __name__ == "__main__":
197
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
muppit/train_evaluator.py DELETED
@@ -1,408 +0,0 @@
1
- import pdb
2
- from pytorch_lightning.strategies import DDPStrategy
3
- import torch
4
- import torch.nn.functional as F
5
- from torch.utils.data import DataLoader, DistributedSampler
6
- from datasets import load_from_disk
7
- import pytorch_lightning as pl
8
- from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
9
- Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
10
- from pytorch_lightning.loggers import WandbLogger
11
- from torch.optim.lr_scheduler import _LRScheduler
12
- from transformers.optimization import get_cosine_schedule_with_warmup
13
- from argparse import ArgumentParser
14
- import os
15
- import uuid
16
- import numpy as np
17
- import torch.distributed as dist
18
- from models import *
19
- from torch.nn.utils.rnn import pad_sequence
20
- from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
21
- from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
22
- from torch.optim import Adam, AdamW
23
- from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
24
- import gc
25
-
26
-
27
- os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
28
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
29
- os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
30
-
31
-
32
- def compute_class_weights(targets):
33
- num_binding_sites = targets.sum()
34
- num_non_binding_sites = targets.numel() - num_binding_sites
35
- total = num_binding_sites + num_non_binding_sites
36
- weight_for_binding = total / (2 * num_binding_sites)
37
- weight_for_non_binding = total / (2 * num_non_binding_sites)
38
- return torch.tensor([weight_for_non_binding, weight_for_binding])
39
-
40
-
41
- def collate_fn(batch):
42
- # Unpack the batch
43
- anchors = []
44
- positives = []
45
- # negatives = []
46
- binding_sites = []
47
-
48
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
49
-
50
- for b in batch:
51
- anchors.append(b['anchors'])
52
- positives.append(b['positives'])
53
- # negatives.append(b['negatives'])
54
- binding_sites.append(b['binding_site'])
55
-
56
- # Collate the tensors using torch's pad_sequence
57
- anchor_input_ids = torch.nn.utils.rnn.pad_sequence(
58
- [torch.Tensor(item['input_ids']).squeeze(0) for item in anchors], batch_first=True, padding_value=tokenizer.pad_token_id)
59
- anchor_attention_mask = torch.nn.utils.rnn.pad_sequence(
60
- [torch.Tensor(item['attention_mask']).squeeze(0) for item in anchors], batch_first=True, padding_value=0)
61
-
62
- positive_input_ids = torch.nn.utils.rnn.pad_sequence(
63
- [torch.Tensor(item['input_ids']).squeeze(0) for item in positives], batch_first=True, padding_value=tokenizer.pad_token_id)
64
- positive_attention_mask = torch.nn.utils.rnn.pad_sequence(
65
- [torch.Tensor(item['attention_mask']).squeeze(0) for item in positives], batch_first=True, padding_value=0)
66
-
67
- # negative_input_ids = torch.nn.utils.rnn.pad_sequence(
68
- # [torch.Tensor(item['input_ids']).squeeze(0) for item in negatives], batch_first=True, padding_value=tokenizer.pad_token_id)
69
- # negative_attention_mask = torch.nn.utils.rnn.pad_sequence(
70
- # [torch.Tensor(item['attention_mask']).squeeze(0) for item in negatives], batch_first=True, padding_value=0)
71
-
72
- # assert anchor_input_ids.shape == negative_input_ids.shape
73
-
74
- n, max_length = anchor_input_ids.shape[0], anchor_input_ids.shape[1]
75
- site = torch.zeros(n, max_length)
76
- for i in range(len(binding_sites)):
77
- binding_site = binding_sites[i]
78
- site[i, binding_site] = 1
79
-
80
- # Return the collated batch
81
- return {
82
- 'anchor_input_ids': anchor_input_ids.int(),
83
- 'anchor_attention_mask': anchor_attention_mask.int(),
84
- 'positive_input_ids': positive_input_ids.int(),
85
- 'positive_attention_mask': positive_attention_mask.int(),
86
- # 'negative_input_ids': negative_input_ids.int(),
87
- # 'negative_attention_mask': negative_attention_mask.int(),
88
- 'binding_site': site
89
- }
90
-
91
-
92
- class CustomDataModule(pl.LightningDataModule):
93
- def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
94
- super().__init__()
95
- self.train_dataset = train_dataset
96
- self.val_dataset = val_dataset
97
- self.batch_size = batch_size
98
- self.tokenizer = tokenizer
99
-
100
- def train_dataloader(self):
101
- return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
102
- num_workers=8, pin_memory=True)
103
-
104
- def val_dataloader(self):
105
- return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
106
- pin_memory=True)
107
-
108
- def setup(self, stage=None):
109
- if stage == 'test' or stage is None:
110
- test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/test_dataset_static')
111
- self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
112
- num_workers=8, pin_memory=True)
113
-
114
-
115
- class CosineAnnealingWithWarmup(_LRScheduler):
116
- def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
117
- self.warmup_steps = warmup_steps
118
- self.total_steps = total_steps
119
- self.base_lr = base_lr
120
- self.max_lr = max_lr
121
- self.min_lr = min_lr
122
- super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
123
- print(f"SELF BASE LRS = {self.base_lrs}")
124
-
125
- def get_lr(self):
126
- if self.last_epoch < self.warmup_steps:
127
- # Linear warmup phase from base_lr to max_lr
128
- return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
129
-
130
- # Cosine annealing phase from max_lr to min_lr
131
- progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
132
- cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
133
- decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
134
-
135
- return [decayed_lr for base_lr in self.base_lrs]
136
-
137
- class PeptideModel(pl.LightningModule):
138
- def __init__(self, n_layers, d_model, n_head,
139
- d_k, d_v, d_inner, dropout=0.2,
140
- learning_rate=0.00001, max_epochs=15):
141
- super(PeptideModel, self).__init__()
142
-
143
- self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
144
- # freeze all the esm_model parameters
145
- for param in self.esm_model.parameters():
146
- param.requires_grad = False
147
-
148
- self.repeated_module = RepeatedModule2(n_layers, d_model,
149
- n_head, d_k, d_v, d_inner, dropout=dropout)
150
-
151
- self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
152
- d_k, d_v, dropout=dropout)
153
-
154
- self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
155
-
156
- self.output_projection_prot = nn.Linear(d_model, 1)
157
-
158
- self.learning_rate = learning_rate
159
- self.max_epochs = max_epochs
160
-
161
- self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
162
- self.historical_memory = 0.9
163
- self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
164
-
165
- def forward(self, binder_tokens, target_tokens):
166
- peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
167
- protein_sequence = self.esm_model(**target_tokens).last_hidden_state
168
-
169
- prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
170
- seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
171
- protein_sequence)
172
-
173
- prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
174
-
175
- prot_enc = self.final_ffn(prot_enc)
176
-
177
- prot_enc = self.output_projection_prot(prot_enc)
178
-
179
- return prot_enc
180
-
181
- def training_step(self, batch, batch_idx):
182
- opt = self.optimizers()
183
- lr = opt.param_groups[0]['lr']
184
- self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
185
-
186
- target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
187
- 'attention_mask': batch['anchor_attention_mask'].to(self.device)}
188
- binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
189
- 'attention_mask': batch['positive_attention_mask'].to(self.device)}
190
- binding_site = batch['binding_site'].to(self.device)
191
- mask = target_tokens['attention_mask']
192
-
193
- outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
194
-
195
- weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
196
- loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
197
-
198
- masked_loss = loss * mask
199
- mean_loss = masked_loss.sum() / mask.sum()
200
-
201
- # print('logging')
202
- self.log('train_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
203
- return mean_loss
204
-
205
- def validation_step(self, batch, batch_idx):
206
- target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
207
- 'attention_mask': batch['anchor_attention_mask'].to(self.device)}
208
- binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
209
- 'attention_mask': batch['positive_attention_mask'].to(self.device)}
210
- binding_site = batch['binding_site'].to(self.device)
211
- mask = target_tokens['attention_mask']
212
-
213
- outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
214
-
215
- weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
216
- loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
217
-
218
- # Apply the mask to the loss
219
- masked_loss = loss * mask
220
-
221
- # Compute the mean loss only over the valid positions
222
- mean_loss = masked_loss.sum() / mask.sum()
223
-
224
- # Calculate predictions and apply mask
225
- sigmoid_outputs = torch.sigmoid(outputs_nodes)
226
- total = mask.sum()
227
-
228
- self.update_class_thresholds(sigmoid_outputs, binding_site, mask)
229
- self.log('threshold', self.classification_threshold, on_epoch=True)
230
-
231
- predict = (sigmoid_outputs >= self.classification_threshold).float()
232
- correct = ((predict == binding_site) * mask).sum()
233
- accuracy = correct / total
234
-
235
- # Compute AUC
236
- outputs_nodes_flat = sigmoid_outputs[mask.bool()].float().cpu().detach().numpy().flatten()
237
- binding_site_flat = binding_site[mask.bool()].float().cpu().detach().numpy().flatten()
238
- predictions_flat = predict[mask.bool()].float().cpu().detach().numpy().flatten()
239
-
240
- auc = roc_auc_score(binding_site_flat, outputs_nodes_flat)
241
- f1 = f1_score(binding_site_flat, predictions_flat)
242
- mcc = matthews_corrcoef(binding_site_flat, predictions_flat)
243
-
244
- self.log('val_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
245
- self.log('val_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
246
- self.log('val_auc', auc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
247
- self.log('val_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
248
- self.log('val_mcc', mcc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
249
-
250
- def configure_optimizers(self):
251
- print(f"MAX STEPS = {self.max_epochs}")
252
- optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
253
- # schedulers = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=0.1*self.max_epochs,
254
- # max_epochs=self.max_epochs,
255
- # warmup_start_lr=5e-4,
256
- # eta_min=0.1 * self.learning_rate)
257
-
258
- base_lr = 1e-4
259
- max_lr = self.learning_rate
260
- min_lr = 0.1 * self.learning_rate
261
-
262
- schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=692, total_steps=8293,
263
- base_lr=base_lr, max_lr=max_lr, min_lr=min_lr)
264
-
265
- lr_schedulers = {
266
- "scheduler": schedulers,
267
- "name": 'learning_rate_logs',
268
- "interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
269
- 'frequency': 1 # The scheduler updates the learning rate after every batch
270
- }
271
- return [optimizer], [lr_schedulers]
272
-
273
- def update_class_thresholds(self, inputs, targets, mask):
274
- with torch.no_grad():
275
- min_threshold_value = 0.001
276
- thresholds = torch.arange(0.1, 1.0, 0.05, device=inputs.device)
277
-
278
- best_f1_score = 0
279
- best_threshold = min_threshold_value
280
-
281
- for threshold in thresholds:
282
- binary_predictions = (inputs >= threshold).float()
283
-
284
- tp = ((binary_predictions * targets) * mask).sum().item()
285
- fp = ((binary_predictions * (1 - targets)) * mask).sum().item()
286
- fn = (((1 - binary_predictions) * targets) * mask).sum().item()
287
-
288
- precision = tp / (tp + fp + 1e-7)
289
- recall = tp / (tp + fn + 1e-7)
290
- f1_score = 2 * precision * recall / (precision + recall + 1e-7)
291
-
292
- if f1_score > best_f1_score:
293
- best_f1_score = f1_score
294
- best_threshold = threshold
295
-
296
- updated_threshold = self.historical_memory * self.classification_threshold + (
297
- 1 - self.historical_memory) * best_threshold
298
- self.classification_threshold = nn.Parameter(torch.clamp(updated_threshold, min=min_threshold_value))
299
- gc.collect()
300
- torch.cuda.empty_cache()
301
-
302
- def training_epoch_end(self, outputs):
303
- gc.collect()
304
- torch.cuda.empty_cache()
305
- super().training_epoch_end(outputs)
306
-
307
- def validation_epoch_end(self, outputs):
308
- gc.collect()
309
- torch.cuda.empty_cache()
310
- super().validation_epoch_end(outputs)
311
-
312
-
313
-
314
-
315
- def main():
316
- parser = ArgumentParser()
317
-
318
- parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
319
- parser.add_argument("-d", dest="dataset", required=False, type=str, default="pepnn",
320
- help="Which dataset to train on, pepnn, pepbind, or interpep")
321
- parser.add_argument("-lr", type=float, default=1e-3)
322
- parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
323
- parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
324
- parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
325
- parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
326
- parser.add_argument("-d_inner", type=int, default=64)
327
- # parser.add_argument("-sm", dest="saved_model", help="File containing initial params", required=False, type=str,
328
- # default=None)
329
- parser.add_argument("-sm", default=None, help="File containing initial params", type=str)
330
- parser.add_argument("--max_epochs", type=int, default=15, help="Max number of epochs to train")
331
- args = parser.parse_args()
332
-
333
- # Initialize the process group for distributed training
334
- dist.init_process_group(backend='nccl')
335
-
336
- train_dataset = load_from_disk('/home/tc415/muPPIt/dataset/train_dataset_drop_500')
337
- val_dataset = load_from_disk('/home/tc415/muPPIt/dataset/val_dataset_drop_500')
338
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
339
-
340
- data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
341
-
342
- # Calculate the number of training steps and warm-up steps
343
- train_dataloader = data_module.train_dataloader()
344
- num_training_steps = len(train_dataloader) * args.max_epochs
345
- num_warmup_steps = int(0.1 * num_training_steps) # Warm-up for 10% of training steps
346
-
347
- model = PeptideModel(6, 64, 6, 64, 128, 64, dropout=0.2,
348
- learning_rate=args.lr, max_epochs=num_training_steps)
349
- if args.sm:
350
- model = PeptideModel.load_from_checkpoint(args.sm,
351
- n_layers=args.n_layers,
352
- d_model=args.d_model,
353
- n_head=args.n_head,
354
- d_k=64,
355
- d_v=128,
356
- d_inner=64,
357
- dropout=0.3,
358
- learning_rate=args.lr,
359
- max_epochs=args.max_epochs)
360
-
361
- run_id = str(uuid.uuid4())
362
-
363
- print("Classification Thresholds:")
364
- print(model.classification_threshold)
365
-
366
- logger = WandbLogger(project=f"bind_evaluator",
367
- name=f"continue_lr={args.lr}_nlayers={args.n_layers}_dmodel={args.d_model}_nhead={args.n_head}_dinner={args.d_inner}",
368
- # display on the web
369
- # save_dir=f'./pl_logs/',
370
- job_type='model-training',
371
- id=run_id)
372
-
373
- checkpoint_callback = ModelCheckpoint(
374
- monitor='val_mcc',
375
- dirpath=args.output_file,
376
- filename='model-{epoch:02d}-{val_loss:.2f}',
377
- save_top_k=1,
378
- mode='max',
379
- )
380
-
381
- early_stopping_callback = EarlyStopping(
382
- monitor='val_mcc',
383
- patience=5,
384
- verbose=True,
385
- mode='max'
386
- )
387
-
388
- accumulator = GradientAccumulationScheduler(scheduling={0: 4, 2: 2, 7: 1})
389
-
390
- trainer = pl.Trainer(
391
- max_epochs=args.max_epochs,
392
- accelerator='gpu',
393
- strategy='ddp',
394
- precision='bf16',
395
- logger=logger,
396
- devices=[0,1,2,3,4,5,6],
397
- callbacks=[checkpoint_callback, accumulator, early_stopping_callback],
398
- gradient_clip_val=1.0
399
- )
400
-
401
- trainer.fit(model, datamodule=data_module)
402
-
403
- best_model_path = checkpoint_callback.best_model_path
404
- print(best_model_path)
405
-
406
-
407
- if __name__ == "__main__":
408
- main()