Delete muppit
Browse files- muppit/.gitkeep +0 -0
- muppit/__init__.py +0 -0
- muppit/calculate_steps.py +0 -72
- muppit/finetune.py +0 -386
- muppit/models/.gitattributes +0 -1
- muppit/models/.gitkeep +0 -0
- muppit/models/__init__.py +0 -3
- muppit/models/dataloaders.py +0 -426
- muppit/models/layers.py +0 -44
- muppit/models/models.py +0 -238
- muppit/models/modules.py +0 -213
- muppit/models/score_domain.py +0 -40
- muppit/predict.py +0 -118
- muppit/scripts/.gitkeep +0 -0
- muppit/scripts/predict_binding_site.py +0 -149
- muppit/test_evaluator.py +0 -197
- muppit/train_evaluator.py +0 -408
muppit/.gitkeep
DELETED
File without changes
|
muppit/__init__.py
DELETED
File without changes
|
muppit/calculate_steps.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
|
4 |
-
# def calculate_steps_per_epoch(total_samples, batch_size_per_gpu, num_gpus, scheduling):
|
5 |
-
# # Calculate total batch size across all GPUs
|
6 |
-
# total_batch_size = batch_size_per_gpu * num_gpus
|
7 |
-
#
|
8 |
-
# # Calculate total batches per epoch
|
9 |
-
# batches_per_epoch = math.ceil(total_samples / total_batch_size)
|
10 |
-
#
|
11 |
-
# steps_per_epoch = []
|
12 |
-
# current_accumulation_factor = 1 # Default accumulation factor
|
13 |
-
#
|
14 |
-
# for epoch in range(max(scheduling.keys()) + 1):
|
15 |
-
# # Update accumulation factor if it's defined for the current epoch
|
16 |
-
# if epoch in scheduling:
|
17 |
-
# current_accumulation_factor = scheduling[epoch]
|
18 |
-
#
|
19 |
-
# effective_steps = math.ceil(batches_per_epoch / current_accumulation_factor)
|
20 |
-
# steps_per_epoch.append(effective_steps)
|
21 |
-
#
|
22 |
-
# return steps_per_epoch
|
23 |
-
|
24 |
-
def calculate_total_steps(total_samples, batch_size, num_gpus, accumulation_schedule, max_epochs):
|
25 |
-
total_steps = 0
|
26 |
-
|
27 |
-
for epoch in range(max_epochs):
|
28 |
-
# Determine the accumulation steps for the current epoch
|
29 |
-
for start_epoch, steps in accumulation_schedule.items():
|
30 |
-
if start_epoch > epoch:
|
31 |
-
break
|
32 |
-
accumulation_steps = steps
|
33 |
-
|
34 |
-
effective_batch_size = batch_size * num_gpus * accumulation_steps
|
35 |
-
steps_per_epoch = (total_samples + effective_batch_size - 1) // effective_batch_size
|
36 |
-
|
37 |
-
total_steps += steps_per_epoch
|
38 |
-
print(f'Epoch {epoch}: {steps_per_epoch} steps (accumulation_steps={accumulation_steps})')
|
39 |
-
|
40 |
-
return total_steps
|
41 |
-
|
42 |
-
|
43 |
-
total_samples = 4804 # Replace with the actual number of samples in your dataset
|
44 |
-
batch_size = 32
|
45 |
-
num_gpus = 1
|
46 |
-
accumulation_schedule = {0: 4, 3: 3, 10: 2}
|
47 |
-
max_epochs = 20
|
48 |
-
|
49 |
-
total_steps = calculate_total_steps(total_samples, batch_size, num_gpus, accumulation_schedule, max_epochs)
|
50 |
-
print(f"Total Steps: {total_steps}")
|
51 |
-
|
52 |
-
# total_samples = 309503 # Replace with the actual number of samples in your dataset
|
53 |
-
# batch_size = 32
|
54 |
-
# num_gpus = 7
|
55 |
-
# accumulation_schedule = {0: 4, 2: 2, 7: 1}
|
56 |
-
# max_epochs = 10
|
57 |
-
#
|
58 |
-
# total_steps = calculate_total_steps(total_samples, batch_size, num_gpus, accumulation_schedule, max_epochs)
|
59 |
-
# print(f"Total Steps: {total_steps}")
|
60 |
-
|
61 |
-
#
|
62 |
-
# # Example usage
|
63 |
-
# total_samples = 309503
|
64 |
-
# batch_size_per_gpu = 16
|
65 |
-
# num_gpus = 7
|
66 |
-
# scheduling = {0: 4, 5: 3, 10: 2, 13: 1}
|
67 |
-
#
|
68 |
-
# steps_per_epoch = calculate_steps_per_epoch(total_samples, batch_size_per_gpu, num_gpus, scheduling)
|
69 |
-
# for epoch, steps in enumerate(steps_per_epoch):
|
70 |
-
# print(f"Epoch {epoch}: {steps} steps")
|
71 |
-
#
|
72 |
-
# print(f"Total steps: {sum(steps_per_epoch)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/finetune.py
DELETED
@@ -1,386 +0,0 @@
|
|
1 |
-
import pdb
|
2 |
-
from pytorch_lightning.strategies import DDPStrategy
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torch.utils.data import DataLoader, DistributedSampler
|
6 |
-
from datasets import load_from_disk
|
7 |
-
import pytorch_lightning as pl
|
8 |
-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
|
9 |
-
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
|
10 |
-
from pytorch_lightning.loggers import WandbLogger
|
11 |
-
from torch.optim.lr_scheduler import _LRScheduler
|
12 |
-
from transformers.optimization import get_cosine_schedule_with_warmup
|
13 |
-
from argparse import ArgumentParser
|
14 |
-
import os
|
15 |
-
import uuid
|
16 |
-
import numpy as np
|
17 |
-
import torch.distributed as dist
|
18 |
-
from models import *
|
19 |
-
from torch.nn.utils.rnn import pad_sequence
|
20 |
-
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
21 |
-
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
22 |
-
from torch.optim import Adam, AdamW
|
23 |
-
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
|
24 |
-
import gc
|
25 |
-
|
26 |
-
|
27 |
-
os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
|
28 |
-
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
29 |
-
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
30 |
-
|
31 |
-
|
32 |
-
def collate_fn(batch):
|
33 |
-
# Unpack the batch
|
34 |
-
anchors = []
|
35 |
-
positives = []
|
36 |
-
# negatives = []
|
37 |
-
binding_sites = []
|
38 |
-
|
39 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
40 |
-
|
41 |
-
for b in batch:
|
42 |
-
anchors.append(b['anchors'])
|
43 |
-
positives.append(b['positives'])
|
44 |
-
binding_sites.append(b['binding_site'])
|
45 |
-
|
46 |
-
# Collate the tensors using torch's pad_sequence
|
47 |
-
anchor_input_ids = torch.nn.utils.rnn.pad_sequence(
|
48 |
-
[torch.Tensor(item['input_ids']).squeeze(0) for item in anchors], batch_first=True, padding_value=tokenizer.pad_token_id)
|
49 |
-
anchor_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
50 |
-
[torch.Tensor(item['attention_mask']).squeeze(0) for item in anchors], batch_first=True, padding_value=0)
|
51 |
-
|
52 |
-
positive_input_ids = torch.nn.utils.rnn.pad_sequence(
|
53 |
-
[torch.Tensor(item['input_ids']).squeeze(0) for item in positives], batch_first=True, padding_value=tokenizer.pad_token_id)
|
54 |
-
positive_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
55 |
-
[torch.Tensor(item['attention_mask']).squeeze(0) for item in positives], batch_first=True, padding_value=0)
|
56 |
-
|
57 |
-
n, max_length = anchor_input_ids.shape[0], anchor_input_ids.shape[1]
|
58 |
-
site = torch.zeros(n, max_length)
|
59 |
-
for i in range(len(binding_sites)):
|
60 |
-
binding_site = binding_sites[i]
|
61 |
-
site[i, binding_site] = 1
|
62 |
-
|
63 |
-
# Return the collated batch
|
64 |
-
return {
|
65 |
-
'anchor_input_ids': anchor_input_ids.int(),
|
66 |
-
'anchor_attention_mask': anchor_attention_mask.int(),
|
67 |
-
'positive_input_ids': positive_input_ids.int(),
|
68 |
-
'positive_attention_mask': positive_attention_mask.int(),
|
69 |
-
'binding_site': site
|
70 |
-
}
|
71 |
-
|
72 |
-
|
73 |
-
class CustomDataModule(pl.LightningDataModule):
|
74 |
-
def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
|
75 |
-
super().__init__()
|
76 |
-
self.train_dataset = train_dataset
|
77 |
-
self.val_dataset = val_dataset
|
78 |
-
self.batch_size = batch_size
|
79 |
-
self.tokenizer = tokenizer
|
80 |
-
|
81 |
-
def train_dataloader(self):
|
82 |
-
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
|
83 |
-
num_workers=8, pin_memory=True)
|
84 |
-
|
85 |
-
def val_dataloader(self):
|
86 |
-
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
|
87 |
-
pin_memory=True)
|
88 |
-
|
89 |
-
def setup(self, stage=None):
|
90 |
-
if stage == 'test' or stage is None:
|
91 |
-
test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/pep_prot_test')
|
92 |
-
self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
93 |
-
num_workers=8, pin_memory=True)
|
94 |
-
|
95 |
-
|
96 |
-
class CosineAnnealingWithWarmup(_LRScheduler):
|
97 |
-
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
|
98 |
-
self.warmup_steps = warmup_steps
|
99 |
-
self.total_steps = total_steps
|
100 |
-
self.base_lr = base_lr
|
101 |
-
self.max_lr = max_lr
|
102 |
-
self.min_lr = min_lr
|
103 |
-
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
|
104 |
-
print(f"SELF BASE LRS = {self.base_lrs}")
|
105 |
-
|
106 |
-
def get_lr(self):
|
107 |
-
if self.last_epoch < self.warmup_steps:
|
108 |
-
# Linear warmup phase from base_lr to max_lr
|
109 |
-
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
|
110 |
-
|
111 |
-
# Cosine annealing phase from max_lr to min_lr
|
112 |
-
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
113 |
-
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
114 |
-
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
|
115 |
-
|
116 |
-
return [decayed_lr for base_lr in self.base_lrs]
|
117 |
-
|
118 |
-
class PeptideModel(pl.LightningModule):
|
119 |
-
def __init__(self, n_layers, d_model, n_head,
|
120 |
-
d_k, d_v, d_inner, dropout=0.2,
|
121 |
-
learning_rate=0.00001, max_epochs=15):
|
122 |
-
super(PeptideModel, self).__init__()
|
123 |
-
|
124 |
-
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
125 |
-
# freeze all the esm_model parameters
|
126 |
-
for param in self.esm_model.parameters():
|
127 |
-
param.requires_grad = False
|
128 |
-
|
129 |
-
self.repeated_module = RepeatedModule2(n_layers, d_model,
|
130 |
-
n_head, d_k, d_v, d_inner, dropout=dropout)
|
131 |
-
|
132 |
-
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
133 |
-
d_k, d_v, dropout=dropout)
|
134 |
-
|
135 |
-
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
|
136 |
-
|
137 |
-
self.output_projection_prot = nn.Linear(d_model, 1)
|
138 |
-
|
139 |
-
self.learning_rate = learning_rate
|
140 |
-
self.max_epochs = max_epochs
|
141 |
-
|
142 |
-
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
|
143 |
-
self.historical_memory = 0.9
|
144 |
-
self.class_weights = torch.tensor([3.6625710315221727, 0.5790496079007189]) # binding_site weights, non-bidning site weights
|
145 |
-
|
146 |
-
def forward(self, binder_tokens, target_tokens):
|
147 |
-
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
|
148 |
-
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
|
149 |
-
|
150 |
-
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
|
151 |
-
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
|
152 |
-
protein_sequence)
|
153 |
-
|
154 |
-
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
|
155 |
-
|
156 |
-
prot_enc = self.final_ffn(prot_enc)
|
157 |
-
|
158 |
-
prot_enc = self.output_projection_prot(prot_enc)
|
159 |
-
|
160 |
-
return prot_enc
|
161 |
-
|
162 |
-
def training_step(self, batch, batch_idx):
|
163 |
-
opt = self.optimizers()
|
164 |
-
lr = opt.param_groups[0]['lr']
|
165 |
-
self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
166 |
-
|
167 |
-
target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
|
168 |
-
'attention_mask': batch['anchor_attention_mask'].to(self.device)}
|
169 |
-
binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
|
170 |
-
'attention_mask': batch['positive_attention_mask'].to(self.device)}
|
171 |
-
binding_site = batch['binding_site'].to(self.device)
|
172 |
-
mask = target_tokens['attention_mask']
|
173 |
-
|
174 |
-
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
|
175 |
-
|
176 |
-
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
|
177 |
-
loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
|
178 |
-
|
179 |
-
masked_loss = loss * mask
|
180 |
-
mean_loss = masked_loss.sum() / mask.sum()
|
181 |
-
|
182 |
-
# print('logging')
|
183 |
-
self.log('train_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
184 |
-
return mean_loss
|
185 |
-
|
186 |
-
def validation_step(self, batch, batch_idx):
|
187 |
-
target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
|
188 |
-
'attention_mask': batch['anchor_attention_mask'].to(self.device)}
|
189 |
-
binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
|
190 |
-
'attention_mask': batch['positive_attention_mask'].to(self.device)}
|
191 |
-
binding_site = batch['binding_site'].to(self.device)
|
192 |
-
mask = target_tokens['attention_mask']
|
193 |
-
|
194 |
-
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
|
195 |
-
|
196 |
-
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
|
197 |
-
loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
|
198 |
-
|
199 |
-
# Apply the mask to the loss
|
200 |
-
masked_loss = loss * mask
|
201 |
-
|
202 |
-
# Compute the mean loss only over the valid positions
|
203 |
-
mean_loss = masked_loss.sum() / mask.sum()
|
204 |
-
|
205 |
-
# Calculate predictions and apply mask
|
206 |
-
sigmoid_outputs = torch.sigmoid(outputs_nodes)
|
207 |
-
total = mask.sum()
|
208 |
-
|
209 |
-
self.update_class_thresholds(sigmoid_outputs, binding_site, mask)
|
210 |
-
self.log('threshold', self.classification_threshold, on_epoch=True)
|
211 |
-
|
212 |
-
predict = (sigmoid_outputs >= self.classification_threshold).float()
|
213 |
-
correct = ((predict == binding_site) * mask).sum()
|
214 |
-
accuracy = correct / total
|
215 |
-
|
216 |
-
# Compute AUC
|
217 |
-
outputs_nodes_flat = sigmoid_outputs[mask.bool()].float().cpu().detach().numpy().flatten()
|
218 |
-
binding_site_flat = binding_site[mask.bool()].float().cpu().detach().numpy().flatten()
|
219 |
-
predictions_flat = predict[mask.bool()].float().cpu().detach().numpy().flatten()
|
220 |
-
|
221 |
-
auc = roc_auc_score(binding_site_flat, outputs_nodes_flat)
|
222 |
-
f1 = f1_score(binding_site_flat, predictions_flat)
|
223 |
-
mcc = matthews_corrcoef(binding_site_flat, predictions_flat)
|
224 |
-
|
225 |
-
self.log('val_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
226 |
-
self.log('val_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
227 |
-
self.log('val_auc', auc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
228 |
-
self.log('val_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
229 |
-
self.log('val_mcc', mcc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
230 |
-
|
231 |
-
def configure_optimizers(self):
|
232 |
-
print(f"MAX STEPS = {self.max_epochs}")
|
233 |
-
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
|
234 |
-
# schedulers = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=0.1*self.max_epochs,
|
235 |
-
# max_epochs=self.max_epochs,
|
236 |
-
# warmup_start_lr=5e-4,
|
237 |
-
# eta_min=0.1 * self.learning_rate)
|
238 |
-
|
239 |
-
base_lr = 0
|
240 |
-
max_lr = self.learning_rate
|
241 |
-
min_lr = 0.1 * self.learning_rate
|
242 |
-
|
243 |
-
schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=76, total_steps=1231,
|
244 |
-
base_lr=base_lr, max_lr=max_lr, min_lr=min_lr)
|
245 |
-
|
246 |
-
lr_schedulers = {
|
247 |
-
"scheduler": schedulers,
|
248 |
-
"name": 'learning_rate_logs',
|
249 |
-
"interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
|
250 |
-
'frequency': 1 # The scheduler updates the learning rate after every batch
|
251 |
-
}
|
252 |
-
return [optimizer], [lr_schedulers]
|
253 |
-
|
254 |
-
def update_class_thresholds(self, inputs, targets, mask):
|
255 |
-
with torch.no_grad():
|
256 |
-
min_threshold_value = 0.001
|
257 |
-
thresholds = torch.arange(0.1, 1.0, 0.05, device=inputs.device)
|
258 |
-
|
259 |
-
best_f1_score = 0
|
260 |
-
best_threshold = min_threshold_value
|
261 |
-
|
262 |
-
for threshold in thresholds:
|
263 |
-
binary_predictions = (inputs >= threshold).float()
|
264 |
-
|
265 |
-
tp = ((binary_predictions * targets) * mask).sum().item()
|
266 |
-
fp = ((binary_predictions * (1 - targets)) * mask).sum().item()
|
267 |
-
fn = (((1 - binary_predictions) * targets) * mask).sum().item()
|
268 |
-
|
269 |
-
precision = tp / (tp + fp + 1e-7)
|
270 |
-
recall = tp / (tp + fn + 1e-7)
|
271 |
-
f1_score = 2 * precision * recall / (precision + recall + 1e-7)
|
272 |
-
|
273 |
-
if f1_score > best_f1_score:
|
274 |
-
best_f1_score = f1_score
|
275 |
-
best_threshold = threshold
|
276 |
-
|
277 |
-
updated_threshold = self.historical_memory * self.classification_threshold + (
|
278 |
-
1 - self.historical_memory) * best_threshold
|
279 |
-
self.classification_threshold = nn.Parameter(torch.clamp(updated_threshold, min=min_threshold_value))
|
280 |
-
gc.collect()
|
281 |
-
torch.cuda.empty_cache()
|
282 |
-
|
283 |
-
def training_epoch_end(self, outputs):
|
284 |
-
gc.collect()
|
285 |
-
torch.cuda.empty_cache()
|
286 |
-
super().training_epoch_end(outputs)
|
287 |
-
|
288 |
-
def validation_epoch_end(self, outputs):
|
289 |
-
gc.collect()
|
290 |
-
torch.cuda.empty_cache()
|
291 |
-
super().validation_epoch_end(outputs)
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
def main():
|
297 |
-
parser = ArgumentParser()
|
298 |
-
|
299 |
-
parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
|
300 |
-
parser.add_argument("-d", dest="dataset", required=False, type=str, default="pepnn",
|
301 |
-
help="Which dataset to train on, pepnn, pepbind, or interpep")
|
302 |
-
parser.add_argument("-lr", type=float, default=1e-3)
|
303 |
-
parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
|
304 |
-
parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
|
305 |
-
parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
|
306 |
-
parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
|
307 |
-
parser.add_argument("-d_inner", type=int, default=64)
|
308 |
-
# parser.add_argument("-sm", dest="saved_model", help="File containing initial params", required=False, type=str,
|
309 |
-
# default=None)
|
310 |
-
parser.add_argument("-sm", default=None, help="File containing initial params", type=str)
|
311 |
-
parser.add_argument("--max_epochs", type=int, default=15, help="Max number of epochs to train")
|
312 |
-
args = parser.parse_args()
|
313 |
-
|
314 |
-
print(args.max_epochs)
|
315 |
-
|
316 |
-
# Initialize the process group for distributed training
|
317 |
-
dist.init_process_group(backend="nccl")
|
318 |
-
|
319 |
-
train_dataset = load_from_disk('/home/tc415/muPPIt/dataset/pep_prot_train')
|
320 |
-
val_dataset = load_from_disk('/home/tc415/muPPIt/dataset/pep_prot_val')
|
321 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
322 |
-
|
323 |
-
data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
|
324 |
-
|
325 |
-
model = PeptideModel(6, 64, 6, 64, 128, 64, dropout=0.2,
|
326 |
-
learning_rate=args.lr, max_epochs=args.max_epochs)
|
327 |
-
if args.sm:
|
328 |
-
model = PeptideModel.load_from_checkpoint(args.sm,
|
329 |
-
n_layers=args.n_layers,
|
330 |
-
d_model=args.d_model,
|
331 |
-
n_head=args.n_head,
|
332 |
-
d_k=64,
|
333 |
-
d_v=128,
|
334 |
-
d_inner=64,
|
335 |
-
dropout=0.2,
|
336 |
-
learning_rate=args.lr,
|
337 |
-
max_epochs=args.max_epochs)
|
338 |
-
|
339 |
-
run_id = str(uuid.uuid4())
|
340 |
-
|
341 |
-
print("Classification Thresholds:")
|
342 |
-
print(model.classification_threshold)
|
343 |
-
|
344 |
-
logger = WandbLogger(project=f"bind_evaluator",
|
345 |
-
name=f"finetune_lr={args.lr}_nlayers={args.n_layers}_dmodel={args.d_model}_nhead={args.n_head}_dinner={args.d_inner}",
|
346 |
-
# display on the web
|
347 |
-
# save_dir=f'./pl_logs/',
|
348 |
-
job_type='model-training',
|
349 |
-
id=run_id)
|
350 |
-
|
351 |
-
checkpoint_callback = ModelCheckpoint(
|
352 |
-
monitor='val_mcc',
|
353 |
-
dirpath=args.output_file,
|
354 |
-
filename='model-{epoch:02d}-{val_mcc:.2f}',
|
355 |
-
save_top_k=1,
|
356 |
-
mode='max',
|
357 |
-
)
|
358 |
-
|
359 |
-
early_stopping_callback = EarlyStopping(
|
360 |
-
monitor='val_mcc',
|
361 |
-
patience=5,
|
362 |
-
verbose=True,
|
363 |
-
mode='max'
|
364 |
-
)
|
365 |
-
|
366 |
-
accumulator = GradientAccumulationScheduler(scheduling={0: 4, 3: 3, 10: 2})
|
367 |
-
|
368 |
-
trainer = pl.Trainer(
|
369 |
-
max_epochs=args.max_epochs,
|
370 |
-
accelerator='gpu',
|
371 |
-
strategy='ddp',
|
372 |
-
precision='bf16',
|
373 |
-
logger=logger,
|
374 |
-
devices=[0],
|
375 |
-
callbacks=[checkpoint_callback, accumulator, early_stopping_callback],
|
376 |
-
gradient_clip_val=1.0
|
377 |
-
)
|
378 |
-
|
379 |
-
trainer.fit(model, datamodule=data_module)
|
380 |
-
|
381 |
-
best_model_path = checkpoint_callback.best_model_path
|
382 |
-
print(best_model_path)
|
383 |
-
|
384 |
-
|
385 |
-
if __name__ == "__main__":
|
386 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/models/.gitattributes
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
ProtBert-BFD/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
muppit/models/.gitkeep
DELETED
File without changes
|
muppit/models/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .models import *
|
2 |
-
from .score_domain import *
|
3 |
-
from .dataloaders import *
|
|
|
|
|
|
|
|
muppit/models/dataloaders.py
DELETED
@@ -1,426 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
"""
|
3 |
-
Created on Sat Jul 31 21:54:08 2021
|
4 |
-
|
5 |
-
@author: Osama
|
6 |
-
"""
|
7 |
-
|
8 |
-
from torch.utils.data import Dataset
|
9 |
-
from Bio.PDB import Polypeptide
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import pandas as pd
|
13 |
-
import os
|
14 |
-
import esm
|
15 |
-
import ast
|
16 |
-
import pdb
|
17 |
-
|
18 |
-
|
19 |
-
class InterpepComplexes(Dataset):
|
20 |
-
|
21 |
-
def __init__(self, mode,
|
22 |
-
encoded_data_directory = "../../datasets/interpep_data/"):
|
23 |
-
|
24 |
-
self.mode = mode
|
25 |
-
|
26 |
-
self.encoded_data_directory = encoded_data_directory
|
27 |
-
|
28 |
-
self.train_dir = "../../datasets/interpep_data/train_examples.npy"
|
29 |
-
|
30 |
-
self.test_dir = "../../datasets/interpep_data/test_examples.npy"
|
31 |
-
|
32 |
-
self.val_dir = "../../datasets/interpep_data/val_examples.npy"
|
33 |
-
|
34 |
-
|
35 |
-
self.test_list = np.load(self.test_dir)
|
36 |
-
|
37 |
-
self.train_list = np.load(self.train_dir)
|
38 |
-
|
39 |
-
self.val_list = np.load(self.val_dir)
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
if mode == "train":
|
44 |
-
self.num_data = len(self.train_list)
|
45 |
-
elif mode == "val":
|
46 |
-
self.num_data = len(self.val_list)
|
47 |
-
elif mode == "test":
|
48 |
-
self.num_data = len(self.test_list)
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
def __getitem__(self, index):
|
53 |
-
|
54 |
-
if self.mode == "train":
|
55 |
-
item = self.train_list[index]
|
56 |
-
elif self.mode == "val":
|
57 |
-
item = self.val_list[index]
|
58 |
-
elif self.mode == "test":
|
59 |
-
item = self.test_list[index]
|
60 |
-
|
61 |
-
file_dir = self.encoded_data_directory
|
62 |
-
|
63 |
-
with np.load(file_dir + "fragment_data/" + item + ".npz") as data:
|
64 |
-
temp_pep_sequence = data["target_sequence"]
|
65 |
-
temp_binding_sites = data["binding_sites"]
|
66 |
-
|
67 |
-
|
68 |
-
with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\
|
69 |
-
item.split("_")[1] + ".npz") as data:
|
70 |
-
temp_nodes = data["nodes"]
|
71 |
-
|
72 |
-
|
73 |
-
binding = np.zeros(len(temp_nodes))
|
74 |
-
if len(temp_binding_sites) != 0:
|
75 |
-
binding[temp_binding_sites] = 1
|
76 |
-
target = torch.LongTensor(binding)
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
nodes = temp_nodes[:, 0:20]
|
85 |
-
|
86 |
-
prot_sequence = np.argmax(nodes, axis=-1)
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence])
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
pep_sequence = temp_pep_sequence
|
95 |
-
|
96 |
-
pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1)
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
return pep_sequence, prot_sequence, target
|
103 |
-
|
104 |
-
def __len__(self):
|
105 |
-
return self.num_data
|
106 |
-
|
107 |
-
class PPI(Dataset):
|
108 |
-
|
109 |
-
def __init__(self, mode, csv_dir_path = "/home/u21307130002/PepNN/pepnn/datasets/ppi/"):
|
110 |
-
|
111 |
-
self.mode = mode
|
112 |
-
self.train_data = pd.read_csv(os.path.join(csv_dir_path, 'train.csv'))
|
113 |
-
self.val_data = pd.read_csv(os.path.join(csv_dir_path, 'val.csv'))
|
114 |
-
# self.test_data = pd.read_csv(os.path.join(csv_dir_path, 'test.csv'))
|
115 |
-
|
116 |
-
if self.mode == 'train':
|
117 |
-
self.num_data = len(self.train_data)
|
118 |
-
|
119 |
-
def __len__(self):
|
120 |
-
return self.num_data
|
121 |
-
|
122 |
-
def __getitem__(self, index):
|
123 |
-
# pdb.set_trace()
|
124 |
-
if torch.is_tensor(index):
|
125 |
-
index = index.tolist()
|
126 |
-
|
127 |
-
if self.mode == "train":
|
128 |
-
item = self.train_data.iloc[index]
|
129 |
-
elif self.mode == "val":
|
130 |
-
item = self.val_data.iloc[index]
|
131 |
-
elif self.mode == "test":
|
132 |
-
item = self.test_data.iloc[index]
|
133 |
-
else:
|
134 |
-
item = None
|
135 |
-
|
136 |
-
# print(item)
|
137 |
-
|
138 |
-
motif1 = ast.literal_eval(item['Chain_1_motifs'])
|
139 |
-
motif2 = ast.literal_eval(item['Chain_2_motifs'])
|
140 |
-
|
141 |
-
if len(motif1[0]) > len(motif2[0]):
|
142 |
-
target = motif1
|
143 |
-
prot_sequence = item['Sequence1']
|
144 |
-
pep_sequence = item['Sequence2']
|
145 |
-
else:
|
146 |
-
target = motif2
|
147 |
-
pep_sequence = item['Sequence1']
|
148 |
-
prot_sequence = item['Sequence2']
|
149 |
-
|
150 |
-
target = [int(motif.split('_')[1]) for motif in target]
|
151 |
-
|
152 |
-
if target[-1] >= len(prot_sequence):
|
153 |
-
pdb.set_trace()
|
154 |
-
|
155 |
-
binding = np.zeros(len(prot_sequence))
|
156 |
-
if len(target) != 0:
|
157 |
-
binding[target] = 1
|
158 |
-
target = torch.LongTensor(binding).float()
|
159 |
-
|
160 |
-
# print(f"peptide length: {len(pep_sequence)}")
|
161 |
-
# print(f"protein length: {len(prot_sequence)}")
|
162 |
-
# print(f"target length: {len(target)}")
|
163 |
-
# pdb.set_trace()
|
164 |
-
|
165 |
-
return pep_sequence, prot_sequence, target
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
class PepBindComplexes(Dataset):
|
171 |
-
|
172 |
-
def __init__(self, mode,
|
173 |
-
encoded_data_directory = "../../datasets/pepbind_data/"):
|
174 |
-
|
175 |
-
self.mode = mode
|
176 |
-
|
177 |
-
self.encoded_data_directory = encoded_data_directory
|
178 |
-
|
179 |
-
self.train_dir = "../../datasets/pepbind_data/train_examples.npy"
|
180 |
-
|
181 |
-
self.test_dir = "../../datasets/pepbind_data/test_examples.npy"
|
182 |
-
|
183 |
-
self.val_dir = "../../datasets/pepbind_data/val_examples.npy"
|
184 |
-
|
185 |
-
|
186 |
-
self.test_list = np.load(self.test_dir)
|
187 |
-
|
188 |
-
self.train_list = np.load(self.train_dir)
|
189 |
-
|
190 |
-
self.val_list = np.load(self.val_dir)
|
191 |
-
|
192 |
-
|
193 |
-
if mode == "train":
|
194 |
-
self.num_data = len(self.train_list)
|
195 |
-
elif mode == "val":
|
196 |
-
self.num_data = len(self.val_list)
|
197 |
-
elif mode == "test":
|
198 |
-
self.num_data = len(self.test_list)
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
def __getitem__(self, index):
|
203 |
-
|
204 |
-
if self.mode == "train":
|
205 |
-
item = self.train_list[index]
|
206 |
-
|
207 |
-
|
208 |
-
elif self.mode == "val":
|
209 |
-
item = self.val_list[index]
|
210 |
-
|
211 |
-
|
212 |
-
elif self.mode == "test":
|
213 |
-
item = self.test_list[index]
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
file_dir = self.encoded_data_directory
|
218 |
-
|
219 |
-
|
220 |
-
with np.load(file_dir + "fragment_data/" + item + ".npz") as data:
|
221 |
-
temp_pep_sequence = data["target_sequence"]
|
222 |
-
temp_binding_sites = data["binding_sites"]
|
223 |
-
|
224 |
-
|
225 |
-
with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\
|
226 |
-
item.split("_")[1] + ".npz") as data:
|
227 |
-
temp_nodes = data["nodes"]
|
228 |
-
|
229 |
-
|
230 |
-
binding = np.zeros(len(temp_nodes))
|
231 |
-
if len(temp_binding_sites) != 0:
|
232 |
-
binding[temp_binding_sites] = 1
|
233 |
-
target = torch.LongTensor(binding)
|
234 |
-
|
235 |
-
nodes = temp_nodes[:, 0:20]
|
236 |
-
|
237 |
-
prot_sequence = np.argmax(nodes, axis=-1)
|
238 |
-
|
239 |
-
|
240 |
-
prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence])
|
241 |
-
|
242 |
-
|
243 |
-
pep_sequence = temp_pep_sequence
|
244 |
-
|
245 |
-
pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1)
|
246 |
-
|
247 |
-
|
248 |
-
return pep_sequence, prot_sequence, target
|
249 |
-
|
250 |
-
|
251 |
-
def __len__(self):
|
252 |
-
return self.num_data
|
253 |
-
|
254 |
-
class PeptideComplexes(Dataset):
|
255 |
-
|
256 |
-
def __init__(self, mode,
|
257 |
-
encoded_data_directory = "../../datasets/pepnn_data/all_data/"):
|
258 |
-
|
259 |
-
self.mode = mode
|
260 |
-
|
261 |
-
self.encoded_data_directory = encoded_data_directory
|
262 |
-
|
263 |
-
self.train_dir = "../../datasets/pepnn_data/train_examples.npy"
|
264 |
-
|
265 |
-
self.test_dir = "../../datasets/pepnn_test_data/test_examples.npy"
|
266 |
-
|
267 |
-
self.val_dir = "../../datasets/pepnn_data/val_examples.npy"
|
268 |
-
|
269 |
-
|
270 |
-
self.example_weights = np.load("../../datasets/pepnn_data/example_weights.npy")
|
271 |
-
|
272 |
-
self.test_list = np.load(self.test_dir)
|
273 |
-
|
274 |
-
self.train_list = np.load(self.train_dir)
|
275 |
-
|
276 |
-
self.val_list = np.load(self.val_dir)
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
if mode == "train":
|
281 |
-
self.num_data = len(self.train_list)
|
282 |
-
elif mode == "val":
|
283 |
-
self.num_data = len(self.val_list)
|
284 |
-
elif mode == "test":
|
285 |
-
self.num_data = len(self.test_list)
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
def __getitem__(self, index):
|
290 |
-
|
291 |
-
|
292 |
-
if self.mode == "train":
|
293 |
-
item = self.train_list[index]
|
294 |
-
|
295 |
-
weight = self.example_weights[item]
|
296 |
-
|
297 |
-
elif self.mode == "val":
|
298 |
-
item = self.val_list[index]
|
299 |
-
|
300 |
-
weight = self.example_weights[item]
|
301 |
-
|
302 |
-
elif self.mode == "test":
|
303 |
-
item = self.test_list[index]
|
304 |
-
|
305 |
-
weight = 1
|
306 |
-
|
307 |
-
if self.mode != "test":
|
308 |
-
file_dir = self.encoded_data_directory
|
309 |
-
else:
|
310 |
-
file_dir = "../../datasets/pepnn_test_data/all_data/"
|
311 |
-
|
312 |
-
|
313 |
-
with np.load(file_dir + "fragment_data/" + item + ".npz") as data:
|
314 |
-
temp_pep_sequence = data["target_sequence"]
|
315 |
-
temp_binding_sites = data["binding_sites"]
|
316 |
-
|
317 |
-
|
318 |
-
with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\
|
319 |
-
item.split("_")[1] + ".npz") as data:
|
320 |
-
temp_nodes = data["nodes"]
|
321 |
-
|
322 |
-
|
323 |
-
binding = np.zeros(len(temp_nodes))
|
324 |
-
if len(temp_binding_sites) != 0:
|
325 |
-
binding[temp_binding_sites] = 1
|
326 |
-
target = torch.LongTensor(binding)
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
nodes = temp_nodes[:, 0:20]
|
335 |
-
|
336 |
-
prot_sequence = np.argmax(nodes, axis=-1)
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence])
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
pep_sequence = temp_pep_sequence
|
345 |
-
|
346 |
-
pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1)
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
return pep_sequence, prot_sequence, target, weight
|
353 |
-
|
354 |
-
|
355 |
-
def __len__(self):
|
356 |
-
return self.num_data
|
357 |
-
|
358 |
-
|
359 |
-
class BitenetComplexes(Dataset):
|
360 |
-
|
361 |
-
def __init__(self, encoded_data_directory = "../bitenet_data/all_data/"):
|
362 |
-
|
363 |
-
|
364 |
-
self.encoded_data_directory = encoded_data_directory
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
self.train_dir = "../../datasets/bitenet_data/examples.npy"
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
self.full_list = np.load(self.train_dir)
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
self.num_data = len(self.full_list)
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
def __getitem__(self, index):
|
385 |
-
|
386 |
-
item = self.full_list[index]
|
387 |
-
|
388 |
-
file_dir = self.encoded_data_directory
|
389 |
-
|
390 |
-
with np.load(file_dir + "fragment_data/" + item[:-1] + "_" + item[-1] + ".npz") as data:
|
391 |
-
temp_pep_sequence = data["target_sequence"]
|
392 |
-
temp_binding_matrix = data["binding_matrix"]
|
393 |
-
|
394 |
-
|
395 |
-
with np.load(file_dir + "receptor_data/" + item.split("_")[0] + "_" +\
|
396 |
-
item.split("_")[1][0] + ".npz") as data:
|
397 |
-
temp_nodes = data["nodes"]
|
398 |
-
|
399 |
-
|
400 |
-
binding_sum = np.sum(temp_binding_matrix, axis=0).T
|
401 |
-
|
402 |
-
target = torch.LongTensor(binding_sum >= 1)
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
nodes = temp_nodes[:, 0:20]
|
407 |
-
|
408 |
-
prot_sequence = np.argmax(nodes, axis=-1)
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
prot_sequence = " ".join([Polypeptide.index_to_one(i) for i in prot_sequence])
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
pep_sequence = temp_pep_sequence
|
417 |
-
|
418 |
-
pep_sequence = torch.argmax(torch.FloatTensor(pep_sequence), dim=-1)
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
return pep_sequence, prot_sequence, target
|
424 |
-
|
425 |
-
def __len__(self):
|
426 |
-
return self.num_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/models/layers.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
from torch import nn
|
2 |
-
from .modules import *
|
3 |
-
|
4 |
-
class ReciprocalLayer(nn.Module):
|
5 |
-
|
6 |
-
def __init__(self, d_model, d_inner, n_head, d_k, d_v):
|
7 |
-
|
8 |
-
super().__init__()
|
9 |
-
|
10 |
-
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
11 |
-
d_k, d_v)
|
12 |
-
|
13 |
-
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
14 |
-
d_k, d_v)
|
15 |
-
|
16 |
-
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model,
|
17 |
-
d_k, d_v)
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
self.ffn_seq = FFN(d_model, d_inner)
|
22 |
-
|
23 |
-
self.ffn_protein = FFN(d_model, d_inner)
|
24 |
-
|
25 |
-
def forward(self, sequence_enc, protein_seq_enc):
|
26 |
-
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
27 |
-
|
28 |
-
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
29 |
-
|
30 |
-
|
31 |
-
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc,
|
32 |
-
seq_enc,
|
33 |
-
seq_enc,
|
34 |
-
prot_enc)
|
35 |
-
prot_enc = self.ffn_protein(prot_enc)
|
36 |
-
|
37 |
-
seq_enc = self.ffn_seq(seq_enc)
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/models/models.py
DELETED
@@ -1,238 +0,0 @@
|
|
1 |
-
import pdb
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from .layers import *
|
7 |
-
from .modules import *
|
8 |
-
import pdb
|
9 |
-
from transformers import EsmModel, EsmTokenizer
|
10 |
-
|
11 |
-
def to_var(x):
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
x = x.cuda()
|
14 |
-
return x
|
15 |
-
|
16 |
-
|
17 |
-
class RepeatedModule2(nn.Module):
|
18 |
-
def __init__(self, n_layers, d_model,
|
19 |
-
n_head, d_k, d_v, d_inner, dropout=0.1):
|
20 |
-
super().__init__()
|
21 |
-
|
22 |
-
self.linear1 = nn.Linear(1280, d_model)
|
23 |
-
self.linear2 = nn.Linear(1280, d_model)
|
24 |
-
self.sequence_embedding = nn.Embedding(20, d_model)
|
25 |
-
self.d_model = d_model
|
26 |
-
|
27 |
-
self.reciprocal_layer_stack = nn.ModuleList([
|
28 |
-
ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v)
|
29 |
-
for _ in range(n_layers)])
|
30 |
-
|
31 |
-
self.dropout = nn.Dropout(dropout)
|
32 |
-
self.dropout_2 = nn.Dropout(dropout)
|
33 |
-
|
34 |
-
def forward(self, peptide_sequence, protein_sequence):
|
35 |
-
sequence_attention_list = []
|
36 |
-
|
37 |
-
prot_attention_list = []
|
38 |
-
|
39 |
-
prot_seq_attention_list = []
|
40 |
-
|
41 |
-
seq_prot_attention_list = []
|
42 |
-
|
43 |
-
sequence_enc = self.dropout(self.linear1(peptide_sequence))
|
44 |
-
|
45 |
-
prot_enc = self.dropout_2(self.linear2(protein_sequence))
|
46 |
-
|
47 |
-
for reciprocal_layer in self.reciprocal_layer_stack:
|
48 |
-
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention = \
|
49 |
-
reciprocal_layer(sequence_enc, prot_enc)
|
50 |
-
|
51 |
-
sequence_attention_list.append(sequence_attention)
|
52 |
-
|
53 |
-
prot_attention_list.append(prot_attention)
|
54 |
-
|
55 |
-
prot_seq_attention_list.append(prot_seq_attention)
|
56 |
-
|
57 |
-
seq_prot_attention_list.append(seq_prot_attention)
|
58 |
-
|
59 |
-
return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
|
60 |
-
seq_prot_attention_list, seq_prot_attention_list
|
61 |
-
|
62 |
-
|
63 |
-
class RepeatedModule(nn.Module):
|
64 |
-
|
65 |
-
def __init__(self, n_layers, d_model,
|
66 |
-
n_head, d_k, d_v, d_inner, dropout=0.1):
|
67 |
-
|
68 |
-
super().__init__()
|
69 |
-
|
70 |
-
self.linear = nn.Linear(1024, d_model)
|
71 |
-
self.sequence_embedding = nn.Embedding(20, d_model)
|
72 |
-
self.d_model = d_model
|
73 |
-
|
74 |
-
self.reciprocal_layer_stack = nn.ModuleList([
|
75 |
-
ReciprocalLayer(d_model, d_inner, n_head, d_k, d_v)
|
76 |
-
for _ in range(n_layers)])
|
77 |
-
|
78 |
-
self.dropout = nn.Dropout(dropout)
|
79 |
-
self.dropout_2 = nn.Dropout(dropout)
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
def _positional_embedding(self, batches, number):
|
84 |
-
|
85 |
-
result = torch.exp(torch.arange(0, self.d_model,2,dtype=torch.float32)*-1*(np.log(10000)/self.d_model))
|
86 |
-
|
87 |
-
numbers = torch.arange(0, number, dtype=torch.float32)
|
88 |
-
|
89 |
-
numbers = numbers.unsqueeze(0)
|
90 |
-
|
91 |
-
numbers = numbers.unsqueeze(2)
|
92 |
-
|
93 |
-
result = numbers*result
|
94 |
-
|
95 |
-
result = torch.cat((torch.sin(result), torch.cos(result)),2)
|
96 |
-
|
97 |
-
return result
|
98 |
-
|
99 |
-
def forward(self, peptide_sequence, protein_sequence):
|
100 |
-
|
101 |
-
|
102 |
-
sequence_attention_list = []
|
103 |
-
|
104 |
-
prot_attention_list = []
|
105 |
-
|
106 |
-
prot_seq_attention_list = []
|
107 |
-
|
108 |
-
seq_prot_attention_list = []
|
109 |
-
|
110 |
-
sequence_enc = self.sequence_embedding(peptide_sequence)
|
111 |
-
|
112 |
-
sequence_enc += to_var(self._positional_embedding(peptide_sequence.shape[0],
|
113 |
-
peptide_sequence.shape[1]))
|
114 |
-
sequence_enc = self.dropout(sequence_enc)
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
prot_enc = self.dropout_2(self.linear(protein_sequence))
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
for reciprocal_layer in self.reciprocal_layer_stack:
|
126 |
-
|
127 |
-
prot_enc, sequence_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention =\
|
128 |
-
reciprocal_layer(sequence_enc, prot_enc)
|
129 |
-
|
130 |
-
sequence_attention_list.append(sequence_attention)
|
131 |
-
|
132 |
-
prot_attention_list.append(prot_attention)
|
133 |
-
|
134 |
-
prot_seq_attention_list.append(prot_seq_attention)
|
135 |
-
|
136 |
-
seq_prot_attention_list.append(seq_prot_attention)
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
return prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\
|
141 |
-
seq_prot_attention_list, seq_prot_attention_list
|
142 |
-
|
143 |
-
|
144 |
-
class FullModel(nn.Module):
|
145 |
-
|
146 |
-
def __init__(self, n_layers, d_model, n_head,
|
147 |
-
d_k, d_v, d_inner, return_attention=False, dropout=0.2):
|
148 |
-
super().__init__()
|
149 |
-
|
150 |
-
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
151 |
-
|
152 |
-
# freeze all the esm_model parameters
|
153 |
-
for param in self.esm_model.parameters():
|
154 |
-
param.requires_grad = False
|
155 |
-
|
156 |
-
self.repeated_module = RepeatedModule2(n_layers, d_model,
|
157 |
-
n_head, d_k, d_v, d_inner, dropout=dropout)
|
158 |
-
|
159 |
-
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
160 |
-
d_k, d_v, dropout=dropout)
|
161 |
-
|
162 |
-
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
|
163 |
-
|
164 |
-
self.output_projection_prot = nn.Linear(d_model, 1)
|
165 |
-
self.sigmoid = nn.Sigmoid()
|
166 |
-
|
167 |
-
self.return_attention = return_attention
|
168 |
-
|
169 |
-
def forward(self, binder_tokens, target_tokens):
|
170 |
-
|
171 |
-
with torch.no_grad():
|
172 |
-
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
|
173 |
-
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
|
174 |
-
|
175 |
-
# pdb.set_trace()
|
176 |
-
|
177 |
-
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
|
178 |
-
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
|
179 |
-
protein_sequence)
|
180 |
-
|
181 |
-
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
|
182 |
-
|
183 |
-
# pdb.set_trace()
|
184 |
-
|
185 |
-
prot_enc = self.final_ffn(prot_enc)
|
186 |
-
|
187 |
-
prot_enc = self.sigmoid(self.output_projection_prot(prot_enc))
|
188 |
-
|
189 |
-
return prot_enc
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
class Original_FullModel(nn.Module):
|
194 |
-
|
195 |
-
def __init__(self, n_layers, d_model, n_head,
|
196 |
-
d_k, d_v, d_inner, return_attention=False, dropout=0.2):
|
197 |
-
|
198 |
-
super().__init__()
|
199 |
-
self.repeated_module = RepeatedModule(n_layers, d_model,
|
200 |
-
n_head, d_k, d_v, d_inner, dropout=dropout)
|
201 |
-
|
202 |
-
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
203 |
-
d_k, d_v, dropout=dropout)
|
204 |
-
|
205 |
-
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
|
206 |
-
self.output_projection_prot = nn.Linear(d_model, 2)
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
self.softmax_prot =nn.LogSoftmax(dim=-1)
|
211 |
-
|
212 |
-
|
213 |
-
self.return_attention = return_attention
|
214 |
-
|
215 |
-
def forward(self, peptide_sequence, protein_sequence):
|
216 |
-
|
217 |
-
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list,\
|
218 |
-
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
|
219 |
-
protein_sequence)
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
|
224 |
-
|
225 |
-
prot_enc = self.final_ffn(prot_enc)
|
226 |
-
|
227 |
-
prot_enc = self.softmax_prot(self.output_projection_prot(prot_enc))
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
if not self.return_attention:
|
234 |
-
return prot_enc
|
235 |
-
else:
|
236 |
-
return prot_enc, sequence_attention_list, prot_attention_list,\
|
237 |
-
seq_prot_attention_list, seq_prot_attention_list
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/models/modules.py
DELETED
@@ -1,213 +0,0 @@
|
|
1 |
-
from torch import nn
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
|
7 |
-
def to_var(x):
|
8 |
-
if torch.cuda.is_available():
|
9 |
-
x = x.cuda()
|
10 |
-
return x
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
class MultiHeadAttentionSequence(nn.Module):
|
17 |
-
|
18 |
-
|
19 |
-
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
20 |
-
|
21 |
-
super().__init__()
|
22 |
-
|
23 |
-
self.n_head = n_head
|
24 |
-
self.d_model = d_model
|
25 |
-
self.d_k = d_k
|
26 |
-
self.d_v = d_v
|
27 |
-
|
28 |
-
|
29 |
-
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
30 |
-
self.W_K = nn.Linear(d_model, n_head*d_k)
|
31 |
-
self.W_V = nn.Linear(d_model, n_head*d_v)
|
32 |
-
self.W_O = nn.Linear(n_head*d_v, d_model)
|
33 |
-
|
34 |
-
|
35 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
36 |
-
|
37 |
-
self.dropout = nn.Dropout(dropout)
|
38 |
-
|
39 |
-
|
40 |
-
def forward(self, q, k, v):
|
41 |
-
|
42 |
-
batch, len_q, _ = q.size()
|
43 |
-
batch, len_k, _ = k.size()
|
44 |
-
batch, len_v, _ = v.size()
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
49 |
-
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
50 |
-
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
Q = Q.transpose(1, 2)
|
56 |
-
K = K.transpose(1, 2).transpose(2, 3)
|
57 |
-
V = V.transpose(1, 2)
|
58 |
-
|
59 |
-
|
60 |
-
attention = torch.matmul(Q, K)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
attention = attention /np.sqrt(self.d_k)
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
attention = F.softmax(attention, dim=-1)
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
output = torch.matmul(attention, V)
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
output = self.W_O(output)
|
83 |
-
|
84 |
-
|
85 |
-
output = self.dropout(output)
|
86 |
-
|
87 |
-
output = self.layer_norm(output + q)
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
return output, attention
|
94 |
-
|
95 |
-
class MultiHeadAttentionReciprocal(nn.Module):
|
96 |
-
|
97 |
-
|
98 |
-
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
99 |
-
|
100 |
-
super().__init__()
|
101 |
-
|
102 |
-
self.n_head = n_head
|
103 |
-
self.d_model = d_model
|
104 |
-
self.d_k = d_k
|
105 |
-
self.d_v = d_v
|
106 |
-
|
107 |
-
|
108 |
-
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
109 |
-
self.W_K = nn.Linear(d_model, n_head*d_k)
|
110 |
-
self.W_V = nn.Linear(d_model, n_head*d_v)
|
111 |
-
self.W_O = nn.Linear(n_head*d_v, d_model)
|
112 |
-
self.W_V_2 = nn.Linear(d_model, n_head*d_v)
|
113 |
-
self.W_O_2 = nn.Linear(n_head*d_v, d_model)
|
114 |
-
|
115 |
-
self.layer_norm = nn.LayerNorm(d_model)
|
116 |
-
|
117 |
-
self.dropout = nn.Dropout(dropout)
|
118 |
-
|
119 |
-
self.layer_norm_2 = nn.LayerNorm(d_model)
|
120 |
-
|
121 |
-
self.dropout_2 = nn.Dropout(dropout)
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
def forward(self, q, k, v, v_2):
|
127 |
-
|
128 |
-
batch, len_q, _ = q.size()
|
129 |
-
batch, len_k, _ = k.size()
|
130 |
-
batch, len_v, _ = v.size()
|
131 |
-
batch, len_v_2, _ = v_2.size()
|
132 |
-
|
133 |
-
|
134 |
-
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
135 |
-
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
136 |
-
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
137 |
-
V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
Q = Q.transpose(1, 2)
|
142 |
-
K = K.transpose(1, 2).transpose(2, 3)
|
143 |
-
V = V.transpose(1, 2)
|
144 |
-
V_2 = V_2.transpose(1,2)
|
145 |
-
|
146 |
-
attention = torch.matmul(Q, K)
|
147 |
-
|
148 |
-
|
149 |
-
attention = attention /np.sqrt(self.d_k)
|
150 |
-
|
151 |
-
attention_2 = attention.transpose(-2, -1)
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
attention = F.softmax(attention, dim=-1)
|
156 |
-
|
157 |
-
attention_2 = F.softmax(attention_2, dim=-1)
|
158 |
-
|
159 |
-
|
160 |
-
output = torch.matmul(attention, V)
|
161 |
-
|
162 |
-
output_2 = torch.matmul(attention_2, V_2)
|
163 |
-
|
164 |
-
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
165 |
-
|
166 |
-
output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
|
167 |
-
|
168 |
-
output = self.W_O(output)
|
169 |
-
|
170 |
-
output_2 = self.W_O_2(output_2)
|
171 |
-
|
172 |
-
output = self.dropout(output)
|
173 |
-
|
174 |
-
output = self.layer_norm(output + q)
|
175 |
-
|
176 |
-
output_2 = self.dropout(output_2)
|
177 |
-
|
178 |
-
output_2 = self.layer_norm(output_2 + k)
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
return output, output_2, attention, attention_2
|
185 |
-
|
186 |
-
|
187 |
-
class FFN(nn.Module):
|
188 |
-
|
189 |
-
def __init__(self, d_in, d_hid, dropout=0.1):
|
190 |
-
super().__init__()
|
191 |
-
|
192 |
-
self.layer_1 = nn.Conv1d(d_in, d_hid,1)
|
193 |
-
self.layer_2 = nn.Conv1d(d_hid, d_in,1)
|
194 |
-
self.relu = nn.ReLU()
|
195 |
-
self.layer_norm = nn.LayerNorm(d_in)
|
196 |
-
|
197 |
-
self.dropout = nn.Dropout(dropout)
|
198 |
-
|
199 |
-
def forward(self, x):
|
200 |
-
|
201 |
-
residual = x
|
202 |
-
output = self.layer_1(x.transpose(1, 2))
|
203 |
-
|
204 |
-
output = self.relu(output)
|
205 |
-
|
206 |
-
output = self.layer_2(output)
|
207 |
-
|
208 |
-
output = self.dropout(output)
|
209 |
-
|
210 |
-
output = self.layer_norm(output.transpose(1, 2)+residual)
|
211 |
-
|
212 |
-
return output
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/models/score_domain.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
from scipy.stats import norm
|
2 |
-
import numpy as np
|
3 |
-
import os
|
4 |
-
|
5 |
-
|
6 |
-
def score(outputs):
|
7 |
-
|
8 |
-
weight = 0.03
|
9 |
-
binding_size_dist = np.load(os.path.join(os.path.dirname(__file__), "../params/binding_size_train_dist.npy"))
|
10 |
-
|
11 |
-
|
12 |
-
mean = np.mean(binding_size_dist)
|
13 |
-
|
14 |
-
std = np.std(binding_size_dist)
|
15 |
-
|
16 |
-
dist = norm(mean, std)
|
17 |
-
|
18 |
-
|
19 |
-
max_score = 0
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
scores = np.exp(outputs[0])[:, 1]
|
24 |
-
|
25 |
-
indices = np.argsort(-1*scores)
|
26 |
-
|
27 |
-
for j in range(1, len(indices)):
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
score = (1-weight)*np.mean(scores[indices[:j]]) + weight*(dist.pdf(j/len(indices)))
|
32 |
-
|
33 |
-
|
34 |
-
if score > max_score:
|
35 |
-
|
36 |
-
max_score = score
|
37 |
-
|
38 |
-
|
39 |
-
return max_score
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/predict.py
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import pytorch_lightning as pl
|
3 |
-
from torch.utils.data import DataLoader
|
4 |
-
from datasets import load_from_disk
|
5 |
-
from transformers import AutoTokenizer
|
6 |
-
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
|
7 |
-
from argparse import ArgumentParser
|
8 |
-
import os
|
9 |
-
import torch.distributed as dist
|
10 |
-
|
11 |
-
from models import * # Import your model and other necessary classes/functions here
|
12 |
-
|
13 |
-
|
14 |
-
class PeptideModel(pl.LightningModule):
|
15 |
-
def __init__(self, n_layers, d_model, n_head,
|
16 |
-
d_k, d_v, d_inner, dropout=0.2,
|
17 |
-
learning_rate=0.00001, max_epochs=15):
|
18 |
-
super(PeptideModel, self).__init__()
|
19 |
-
|
20 |
-
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
21 |
-
# freeze all the esm_model parameters
|
22 |
-
for param in self.esm_model.parameters():
|
23 |
-
param.requires_grad = False
|
24 |
-
|
25 |
-
self.repeated_module = RepeatedModule2(n_layers, d_model,
|
26 |
-
n_head, d_k, d_v, d_inner, dropout=dropout)
|
27 |
-
|
28 |
-
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
29 |
-
d_k, d_v, dropout=dropout)
|
30 |
-
|
31 |
-
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
|
32 |
-
|
33 |
-
self.output_projection_prot = nn.Linear(d_model, 1)
|
34 |
-
|
35 |
-
self.learning_rate = learning_rate
|
36 |
-
self.max_epochs = max_epochs
|
37 |
-
|
38 |
-
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
|
39 |
-
self.historical_memory = 0.9
|
40 |
-
self.class_weights = torch.tensor(
|
41 |
-
[3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
|
42 |
-
|
43 |
-
def forward(self, binder_tokens, target_tokens):
|
44 |
-
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
|
45 |
-
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
|
46 |
-
|
47 |
-
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
|
48 |
-
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
|
49 |
-
protein_sequence)
|
50 |
-
|
51 |
-
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
|
52 |
-
|
53 |
-
prot_enc = self.final_ffn(prot_enc)
|
54 |
-
|
55 |
-
prot_enc = self.output_projection_prot(prot_enc)
|
56 |
-
|
57 |
-
return torch.sigmoid(prot_enc)
|
58 |
-
|
59 |
-
|
60 |
-
def main():
|
61 |
-
parser = ArgumentParser()
|
62 |
-
parser.add_argument("-sm", default='/home/tc415/muPPIt/muppit/train_base_1/model-epoch=14-val_loss=0.40.ckpt',
|
63 |
-
help="File containing initial params", type=str)
|
64 |
-
parser.add_argument("-batch_size", type=int, default=32, help="Batch size")
|
65 |
-
parser.add_argument("-lr", type=float, default=1e-3)
|
66 |
-
parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
|
67 |
-
parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
|
68 |
-
parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
|
69 |
-
parser.add_argument("-d_inner", type=int, default=64)
|
70 |
-
parser.add_argument("-target", type=str)
|
71 |
-
parser.add_argument("-binder", type=str)
|
72 |
-
args = parser.parse_args()
|
73 |
-
# print(args)
|
74 |
-
device = 'cuda:0'
|
75 |
-
|
76 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
77 |
-
|
78 |
-
anchor_tokens = tokenizer(args.target, return_tensors='pt', padding=True, truncation=True, max_length=40000)
|
79 |
-
|
80 |
-
positive_tokens = tokenizer(args.binder, return_tensors='pt', padding=True, truncation=True, max_length=40000)
|
81 |
-
|
82 |
-
anchor_tokens['attention_mask'][0][0] = 0
|
83 |
-
anchor_tokens['attention_mask'][0][-1] = 0
|
84 |
-
positive_tokens['attention_mask'][0][0] = 0
|
85 |
-
positive_tokens['attention_mask'][0][-1] = 0
|
86 |
-
|
87 |
-
target_tokens = {'input_ids': anchor_tokens["input_ids"].to(device),
|
88 |
-
'attention_mask': anchor_tokens["attention_mask"].to(device)}
|
89 |
-
binder_tokens = {'input_ids': positive_tokens['input_ids'].to(device),
|
90 |
-
'attention_mask': positive_tokens['attention_mask'].to(device)}
|
91 |
-
|
92 |
-
print(binder_tokens['input_ids'].shape)
|
93 |
-
|
94 |
-
model = PeptideModel.load_from_checkpoint(args.sm,
|
95 |
-
n_layers=args.n_layers,
|
96 |
-
d_model=args.d_model,
|
97 |
-
n_head=args.n_head,
|
98 |
-
d_k=64,
|
99 |
-
d_v=128,
|
100 |
-
d_inner=64).to(device)
|
101 |
-
|
102 |
-
model.eval()
|
103 |
-
|
104 |
-
prediction = model(binder_tokens, target_tokens).squeeze(-1)[0][1:-1]
|
105 |
-
print(prediction.shape)
|
106 |
-
print(model.classification_threshold)
|
107 |
-
|
108 |
-
binding_site = []
|
109 |
-
for i in range(len(prediction)):
|
110 |
-
if prediction[i] >= model.classification_threshold:
|
111 |
-
binding_site.append(i)
|
112 |
-
|
113 |
-
print(binding_site)
|
114 |
-
print(len(binding_site))
|
115 |
-
|
116 |
-
|
117 |
-
if __name__ == "__main__":
|
118 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/scripts/.gitkeep
DELETED
File without changes
|
muppit/scripts/predict_binding_site.py
DELETED
@@ -1,149 +0,0 @@
|
|
1 |
-
from Bio import SeqIO
|
2 |
-
from Bio.PDB import Polypeptide
|
3 |
-
from transformers import BertModel, BertTokenizer, pipeline
|
4 |
-
from pepnn_seq.models import FullModel
|
5 |
-
from pepnn_seq.models import score
|
6 |
-
import pandas as pd
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import argparse
|
10 |
-
import os
|
11 |
-
|
12 |
-
|
13 |
-
def to_var(x):
|
14 |
-
if torch.cuda.is_available():
|
15 |
-
x = x.cuda()
|
16 |
-
return x
|
17 |
-
|
18 |
-
|
19 |
-
if __name__ == "__main__":
|
20 |
-
|
21 |
-
parser = argparse.ArgumentParser()
|
22 |
-
|
23 |
-
|
24 |
-
parser.add_argument("-prot", dest="input_protein_file", required=False, type=str,
|
25 |
-
help="Fasta file with protein sequence")
|
26 |
-
|
27 |
-
parser.add_argument("-pep", dest="input_peptide_file", required=False, type=str, default=None,
|
28 |
-
help="Fasta file with peptide sequence")
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
parser.add_argument("-o", dest="output_directory", required=False, type=str, default=None,
|
33 |
-
help="Output directory")
|
34 |
-
|
35 |
-
|
36 |
-
parser.add_argument("-p", dest="params", required=False, type=str, default="../params/params.pth",
|
37 |
-
help="Model parameters")
|
38 |
-
|
39 |
-
args = parser.parse_args()
|
40 |
-
|
41 |
-
if args.output_directory == None:
|
42 |
-
output_directory = os.path.split(args.input_protein_file)[-1].split(".")[0] + "_seq"
|
43 |
-
else:
|
44 |
-
output_directory = args.output_directory
|
45 |
-
|
46 |
-
if not os.path.exists(output_directory):
|
47 |
-
os.mkdir(output_directory)
|
48 |
-
|
49 |
-
|
50 |
-
records = SeqIO.parse(args.input_protein_file, format="fasta")
|
51 |
-
|
52 |
-
prot_sequence = ' '.join(list(records)[0].seq)
|
53 |
-
|
54 |
-
|
55 |
-
protbert_dir = os.path.join(os.path.dirname(__file__), '../models/ProtBert-BFD/')
|
56 |
-
|
57 |
-
vocabFilePath = os.path.join(protbert_dir, 'vocab.txt')
|
58 |
-
tokenizer = BertTokenizer(vocabFilePath, do_lower_case=False )
|
59 |
-
seq_embedding = BertModel.from_pretrained(protbert_dir)
|
60 |
-
|
61 |
-
if torch.cuda.is_available():
|
62 |
-
seq_embedding = pipeline('feature-extraction', model=seq_embedding, tokenizer=tokenizer, device=0)
|
63 |
-
else:
|
64 |
-
seq_embedding = pipeline('feature-extraction', model=seq_embedding, tokenizer=tokenizer, device=-1)
|
65 |
-
|
66 |
-
embedding = seq_embedding(prot_sequence)
|
67 |
-
|
68 |
-
embedding = np.array(embedding)
|
69 |
-
|
70 |
-
seq_len = len(prot_sequence.replace(" ", ""))
|
71 |
-
start_Idx = 1
|
72 |
-
end_Idx = seq_len+1
|
73 |
-
seq_emd = embedding[0][start_Idx:end_Idx]
|
74 |
-
|
75 |
-
|
76 |
-
prot_seq = to_var(torch.FloatTensor(seq_emd).unsqueeze(0))
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
if args.input_peptide_file != None:
|
81 |
-
|
82 |
-
records = SeqIO.parse(args.input_peptide_file, format="fasta")
|
83 |
-
|
84 |
-
pep_sequence = str(list(records)[0].seq).replace("X", "")
|
85 |
-
|
86 |
-
pep_sequence = [Polypeptide.d1_to_index[i] for i in pep_sequence]
|
87 |
-
|
88 |
-
else:
|
89 |
-
|
90 |
-
pep_sequence = [5 for i in range(10)]
|
91 |
-
|
92 |
-
pep_seq = to_var(torch.LongTensor(pep_sequence).unsqueeze(0))
|
93 |
-
|
94 |
-
|
95 |
-
model = FullModel(6, 64, 6,
|
96 |
-
|
97 |
-
64, 128, 64, dropout=0.2)
|
98 |
-
if torch.cuda.is_available():
|
99 |
-
model.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), args.params)))
|
100 |
-
else:
|
101 |
-
model.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), args.params),
|
102 |
-
map_location='cpu'))
|
103 |
-
|
104 |
-
if torch.cuda.is_available():
|
105 |
-
torch.cuda.empty_cache()
|
106 |
-
model.cuda()
|
107 |
-
|
108 |
-
|
109 |
-
model.eval()
|
110 |
-
|
111 |
-
if torch.cuda.is_available():
|
112 |
-
outputs = model(pep_seq, prot_seq).cpu().detach().numpy()
|
113 |
-
else:
|
114 |
-
outputs = model(pep_seq, prot_seq).detach().numpy()
|
115 |
-
|
116 |
-
# compute score for the domain and output file
|
117 |
-
|
118 |
-
score_prm = score(outputs)
|
119 |
-
|
120 |
-
with open(output_directory + "/prm_score.txt", 'w') as output_file:
|
121 |
-
|
122 |
-
output_file.writelines("The input protein's score is {0:.2f}".format(score_prm))
|
123 |
-
|
124 |
-
|
125 |
-
# output prediction as csv
|
126 |
-
|
127 |
-
outputs = np.exp(outputs[0])
|
128 |
-
|
129 |
-
amino_acids = []
|
130 |
-
|
131 |
-
probabilities = []
|
132 |
-
|
133 |
-
position = []
|
134 |
-
for index, aa in enumerate(prot_sequence.split(" ")):
|
135 |
-
probabilities.append(outputs[index, 1])
|
136 |
-
amino_acids.append(aa)
|
137 |
-
position.append(index+1)
|
138 |
-
|
139 |
-
|
140 |
-
output = pd.DataFrame()
|
141 |
-
|
142 |
-
output["Position"] = position
|
143 |
-
output["Amino acid"] = amino_acids
|
144 |
-
output["Probabilities"] = probabilities
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
output.to_csv(output_directory + "/binding_site_prediction.csv", index=False)
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/test_evaluator.py
DELETED
@@ -1,197 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import pytorch_lightning as pl
|
3 |
-
from torch.utils.data import DataLoader
|
4 |
-
from datasets import load_from_disk
|
5 |
-
from transformers import AutoTokenizer
|
6 |
-
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
|
7 |
-
from argparse import ArgumentParser
|
8 |
-
import os
|
9 |
-
import torch.distributed as dist
|
10 |
-
|
11 |
-
from models import * # Import your model and other necessary classes/functions here
|
12 |
-
|
13 |
-
|
14 |
-
def collate_fn(batch):
|
15 |
-
# Unpack the batch
|
16 |
-
anchors = []
|
17 |
-
positives = []
|
18 |
-
# negatives = []
|
19 |
-
binding_sites = []
|
20 |
-
|
21 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
22 |
-
|
23 |
-
for b in batch:
|
24 |
-
anchors.append(b['anchors'])
|
25 |
-
positives.append(b['positives'])
|
26 |
-
# negatives.append(b['negatives'])
|
27 |
-
binding_sites.append(b['binding_site'])
|
28 |
-
|
29 |
-
# Collate the tensors using torch's pad_sequence
|
30 |
-
anchor_input_ids = torch.nn.utils.rnn.pad_sequence(
|
31 |
-
[torch.Tensor(item['input_ids']).squeeze(0) for item in anchors], batch_first=True, padding_value=tokenizer.pad_token_id)
|
32 |
-
anchor_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
33 |
-
[torch.Tensor(item['attention_mask']).squeeze(0) for item in anchors], batch_first=True, padding_value=0)
|
34 |
-
|
35 |
-
positive_input_ids = torch.nn.utils.rnn.pad_sequence(
|
36 |
-
[torch.Tensor(item['input_ids']).squeeze(0) for item in positives], batch_first=True, padding_value=tokenizer.pad_token_id)
|
37 |
-
positive_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
38 |
-
[torch.Tensor(item['attention_mask']).squeeze(0) for item in positives], batch_first=True, padding_value=0)
|
39 |
-
|
40 |
-
n, max_length = anchor_input_ids.shape[0], anchor_input_ids.shape[1]
|
41 |
-
site = torch.zeros(n, max_length)
|
42 |
-
for i in range(len(binding_sites)):
|
43 |
-
binding_site = binding_sites[i]
|
44 |
-
site[i, binding_site] = 1
|
45 |
-
|
46 |
-
# Return the collated batch
|
47 |
-
return {
|
48 |
-
'anchor_input_ids': anchor_input_ids.int(),
|
49 |
-
'anchor_attention_mask': anchor_attention_mask.int(),
|
50 |
-
'positive_input_ids': positive_input_ids.int(),
|
51 |
-
'positive_attention_mask': positive_attention_mask.int(),
|
52 |
-
# 'negative_input_ids': negative_input_ids.int(),
|
53 |
-
# 'negative_attention_mask': negative_attention_mask.int(),
|
54 |
-
'binding_site': site
|
55 |
-
}
|
56 |
-
|
57 |
-
|
58 |
-
class CustomDataModule(pl.LightningDataModule):
|
59 |
-
def __init__(self, tokenizer, batch_size: int = 128):
|
60 |
-
super().__init__()
|
61 |
-
self.batch_size = batch_size
|
62 |
-
self.tokenizer = tokenizer
|
63 |
-
|
64 |
-
def test_dataloader(self):
|
65 |
-
test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/test_dataset_drop_500')
|
66 |
-
return DataLoader(test_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8, pin_memory=True)
|
67 |
-
|
68 |
-
|
69 |
-
class PeptideModel(pl.LightningModule):
|
70 |
-
def __init__(self, n_layers, d_model, n_head,
|
71 |
-
d_k, d_v, d_inner, dropout=0.2,
|
72 |
-
learning_rate=0.00001, max_epochs=15):
|
73 |
-
super(PeptideModel, self).__init__()
|
74 |
-
|
75 |
-
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
76 |
-
# freeze all the esm_model parameters
|
77 |
-
for param in self.esm_model.parameters():
|
78 |
-
param.requires_grad = False
|
79 |
-
|
80 |
-
self.repeated_module = RepeatedModule2(n_layers, d_model,
|
81 |
-
n_head, d_k, d_v, d_inner, dropout=dropout)
|
82 |
-
|
83 |
-
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
84 |
-
d_k, d_v, dropout=dropout)
|
85 |
-
|
86 |
-
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
|
87 |
-
|
88 |
-
self.output_projection_prot = nn.Linear(d_model, 1)
|
89 |
-
|
90 |
-
self.learning_rate = learning_rate
|
91 |
-
self.max_epochs = max_epochs
|
92 |
-
|
93 |
-
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
|
94 |
-
self.historical_memory = 0.9
|
95 |
-
self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
|
96 |
-
|
97 |
-
def forward(self, binder_tokens, target_tokens):
|
98 |
-
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
|
99 |
-
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
|
100 |
-
|
101 |
-
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
|
102 |
-
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
|
103 |
-
protein_sequence)
|
104 |
-
|
105 |
-
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
|
106 |
-
|
107 |
-
prot_enc = self.final_ffn(prot_enc)
|
108 |
-
|
109 |
-
prot_enc = self.output_projection_prot(prot_enc)
|
110 |
-
|
111 |
-
return prot_enc
|
112 |
-
|
113 |
-
def test_step(self, batch, batch_idx):
|
114 |
-
target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
|
115 |
-
'attention_mask': batch['anchor_attention_mask'].to(self.device)}
|
116 |
-
binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
|
117 |
-
'attention_mask': batch['positive_attention_mask'].to(self.device)}
|
118 |
-
binding_site = batch['binding_site'].to(self.device)
|
119 |
-
mask = target_tokens['attention_mask']
|
120 |
-
|
121 |
-
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
|
122 |
-
|
123 |
-
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
|
124 |
-
loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
|
125 |
-
|
126 |
-
masked_loss = loss * mask
|
127 |
-
mean_loss = masked_loss.sum() / mask.sum()
|
128 |
-
|
129 |
-
sigmoid_outputs = torch.sigmoid(outputs_nodes)
|
130 |
-
total = mask.sum()
|
131 |
-
|
132 |
-
# self.update_class_thresholds(sigmoid_outputs, binding_site, mask)
|
133 |
-
# self.log('threshold', self.classification_threshold, on_epoch=True)
|
134 |
-
|
135 |
-
predict = (sigmoid_outputs >= self.classification_threshold).float()
|
136 |
-
correct = ((predict == binding_site) * mask).sum()
|
137 |
-
accuracy = correct / total
|
138 |
-
|
139 |
-
outputs_nodes_flat = sigmoid_outputs[mask.bool()].float().cpu().detach().numpy().flatten()
|
140 |
-
binding_site_flat = binding_site[mask.bool()].float().cpu().detach().numpy().flatten()
|
141 |
-
predictions_flat = predict[mask.bool()].float().cpu().detach().numpy().flatten()
|
142 |
-
|
143 |
-
auc = roc_auc_score(binding_site_flat, outputs_nodes_flat)
|
144 |
-
f1 = f1_score(binding_site_flat, predictions_flat)
|
145 |
-
mcc = matthews_corrcoef(binding_site_flat, predictions_flat)
|
146 |
-
|
147 |
-
self.log('test_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
148 |
-
self.log('test_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
149 |
-
self.log('test_auc', auc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
150 |
-
self.log('test_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
151 |
-
self.log('test_mcc', mcc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
152 |
-
|
153 |
-
|
154 |
-
def main():
|
155 |
-
parser = ArgumentParser()
|
156 |
-
parser.add_argument("-sm", default='/home/tc415/muPPIt/muppit/train_base_1/model-epoch=14-val_loss=0.40.ckpt', help="File containing initial params", type=str)
|
157 |
-
parser.add_argument("-batch_size", type=int, default=32, help="Batch size")
|
158 |
-
parser.add_argument("-lr", type=float, default=1e-3)
|
159 |
-
parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
|
160 |
-
parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
|
161 |
-
parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
|
162 |
-
parser.add_argument("-d_inner", type=int, default=64)
|
163 |
-
args = parser.parse_args()
|
164 |
-
print(args.sm)
|
165 |
-
|
166 |
-
# Initialize the process group for distributed training
|
167 |
-
dist.init_process_group(backend='nccl')
|
168 |
-
|
169 |
-
test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/pep_prot_test')
|
170 |
-
# print(len(test_dataset))
|
171 |
-
|
172 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
173 |
-
|
174 |
-
data_module = CustomDataModule(tokenizer, args.batch_size)
|
175 |
-
|
176 |
-
model = PeptideModel.load_from_checkpoint(args.sm,
|
177 |
-
n_layers=args.n_layers,
|
178 |
-
d_model=args.d_model,
|
179 |
-
n_head=args.n_head,
|
180 |
-
d_k=64,
|
181 |
-
d_v=128,
|
182 |
-
d_inner=64)
|
183 |
-
|
184 |
-
print(f"Class threshold = {model.classification_threshold}")
|
185 |
-
|
186 |
-
trainer = pl.Trainer(accelerator='gpu',
|
187 |
-
devices=[0,1,2,3,4,5,6,7],
|
188 |
-
strategy='ddp',
|
189 |
-
precision='bf16')
|
190 |
-
|
191 |
-
results = trainer.test(model, datamodule=data_module)
|
192 |
-
|
193 |
-
print(results)
|
194 |
-
|
195 |
-
|
196 |
-
if __name__ == "__main__":
|
197 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
muppit/train_evaluator.py
DELETED
@@ -1,408 +0,0 @@
|
|
1 |
-
import pdb
|
2 |
-
from pytorch_lightning.strategies import DDPStrategy
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torch.utils.data import DataLoader, DistributedSampler
|
6 |
-
from datasets import load_from_disk
|
7 |
-
import pytorch_lightning as pl
|
8 |
-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
|
9 |
-
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
|
10 |
-
from pytorch_lightning.loggers import WandbLogger
|
11 |
-
from torch.optim.lr_scheduler import _LRScheduler
|
12 |
-
from transformers.optimization import get_cosine_schedule_with_warmup
|
13 |
-
from argparse import ArgumentParser
|
14 |
-
import os
|
15 |
-
import uuid
|
16 |
-
import numpy as np
|
17 |
-
import torch.distributed as dist
|
18 |
-
from models import *
|
19 |
-
from torch.nn.utils.rnn import pad_sequence
|
20 |
-
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
21 |
-
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
22 |
-
from torch.optim import Adam, AdamW
|
23 |
-
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
|
24 |
-
import gc
|
25 |
-
|
26 |
-
|
27 |
-
os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
|
28 |
-
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
29 |
-
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
30 |
-
|
31 |
-
|
32 |
-
def compute_class_weights(targets):
|
33 |
-
num_binding_sites = targets.sum()
|
34 |
-
num_non_binding_sites = targets.numel() - num_binding_sites
|
35 |
-
total = num_binding_sites + num_non_binding_sites
|
36 |
-
weight_for_binding = total / (2 * num_binding_sites)
|
37 |
-
weight_for_non_binding = total / (2 * num_non_binding_sites)
|
38 |
-
return torch.tensor([weight_for_non_binding, weight_for_binding])
|
39 |
-
|
40 |
-
|
41 |
-
def collate_fn(batch):
|
42 |
-
# Unpack the batch
|
43 |
-
anchors = []
|
44 |
-
positives = []
|
45 |
-
# negatives = []
|
46 |
-
binding_sites = []
|
47 |
-
|
48 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
49 |
-
|
50 |
-
for b in batch:
|
51 |
-
anchors.append(b['anchors'])
|
52 |
-
positives.append(b['positives'])
|
53 |
-
# negatives.append(b['negatives'])
|
54 |
-
binding_sites.append(b['binding_site'])
|
55 |
-
|
56 |
-
# Collate the tensors using torch's pad_sequence
|
57 |
-
anchor_input_ids = torch.nn.utils.rnn.pad_sequence(
|
58 |
-
[torch.Tensor(item['input_ids']).squeeze(0) for item in anchors], batch_first=True, padding_value=tokenizer.pad_token_id)
|
59 |
-
anchor_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
60 |
-
[torch.Tensor(item['attention_mask']).squeeze(0) for item in anchors], batch_first=True, padding_value=0)
|
61 |
-
|
62 |
-
positive_input_ids = torch.nn.utils.rnn.pad_sequence(
|
63 |
-
[torch.Tensor(item['input_ids']).squeeze(0) for item in positives], batch_first=True, padding_value=tokenizer.pad_token_id)
|
64 |
-
positive_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
65 |
-
[torch.Tensor(item['attention_mask']).squeeze(0) for item in positives], batch_first=True, padding_value=0)
|
66 |
-
|
67 |
-
# negative_input_ids = torch.nn.utils.rnn.pad_sequence(
|
68 |
-
# [torch.Tensor(item['input_ids']).squeeze(0) for item in negatives], batch_first=True, padding_value=tokenizer.pad_token_id)
|
69 |
-
# negative_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
70 |
-
# [torch.Tensor(item['attention_mask']).squeeze(0) for item in negatives], batch_first=True, padding_value=0)
|
71 |
-
|
72 |
-
# assert anchor_input_ids.shape == negative_input_ids.shape
|
73 |
-
|
74 |
-
n, max_length = anchor_input_ids.shape[0], anchor_input_ids.shape[1]
|
75 |
-
site = torch.zeros(n, max_length)
|
76 |
-
for i in range(len(binding_sites)):
|
77 |
-
binding_site = binding_sites[i]
|
78 |
-
site[i, binding_site] = 1
|
79 |
-
|
80 |
-
# Return the collated batch
|
81 |
-
return {
|
82 |
-
'anchor_input_ids': anchor_input_ids.int(),
|
83 |
-
'anchor_attention_mask': anchor_attention_mask.int(),
|
84 |
-
'positive_input_ids': positive_input_ids.int(),
|
85 |
-
'positive_attention_mask': positive_attention_mask.int(),
|
86 |
-
# 'negative_input_ids': negative_input_ids.int(),
|
87 |
-
# 'negative_attention_mask': negative_attention_mask.int(),
|
88 |
-
'binding_site': site
|
89 |
-
}
|
90 |
-
|
91 |
-
|
92 |
-
class CustomDataModule(pl.LightningDataModule):
|
93 |
-
def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
|
94 |
-
super().__init__()
|
95 |
-
self.train_dataset = train_dataset
|
96 |
-
self.val_dataset = val_dataset
|
97 |
-
self.batch_size = batch_size
|
98 |
-
self.tokenizer = tokenizer
|
99 |
-
|
100 |
-
def train_dataloader(self):
|
101 |
-
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
|
102 |
-
num_workers=8, pin_memory=True)
|
103 |
-
|
104 |
-
def val_dataloader(self):
|
105 |
-
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
|
106 |
-
pin_memory=True)
|
107 |
-
|
108 |
-
def setup(self, stage=None):
|
109 |
-
if stage == 'test' or stage is None:
|
110 |
-
test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/test_dataset_static')
|
111 |
-
self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
112 |
-
num_workers=8, pin_memory=True)
|
113 |
-
|
114 |
-
|
115 |
-
class CosineAnnealingWithWarmup(_LRScheduler):
|
116 |
-
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
|
117 |
-
self.warmup_steps = warmup_steps
|
118 |
-
self.total_steps = total_steps
|
119 |
-
self.base_lr = base_lr
|
120 |
-
self.max_lr = max_lr
|
121 |
-
self.min_lr = min_lr
|
122 |
-
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
|
123 |
-
print(f"SELF BASE LRS = {self.base_lrs}")
|
124 |
-
|
125 |
-
def get_lr(self):
|
126 |
-
if self.last_epoch < self.warmup_steps:
|
127 |
-
# Linear warmup phase from base_lr to max_lr
|
128 |
-
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
|
129 |
-
|
130 |
-
# Cosine annealing phase from max_lr to min_lr
|
131 |
-
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
132 |
-
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
133 |
-
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
|
134 |
-
|
135 |
-
return [decayed_lr for base_lr in self.base_lrs]
|
136 |
-
|
137 |
-
class PeptideModel(pl.LightningModule):
|
138 |
-
def __init__(self, n_layers, d_model, n_head,
|
139 |
-
d_k, d_v, d_inner, dropout=0.2,
|
140 |
-
learning_rate=0.00001, max_epochs=15):
|
141 |
-
super(PeptideModel, self).__init__()
|
142 |
-
|
143 |
-
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
144 |
-
# freeze all the esm_model parameters
|
145 |
-
for param in self.esm_model.parameters():
|
146 |
-
param.requires_grad = False
|
147 |
-
|
148 |
-
self.repeated_module = RepeatedModule2(n_layers, d_model,
|
149 |
-
n_head, d_k, d_v, d_inner, dropout=dropout)
|
150 |
-
|
151 |
-
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
|
152 |
-
d_k, d_v, dropout=dropout)
|
153 |
-
|
154 |
-
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
|
155 |
-
|
156 |
-
self.output_projection_prot = nn.Linear(d_model, 1)
|
157 |
-
|
158 |
-
self.learning_rate = learning_rate
|
159 |
-
self.max_epochs = max_epochs
|
160 |
-
|
161 |
-
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
|
162 |
-
self.historical_memory = 0.9
|
163 |
-
self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
|
164 |
-
|
165 |
-
def forward(self, binder_tokens, target_tokens):
|
166 |
-
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
|
167 |
-
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
|
168 |
-
|
169 |
-
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
|
170 |
-
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
|
171 |
-
protein_sequence)
|
172 |
-
|
173 |
-
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
|
174 |
-
|
175 |
-
prot_enc = self.final_ffn(prot_enc)
|
176 |
-
|
177 |
-
prot_enc = self.output_projection_prot(prot_enc)
|
178 |
-
|
179 |
-
return prot_enc
|
180 |
-
|
181 |
-
def training_step(self, batch, batch_idx):
|
182 |
-
opt = self.optimizers()
|
183 |
-
lr = opt.param_groups[0]['lr']
|
184 |
-
self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
185 |
-
|
186 |
-
target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
|
187 |
-
'attention_mask': batch['anchor_attention_mask'].to(self.device)}
|
188 |
-
binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
|
189 |
-
'attention_mask': batch['positive_attention_mask'].to(self.device)}
|
190 |
-
binding_site = batch['binding_site'].to(self.device)
|
191 |
-
mask = target_tokens['attention_mask']
|
192 |
-
|
193 |
-
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
|
194 |
-
|
195 |
-
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
|
196 |
-
loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
|
197 |
-
|
198 |
-
masked_loss = loss * mask
|
199 |
-
mean_loss = masked_loss.sum() / mask.sum()
|
200 |
-
|
201 |
-
# print('logging')
|
202 |
-
self.log('train_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
203 |
-
return mean_loss
|
204 |
-
|
205 |
-
def validation_step(self, batch, batch_idx):
|
206 |
-
target_tokens = {'input_ids': batch['anchor_input_ids'].to(self.device),
|
207 |
-
'attention_mask': batch['anchor_attention_mask'].to(self.device)}
|
208 |
-
binder_tokens = {'input_ids': batch['positive_input_ids'].to(self.device),
|
209 |
-
'attention_mask': batch['positive_attention_mask'].to(self.device)}
|
210 |
-
binding_site = batch['binding_site'].to(self.device)
|
211 |
-
mask = target_tokens['attention_mask']
|
212 |
-
|
213 |
-
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
|
214 |
-
|
215 |
-
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
|
216 |
-
loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
|
217 |
-
|
218 |
-
# Apply the mask to the loss
|
219 |
-
masked_loss = loss * mask
|
220 |
-
|
221 |
-
# Compute the mean loss only over the valid positions
|
222 |
-
mean_loss = masked_loss.sum() / mask.sum()
|
223 |
-
|
224 |
-
# Calculate predictions and apply mask
|
225 |
-
sigmoid_outputs = torch.sigmoid(outputs_nodes)
|
226 |
-
total = mask.sum()
|
227 |
-
|
228 |
-
self.update_class_thresholds(sigmoid_outputs, binding_site, mask)
|
229 |
-
self.log('threshold', self.classification_threshold, on_epoch=True)
|
230 |
-
|
231 |
-
predict = (sigmoid_outputs >= self.classification_threshold).float()
|
232 |
-
correct = ((predict == binding_site) * mask).sum()
|
233 |
-
accuracy = correct / total
|
234 |
-
|
235 |
-
# Compute AUC
|
236 |
-
outputs_nodes_flat = sigmoid_outputs[mask.bool()].float().cpu().detach().numpy().flatten()
|
237 |
-
binding_site_flat = binding_site[mask.bool()].float().cpu().detach().numpy().flatten()
|
238 |
-
predictions_flat = predict[mask.bool()].float().cpu().detach().numpy().flatten()
|
239 |
-
|
240 |
-
auc = roc_auc_score(binding_site_flat, outputs_nodes_flat)
|
241 |
-
f1 = f1_score(binding_site_flat, predictions_flat)
|
242 |
-
mcc = matthews_corrcoef(binding_site_flat, predictions_flat)
|
243 |
-
|
244 |
-
self.log('val_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
245 |
-
self.log('val_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
246 |
-
self.log('val_auc', auc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
247 |
-
self.log('val_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
248 |
-
self.log('val_mcc', mcc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
249 |
-
|
250 |
-
def configure_optimizers(self):
|
251 |
-
print(f"MAX STEPS = {self.max_epochs}")
|
252 |
-
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
|
253 |
-
# schedulers = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=0.1*self.max_epochs,
|
254 |
-
# max_epochs=self.max_epochs,
|
255 |
-
# warmup_start_lr=5e-4,
|
256 |
-
# eta_min=0.1 * self.learning_rate)
|
257 |
-
|
258 |
-
base_lr = 1e-4
|
259 |
-
max_lr = self.learning_rate
|
260 |
-
min_lr = 0.1 * self.learning_rate
|
261 |
-
|
262 |
-
schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=692, total_steps=8293,
|
263 |
-
base_lr=base_lr, max_lr=max_lr, min_lr=min_lr)
|
264 |
-
|
265 |
-
lr_schedulers = {
|
266 |
-
"scheduler": schedulers,
|
267 |
-
"name": 'learning_rate_logs',
|
268 |
-
"interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
|
269 |
-
'frequency': 1 # The scheduler updates the learning rate after every batch
|
270 |
-
}
|
271 |
-
return [optimizer], [lr_schedulers]
|
272 |
-
|
273 |
-
def update_class_thresholds(self, inputs, targets, mask):
|
274 |
-
with torch.no_grad():
|
275 |
-
min_threshold_value = 0.001
|
276 |
-
thresholds = torch.arange(0.1, 1.0, 0.05, device=inputs.device)
|
277 |
-
|
278 |
-
best_f1_score = 0
|
279 |
-
best_threshold = min_threshold_value
|
280 |
-
|
281 |
-
for threshold in thresholds:
|
282 |
-
binary_predictions = (inputs >= threshold).float()
|
283 |
-
|
284 |
-
tp = ((binary_predictions * targets) * mask).sum().item()
|
285 |
-
fp = ((binary_predictions * (1 - targets)) * mask).sum().item()
|
286 |
-
fn = (((1 - binary_predictions) * targets) * mask).sum().item()
|
287 |
-
|
288 |
-
precision = tp / (tp + fp + 1e-7)
|
289 |
-
recall = tp / (tp + fn + 1e-7)
|
290 |
-
f1_score = 2 * precision * recall / (precision + recall + 1e-7)
|
291 |
-
|
292 |
-
if f1_score > best_f1_score:
|
293 |
-
best_f1_score = f1_score
|
294 |
-
best_threshold = threshold
|
295 |
-
|
296 |
-
updated_threshold = self.historical_memory * self.classification_threshold + (
|
297 |
-
1 - self.historical_memory) * best_threshold
|
298 |
-
self.classification_threshold = nn.Parameter(torch.clamp(updated_threshold, min=min_threshold_value))
|
299 |
-
gc.collect()
|
300 |
-
torch.cuda.empty_cache()
|
301 |
-
|
302 |
-
def training_epoch_end(self, outputs):
|
303 |
-
gc.collect()
|
304 |
-
torch.cuda.empty_cache()
|
305 |
-
super().training_epoch_end(outputs)
|
306 |
-
|
307 |
-
def validation_epoch_end(self, outputs):
|
308 |
-
gc.collect()
|
309 |
-
torch.cuda.empty_cache()
|
310 |
-
super().validation_epoch_end(outputs)
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
def main():
|
316 |
-
parser = ArgumentParser()
|
317 |
-
|
318 |
-
parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
|
319 |
-
parser.add_argument("-d", dest="dataset", required=False, type=str, default="pepnn",
|
320 |
-
help="Which dataset to train on, pepnn, pepbind, or interpep")
|
321 |
-
parser.add_argument("-lr", type=float, default=1e-3)
|
322 |
-
parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
|
323 |
-
parser.add_argument("-n_layers", type=int, default=6, help="Number of layers")
|
324 |
-
parser.add_argument("-d_model", type=int, default=64, help="Dimension of model")
|
325 |
-
parser.add_argument("-n_head", type=int, default=6, help="Number of heads")
|
326 |
-
parser.add_argument("-d_inner", type=int, default=64)
|
327 |
-
# parser.add_argument("-sm", dest="saved_model", help="File containing initial params", required=False, type=str,
|
328 |
-
# default=None)
|
329 |
-
parser.add_argument("-sm", default=None, help="File containing initial params", type=str)
|
330 |
-
parser.add_argument("--max_epochs", type=int, default=15, help="Max number of epochs to train")
|
331 |
-
args = parser.parse_args()
|
332 |
-
|
333 |
-
# Initialize the process group for distributed training
|
334 |
-
dist.init_process_group(backend='nccl')
|
335 |
-
|
336 |
-
train_dataset = load_from_disk('/home/tc415/muPPIt/dataset/train_dataset_drop_500')
|
337 |
-
val_dataset = load_from_disk('/home/tc415/muPPIt/dataset/val_dataset_drop_500')
|
338 |
-
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
339 |
-
|
340 |
-
data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
|
341 |
-
|
342 |
-
# Calculate the number of training steps and warm-up steps
|
343 |
-
train_dataloader = data_module.train_dataloader()
|
344 |
-
num_training_steps = len(train_dataloader) * args.max_epochs
|
345 |
-
num_warmup_steps = int(0.1 * num_training_steps) # Warm-up for 10% of training steps
|
346 |
-
|
347 |
-
model = PeptideModel(6, 64, 6, 64, 128, 64, dropout=0.2,
|
348 |
-
learning_rate=args.lr, max_epochs=num_training_steps)
|
349 |
-
if args.sm:
|
350 |
-
model = PeptideModel.load_from_checkpoint(args.sm,
|
351 |
-
n_layers=args.n_layers,
|
352 |
-
d_model=args.d_model,
|
353 |
-
n_head=args.n_head,
|
354 |
-
d_k=64,
|
355 |
-
d_v=128,
|
356 |
-
d_inner=64,
|
357 |
-
dropout=0.3,
|
358 |
-
learning_rate=args.lr,
|
359 |
-
max_epochs=args.max_epochs)
|
360 |
-
|
361 |
-
run_id = str(uuid.uuid4())
|
362 |
-
|
363 |
-
print("Classification Thresholds:")
|
364 |
-
print(model.classification_threshold)
|
365 |
-
|
366 |
-
logger = WandbLogger(project=f"bind_evaluator",
|
367 |
-
name=f"continue_lr={args.lr}_nlayers={args.n_layers}_dmodel={args.d_model}_nhead={args.n_head}_dinner={args.d_inner}",
|
368 |
-
# display on the web
|
369 |
-
# save_dir=f'./pl_logs/',
|
370 |
-
job_type='model-training',
|
371 |
-
id=run_id)
|
372 |
-
|
373 |
-
checkpoint_callback = ModelCheckpoint(
|
374 |
-
monitor='val_mcc',
|
375 |
-
dirpath=args.output_file,
|
376 |
-
filename='model-{epoch:02d}-{val_loss:.2f}',
|
377 |
-
save_top_k=1,
|
378 |
-
mode='max',
|
379 |
-
)
|
380 |
-
|
381 |
-
early_stopping_callback = EarlyStopping(
|
382 |
-
monitor='val_mcc',
|
383 |
-
patience=5,
|
384 |
-
verbose=True,
|
385 |
-
mode='max'
|
386 |
-
)
|
387 |
-
|
388 |
-
accumulator = GradientAccumulationScheduler(scheduling={0: 4, 2: 2, 7: 1})
|
389 |
-
|
390 |
-
trainer = pl.Trainer(
|
391 |
-
max_epochs=args.max_epochs,
|
392 |
-
accelerator='gpu',
|
393 |
-
strategy='ddp',
|
394 |
-
precision='bf16',
|
395 |
-
logger=logger,
|
396 |
-
devices=[0,1,2,3,4,5,6],
|
397 |
-
callbacks=[checkpoint_callback, accumulator, early_stopping_callback],
|
398 |
-
gradient_clip_val=1.0
|
399 |
-
)
|
400 |
-
|
401 |
-
trainer.fit(model, datamodule=data_module)
|
402 |
-
|
403 |
-
best_model_path = checkpoint_callback.best_model_path
|
404 |
-
print(best_model_path)
|
405 |
-
|
406 |
-
|
407 |
-
if __name__ == "__main__":
|
408 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|