AlienChen commited on
Commit
ed7b048
·
verified ·
1 Parent(s): 0977aa0

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +394 -0
train.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from pytorch_lightning.strategies import DDPStrategy
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import DataLoader, DistributedSampler, BatchSampler, Sampler
7
+ from datasets import load_from_disk
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
10
+ Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
11
+ from pytorch_lightning.loggers import WandbLogger
12
+ from torch.optim.lr_scheduler import _LRScheduler
13
+ from transformers.optimization import get_cosine_schedule_with_warmup
14
+ from argparse import ArgumentParser
15
+ import os
16
+ import uuid
17
+ import esm
18
+ import numpy as np
19
+ import torch.distributed as dist
20
+ from torch.nn.utils.rnn import pad_sequence
21
+ from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
22
+ # from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
23
+ from torch.optim import Adam, AdamW
24
+ from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
25
+ import gc
26
+
27
+ from models.graph import ProteinGraph
28
+ from models.modules_vec import IntraGraphAttention, DiffEmbeddingLayer, MIM, CrossGraphAttention
29
+
30
+ os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
31
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
32
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
33
+
34
+
35
+
36
+ def collate_fn(batch):
37
+ # Unpack the batch
38
+ binders = []
39
+ mutants = []
40
+ wildtypes = []
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
43
+
44
+ for b in batch:
45
+ binders.append(torch.tensor(b['binder_tokens']).squeeze(0)) # shape: 1*L1 -> L1
46
+ mutants.append(torch.tensor(b['mutant_tokens']).squeeze(0)) # shape: 1*L2 -> L2
47
+ wildtypes.append(torch.tensor(b['wildtype_tokens']).squeeze(0)) # shape: 1*L3 -> L3
48
+
49
+ # Collate the tensors using torch's pad_sequence
50
+ binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
51
+
52
+ mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id)
53
+
54
+ wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id)
55
+
56
+ # Return the collated batch
57
+ return {
58
+ 'binder_input_ids': binder_input_ids.int(),
59
+ 'mutant_input_ids': mutant_input_ids.int(),
60
+ 'wildtype_input_ids': wildtype_input_ids.int(),
61
+ }
62
+
63
+
64
+ class LengthAwareDistributedSampler(DistributedSampler):
65
+ def __init__(self, dataset, key, batch_size, num_replicas=None, rank=None, shuffle=True):
66
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
67
+ self.dataset = dataset
68
+ self.key = key
69
+ self.batch_size = batch_size
70
+
71
+ # Sort indices by the length of the mutant sequence
72
+ self.indices = sorted(range(len(self.dataset)), key=lambda i: len(self.dataset[i][key]))
73
+
74
+ def __iter__(self):
75
+ # Divide indices among replicas
76
+ indices = self.indices[self.rank::self.num_replicas]
77
+
78
+ if self.shuffle:
79
+ torch.manual_seed(self.epoch)
80
+ indices = torch.randperm(len(indices)).tolist()
81
+
82
+ # Yield indices in batches
83
+ for i in range(0, len(indices), self.batch_size):
84
+ yield indices[i:i+self.batch_size]
85
+
86
+ def __len__(self):
87
+ return len(self.indices) // self.num_replicas
88
+
89
+ def set_epoch(self, epoch):
90
+ self.epoch = epoch
91
+
92
+
93
+ class CustomDataModule(pl.LightningDataModule):
94
+ def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
95
+ super().__init__()
96
+ self.train_dataset = train_dataset
97
+ self.val_dataset = val_dataset
98
+ self.batch_size = batch_size
99
+ self.tokenizer = tokenizer
100
+
101
+ def train_dataloader(self):
102
+ # batch_sampler = LengthAwareDistributedSampler(self.train_dataset, 'mutant_tokens', self.batch_size)
103
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
104
+ num_workers=8, pin_memory=True)
105
+
106
+ def val_dataloader(self):
107
+ # batch_sampler = LengthAwareDistributedSampler(self.val_dataset, 'mutant_tokens', self.batch_size)
108
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
109
+ pin_memory=True)
110
+
111
+ def setup(self, stage=None):
112
+ if stage == 'test' or stage is None:
113
+ pass
114
+
115
+
116
+ class CosineAnnealingWithWarmup(_LRScheduler):
117
+ def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
118
+ self.warmup_steps = warmup_steps
119
+ self.total_steps = total_steps
120
+ self.base_lr = base_lr
121
+ self.max_lr = max_lr
122
+ self.min_lr = min_lr
123
+ super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
124
+ print(f"SELF BASE LRS = {self.base_lrs}")
125
+
126
+ def get_lr(self):
127
+ if self.last_epoch < self.warmup_steps:
128
+ # Linear warmup phase from base_lr to max_lr
129
+ return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
130
+
131
+ # Cosine annealing phase from max_lr to min_lr
132
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
133
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
134
+ decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
135
+
136
+ return [decayed_lr for base_lr in self.base_lrs]
137
+
138
+ class muPPIt(pl.LightningModule):
139
+ def __init__(self, d_node, d_edge, d_cross_edge, d_position, num_heads,
140
+ num_intra_layers, num_mim_layers, num_cross_layers, lr, delta=1.0):
141
+ super(muPPIt, self).__init__()
142
+
143
+ self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
144
+ for param in self.esm.parameters():
145
+ param.requires_grad = False
146
+
147
+ self.graph = ProteinGraph(d_node, d_edge, d_position)
148
+
149
+ self.intra_graph_att_layers = nn.ModuleList([
150
+ IntraGraphAttention(d_node, d_edge, num_heads) for _ in range(num_intra_layers)
151
+ ])
152
+
153
+ self.diff_layer = DiffEmbeddingLayer(d_node)
154
+
155
+ self.mim_layers = nn.ModuleList([
156
+ MIM(d_node, d_edge, d_node, num_heads) for _ in range(num_mim_layers)
157
+ ])
158
+
159
+ self.cross_graph_att_layers = nn.ModuleList([
160
+ CrossGraphAttention(d_node, d_cross_edge, d_node, num_heads) for _ in range(num_cross_layers)
161
+ ])
162
+
163
+ self.cross_graph_edge_mapping = nn.Linear(1, d_cross_edge)
164
+ self.mapping = nn.Linear(d_cross_edge, 1)
165
+
166
+ self.d_cross_edge = d_cross_edge
167
+ self.learning_rate = lr
168
+ self.delta = delta
169
+
170
+ def forward(self, binder_tokens, wt_tokens, mut_tokens):
171
+ device = binder_tokens.device
172
+
173
+ # Construct Graph
174
+ # print("Graph")
175
+
176
+ binder_node, binder_edge, binder_node_mask, binder_edge_mask = self.graph(binder_tokens, self.esm, self.alphabet)
177
+ wt_node, wt_edge, wt_node_mask, wt_edge_mask = self.graph(wt_tokens, self.esm, self.alphabet)
178
+ mut_node, mut_edge, mut_node_mask, mut_edge_mask = self.graph(mut_tokens, self.esm, self.alphabet)
179
+
180
+ # Intra-Graph Attention
181
+ # print("Intra Graph")
182
+ for layer in self.intra_graph_att_layers:
183
+ binder_node, binder_edge = layer(binder_node, binder_edge)
184
+ binder_node = binder_node * binder_node_mask.unsqueeze(-1)
185
+ binder_edge = binder_edge * binder_edge_mask.unsqueeze(-1)
186
+
187
+ wt_node, wt_edge = layer(wt_node, wt_edge)
188
+ wt_node = wt_node * wt_node_mask.unsqueeze(-1)
189
+ wt_edge = wt_edge * wt_edge_mask.unsqueeze(-1)
190
+
191
+ mut_node, mut_edge = layer(mut_node, mut_edge)
192
+ mut_node = mut_node * mut_node_mask.unsqueeze(-1)
193
+ mut_edge = mut_edge * mut_edge_mask.unsqueeze(-1)
194
+
195
+ # Differential Embedding Layer
196
+ # print("Diff")
197
+ diff_vec = self.diff_layer(wt_node, mut_node)
198
+
199
+ # Mutation Impact Module
200
+ # print("MIM")
201
+ for layer in self.mim_layers:
202
+ wt_node, wt_edge = layer(wt_node, wt_edge, diff_vec)
203
+ wt_node = wt_node * wt_node_mask.unsqueeze(-1)
204
+ wt_edge = wt_edge * wt_edge_mask.unsqueeze(-1)
205
+
206
+ mut_node, mut_edge = layer(mut_node, mut_edge, diff_vec)
207
+ mut_node = mut_node * mut_node_mask.unsqueeze(-1)
208
+ mut_edge = mut_edge * mut_edge_mask.unsqueeze(-1)
209
+
210
+ # Initialize cross-graph edges
211
+ B = mut_node.shape[0]
212
+ L_mut = mut_node.shape[1]
213
+ L_wt = wt_node.shape[1]
214
+ L_binder = binder_node.shape[1]
215
+
216
+ mut_binder_edges = torch.randn(B, L_mut, L_binder, self.d_cross_edge).to(device)
217
+ wt_binder_edges = torch.randn(B, L_wt, L_binder, self.d_cross_edge).to(device)
218
+
219
+ mut_binder_mask = mut_node_mask.unsqueeze(-1) * binder_node_mask.unsqueeze(1).to(device)
220
+ wt_binder_mask = wt_node_mask.unsqueeze(-1) * binder_node_mask.unsqueeze(1).to(device)
221
+
222
+ # pdb.set_trace()
223
+
224
+ # Cross-Graph Attention
225
+ # print("Cross")
226
+ for layer in self.cross_graph_att_layers:
227
+ wt_node, binder_node, wt_binder_edges = layer(wt_node, binder_node, wt_binder_edges, diff_vec)
228
+ wt_node = wt_node * wt_node_mask.unsqueeze(-1)
229
+ binder_node = binder_node * binder_node_mask.unsqueeze(-1)
230
+ wt_binder_edges = wt_binder_edges * wt_binder_mask.unsqueeze(-1)
231
+
232
+ mut_node, binder_node, mut_binder_edges = layer(mut_node, binder_node, mut_binder_edges, diff_vec)
233
+ mut_node = mut_node * mut_node_mask.unsqueeze(-1)
234
+ binder_node = binder_node * binder_node_mask.unsqueeze(-1)
235
+ mut_binder_edges = mut_binder_edges * mut_binder_mask.unsqueeze(-1)
236
+
237
+ wt_binder_edges = torch.mean(wt_binder_edges, dim=(1,2))
238
+ mut_binder_edges = torch.mean(mut_binder_edges, dim=(1,2))
239
+
240
+ wt_pred = torch.sigmoid(self.mapping(wt_binder_edges))
241
+ mut_pred = torch.sigmoid(self.mapping(mut_binder_edges))
242
+
243
+ return wt_pred, mut_pred
244
+
245
+ def training_step(self, batch, batch_idx):
246
+ opt = self.optimizers()
247
+ lr = opt.param_groups[0]['lr']
248
+ self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
249
+
250
+ binder_tokens = batch['binder_input_ids'].to(self.device)
251
+ mut_tokens = batch['mutant_input_ids'].to(self.device)
252
+ wt_tokens = batch['wildtype_input_ids'].to(self.device)
253
+
254
+ wt_pred, mut_pred = self.forward(binder_tokens, wt_tokens, mut_tokens)
255
+
256
+ wt_loss = (torch.relu(mut_pred) ** 2).mean()
257
+ mut_loss = (torch.relu(1 - wt_pred) ** 2).mean()
258
+ loss = wt_loss + mut_loss
259
+
260
+ # pdb.set_trace()
261
+ self.log('train_wt_loss', wt_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
262
+ self.log('train_mut_loss', mut_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
263
+ self.log('train_loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
264
+ return loss
265
+
266
+
267
+ def validation_step(self, batch, batch_idx):
268
+ binder_tokens = batch['binder_input_ids'].to(self.device)
269
+ mut_tokens = batch['mutant_input_ids'].to(self.device)
270
+ wt_tokens = batch['wildtype_input_ids'].to(self.device)
271
+
272
+ wt_pred, mut_pred = self.forward(binder_tokens, wt_tokens, mut_tokens)
273
+
274
+ wt_loss = (torch.relu(mut_pred) ** 2).mean()
275
+ mut_loss = (torch.relu(1 - wt_pred) ** 2).mean()
276
+ loss = wt_loss + mut_loss
277
+
278
+ self.log('val_wt_loss', wt_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
279
+ self.log('val_mut_loss', mut_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
280
+ self.log('val_loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
281
+
282
+ def configure_optimizers(self):
283
+ optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
284
+
285
+ base_lr = 1e-5
286
+ max_lr = self.learning_rate
287
+ min_lr = 0.1 * self.learning_rate
288
+
289
+ schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=600, total_steps=15390,
290
+ base_lr=base_lr, max_lr=max_lr, min_lr=min_lr)
291
+
292
+ lr_schedulers = {
293
+ "scheduler": schedulers,
294
+ "name": 'learning_rate_logs',
295
+ "interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
296
+ 'frequency': 1 # The scheduler updates the learning rate after every batch
297
+ }
298
+ return [optimizer], [lr_schedulers]
299
+
300
+ def on_training_epoch_end(self, outputs):
301
+ gc.collect()
302
+ torch.cuda.empty_cache()
303
+ super().training_epoch_end(outputs)
304
+
305
+ # def on_validation_epoch_end(self, outputs):
306
+ # gc.collect()
307
+ # torch.cuda.empty_cache()
308
+ # super().validation_epoch_end(outputs)
309
+
310
+
311
+ def main():
312
+ parser = ArgumentParser()
313
+
314
+ parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
315
+ parser.add_argument("-lr", type=float, default=1e-3)
316
+ parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
317
+ parser.add_argument("-d_node", type=int, default=1024, help="Node Representation Dimension")
318
+ parser.add_argument("-d_edge", type=int, default=512, help="Intra-Graph Edge Representation Dimension")
319
+ parser.add_argument("-d_cross_edge", type=int, default=512, help="Cross-Graph Edge Representation Dimension")
320
+ parser.add_argument("-d_position", type=int, default=8, help="Positional Embedding Dimension")
321
+ parser.add_argument("-n_heads", type=int, default=8)
322
+ parser.add_argument("-n_intra_layers", type=int, default=1)
323
+ parser.add_argument("-n_mim_layers", type=int, default=1)
324
+ parser.add_argument("-n_cross_layers", type=int, default=1)
325
+ parser.add_argument("-sm", default=None, help="File containing initial params", type=str)
326
+ parser.add_argument("-max_epochs", type=int, default=15, help="Max number of epochs to train")
327
+ parser.add_argument("-dropout", type=float, default=0.2)
328
+ parser.add_argument("-grad_clip", type=float, default=0.5)
329
+ parser.add_argument("-delta", type=float, default=1)
330
+ args = parser.parse_args()
331
+
332
+ # Initialize the process group for distributed training
333
+ dist.init_process_group(backend='nccl')
334
+
335
+ train_dataset = load_from_disk('/home/tc415/muPPIt/dataset/train/ppiref')
336
+ val_dataset = load_from_disk('/home/tc415/muPPIt/dataset/val/ppiref')
337
+ # val_dataset = None
338
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
339
+
340
+ data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
341
+
342
+ model = muPPIt(args.d_node, args.d_edge, args.d_cross_edge, args.d_position, args.n_heads,
343
+ args.n_intra_layers, args.n_mim_layers, args.n_cross_layers, args.lr, args.delta)
344
+ if args.sm:
345
+ model = muPPIt.load_from_checkpoint(args.sm,args.d_node, args.d_edge, args.d_cross_edge, args.d_position, args.n_heads,
346
+ args.n_intra_layers, args.n_mim_layers, args.n_cross_layers, args.lr, args.delta)
347
+ else:
348
+ print("Train from scratch!")
349
+
350
+ run_id = str(uuid.uuid4())
351
+
352
+ logger = WandbLogger(project=f"muppit",
353
+ name="debug",
354
+ # name=f"lr={args.lr}_dnode={args.d_node}_dedge={args.d_edge}_dcross={args.d_cross_edge}_dposition={args.d_position}",
355
+ job_type='model-training',
356
+ id=run_id)
357
+
358
+ checkpoint_callback = ModelCheckpoint(
359
+ monitor='val_loss',
360
+ dirpath=args.output_file,
361
+ filename='model-{epoch:02d}-{val_mcc:.2f}',
362
+ save_top_k=-1,
363
+ mode='max',
364
+ )
365
+
366
+ early_stopping_callback = EarlyStopping(
367
+ monitor='val_mcc',
368
+ patience=5,
369
+ verbose=True,
370
+ mode='max'
371
+ )
372
+
373
+ accumulator = GradientAccumulationScheduler(scheduling={0: 8, 3: 4, 20: 2})
374
+
375
+ trainer = pl.Trainer(
376
+ max_epochs=args.max_epochs,
377
+ accelerator='gpu',
378
+ strategy='ddp_find_unused_parameters_true',
379
+ precision='bf16',
380
+ # logger=logger,
381
+ devices=[0,1,2],
382
+ callbacks=[checkpoint_callback, accumulator, early_stopping_callback],
383
+ gradient_clip_val=args.grad_clip,
384
+ # num_sanity_val_steps=0
385
+ )
386
+
387
+ trainer.fit(model, datamodule=data_module)
388
+
389
+ best_model_path = checkpoint_callback.best_model_path
390
+ print(best_model_path)
391
+
392
+
393
+ if __name__ == "__main__":
394
+ main()