Create train.py
Browse files
train.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
from pytorch_lightning.strategies import DDPStrategy
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.utils.data import DataLoader, DistributedSampler, BatchSampler, Sampler
|
7 |
+
from datasets import load_from_disk
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
|
10 |
+
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
|
11 |
+
from pytorch_lightning.loggers import WandbLogger
|
12 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
13 |
+
from transformers.optimization import get_cosine_schedule_with_warmup
|
14 |
+
from argparse import ArgumentParser
|
15 |
+
import os
|
16 |
+
import uuid
|
17 |
+
import esm
|
18 |
+
import numpy as np
|
19 |
+
import torch.distributed as dist
|
20 |
+
from torch.nn.utils.rnn import pad_sequence
|
21 |
+
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
22 |
+
# from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
23 |
+
from torch.optim import Adam, AdamW
|
24 |
+
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
|
25 |
+
import gc
|
26 |
+
|
27 |
+
from models.graph import ProteinGraph
|
28 |
+
from models.modules_vec import IntraGraphAttention, DiffEmbeddingLayer, MIM, CrossGraphAttention
|
29 |
+
|
30 |
+
os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
|
31 |
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
32 |
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def collate_fn(batch):
|
37 |
+
# Unpack the batch
|
38 |
+
binders = []
|
39 |
+
mutants = []
|
40 |
+
wildtypes = []
|
41 |
+
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
43 |
+
|
44 |
+
for b in batch:
|
45 |
+
binders.append(torch.tensor(b['binder_tokens']).squeeze(0)) # shape: 1*L1 -> L1
|
46 |
+
mutants.append(torch.tensor(b['mutant_tokens']).squeeze(0)) # shape: 1*L2 -> L2
|
47 |
+
wildtypes.append(torch.tensor(b['wildtype_tokens']).squeeze(0)) # shape: 1*L3 -> L3
|
48 |
+
|
49 |
+
# Collate the tensors using torch's pad_sequence
|
50 |
+
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
|
51 |
+
|
52 |
+
mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id)
|
53 |
+
|
54 |
+
wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id)
|
55 |
+
|
56 |
+
# Return the collated batch
|
57 |
+
return {
|
58 |
+
'binder_input_ids': binder_input_ids.int(),
|
59 |
+
'mutant_input_ids': mutant_input_ids.int(),
|
60 |
+
'wildtype_input_ids': wildtype_input_ids.int(),
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
class LengthAwareDistributedSampler(DistributedSampler):
|
65 |
+
def __init__(self, dataset, key, batch_size, num_replicas=None, rank=None, shuffle=True):
|
66 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
67 |
+
self.dataset = dataset
|
68 |
+
self.key = key
|
69 |
+
self.batch_size = batch_size
|
70 |
+
|
71 |
+
# Sort indices by the length of the mutant sequence
|
72 |
+
self.indices = sorted(range(len(self.dataset)), key=lambda i: len(self.dataset[i][key]))
|
73 |
+
|
74 |
+
def __iter__(self):
|
75 |
+
# Divide indices among replicas
|
76 |
+
indices = self.indices[self.rank::self.num_replicas]
|
77 |
+
|
78 |
+
if self.shuffle:
|
79 |
+
torch.manual_seed(self.epoch)
|
80 |
+
indices = torch.randperm(len(indices)).tolist()
|
81 |
+
|
82 |
+
# Yield indices in batches
|
83 |
+
for i in range(0, len(indices), self.batch_size):
|
84 |
+
yield indices[i:i+self.batch_size]
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.indices) // self.num_replicas
|
88 |
+
|
89 |
+
def set_epoch(self, epoch):
|
90 |
+
self.epoch = epoch
|
91 |
+
|
92 |
+
|
93 |
+
class CustomDataModule(pl.LightningDataModule):
|
94 |
+
def __init__(self, train_dataset, val_dataset, tokenizer, batch_size: int = 128):
|
95 |
+
super().__init__()
|
96 |
+
self.train_dataset = train_dataset
|
97 |
+
self.val_dataset = val_dataset
|
98 |
+
self.batch_size = batch_size
|
99 |
+
self.tokenizer = tokenizer
|
100 |
+
|
101 |
+
def train_dataloader(self):
|
102 |
+
# batch_sampler = LengthAwareDistributedSampler(self.train_dataset, 'mutant_tokens', self.batch_size)
|
103 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn,
|
104 |
+
num_workers=8, pin_memory=True)
|
105 |
+
|
106 |
+
def val_dataloader(self):
|
107 |
+
# batch_sampler = LengthAwareDistributedSampler(self.val_dataset, 'mutant_tokens', self.batch_size)
|
108 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn, num_workers=8,
|
109 |
+
pin_memory=True)
|
110 |
+
|
111 |
+
def setup(self, stage=None):
|
112 |
+
if stage == 'test' or stage is None:
|
113 |
+
pass
|
114 |
+
|
115 |
+
|
116 |
+
class CosineAnnealingWithWarmup(_LRScheduler):
|
117 |
+
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
|
118 |
+
self.warmup_steps = warmup_steps
|
119 |
+
self.total_steps = total_steps
|
120 |
+
self.base_lr = base_lr
|
121 |
+
self.max_lr = max_lr
|
122 |
+
self.min_lr = min_lr
|
123 |
+
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
|
124 |
+
print(f"SELF BASE LRS = {self.base_lrs}")
|
125 |
+
|
126 |
+
def get_lr(self):
|
127 |
+
if self.last_epoch < self.warmup_steps:
|
128 |
+
# Linear warmup phase from base_lr to max_lr
|
129 |
+
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
|
130 |
+
|
131 |
+
# Cosine annealing phase from max_lr to min_lr
|
132 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
133 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
134 |
+
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
|
135 |
+
|
136 |
+
return [decayed_lr for base_lr in self.base_lrs]
|
137 |
+
|
138 |
+
class muPPIt(pl.LightningModule):
|
139 |
+
def __init__(self, d_node, d_edge, d_cross_edge, d_position, num_heads,
|
140 |
+
num_intra_layers, num_mim_layers, num_cross_layers, lr, delta=1.0):
|
141 |
+
super(muPPIt, self).__init__()
|
142 |
+
|
143 |
+
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
144 |
+
for param in self.esm.parameters():
|
145 |
+
param.requires_grad = False
|
146 |
+
|
147 |
+
self.graph = ProteinGraph(d_node, d_edge, d_position)
|
148 |
+
|
149 |
+
self.intra_graph_att_layers = nn.ModuleList([
|
150 |
+
IntraGraphAttention(d_node, d_edge, num_heads) for _ in range(num_intra_layers)
|
151 |
+
])
|
152 |
+
|
153 |
+
self.diff_layer = DiffEmbeddingLayer(d_node)
|
154 |
+
|
155 |
+
self.mim_layers = nn.ModuleList([
|
156 |
+
MIM(d_node, d_edge, d_node, num_heads) for _ in range(num_mim_layers)
|
157 |
+
])
|
158 |
+
|
159 |
+
self.cross_graph_att_layers = nn.ModuleList([
|
160 |
+
CrossGraphAttention(d_node, d_cross_edge, d_node, num_heads) for _ in range(num_cross_layers)
|
161 |
+
])
|
162 |
+
|
163 |
+
self.cross_graph_edge_mapping = nn.Linear(1, d_cross_edge)
|
164 |
+
self.mapping = nn.Linear(d_cross_edge, 1)
|
165 |
+
|
166 |
+
self.d_cross_edge = d_cross_edge
|
167 |
+
self.learning_rate = lr
|
168 |
+
self.delta = delta
|
169 |
+
|
170 |
+
def forward(self, binder_tokens, wt_tokens, mut_tokens):
|
171 |
+
device = binder_tokens.device
|
172 |
+
|
173 |
+
# Construct Graph
|
174 |
+
# print("Graph")
|
175 |
+
|
176 |
+
binder_node, binder_edge, binder_node_mask, binder_edge_mask = self.graph(binder_tokens, self.esm, self.alphabet)
|
177 |
+
wt_node, wt_edge, wt_node_mask, wt_edge_mask = self.graph(wt_tokens, self.esm, self.alphabet)
|
178 |
+
mut_node, mut_edge, mut_node_mask, mut_edge_mask = self.graph(mut_tokens, self.esm, self.alphabet)
|
179 |
+
|
180 |
+
# Intra-Graph Attention
|
181 |
+
# print("Intra Graph")
|
182 |
+
for layer in self.intra_graph_att_layers:
|
183 |
+
binder_node, binder_edge = layer(binder_node, binder_edge)
|
184 |
+
binder_node = binder_node * binder_node_mask.unsqueeze(-1)
|
185 |
+
binder_edge = binder_edge * binder_edge_mask.unsqueeze(-1)
|
186 |
+
|
187 |
+
wt_node, wt_edge = layer(wt_node, wt_edge)
|
188 |
+
wt_node = wt_node * wt_node_mask.unsqueeze(-1)
|
189 |
+
wt_edge = wt_edge * wt_edge_mask.unsqueeze(-1)
|
190 |
+
|
191 |
+
mut_node, mut_edge = layer(mut_node, mut_edge)
|
192 |
+
mut_node = mut_node * mut_node_mask.unsqueeze(-1)
|
193 |
+
mut_edge = mut_edge * mut_edge_mask.unsqueeze(-1)
|
194 |
+
|
195 |
+
# Differential Embedding Layer
|
196 |
+
# print("Diff")
|
197 |
+
diff_vec = self.diff_layer(wt_node, mut_node)
|
198 |
+
|
199 |
+
# Mutation Impact Module
|
200 |
+
# print("MIM")
|
201 |
+
for layer in self.mim_layers:
|
202 |
+
wt_node, wt_edge = layer(wt_node, wt_edge, diff_vec)
|
203 |
+
wt_node = wt_node * wt_node_mask.unsqueeze(-1)
|
204 |
+
wt_edge = wt_edge * wt_edge_mask.unsqueeze(-1)
|
205 |
+
|
206 |
+
mut_node, mut_edge = layer(mut_node, mut_edge, diff_vec)
|
207 |
+
mut_node = mut_node * mut_node_mask.unsqueeze(-1)
|
208 |
+
mut_edge = mut_edge * mut_edge_mask.unsqueeze(-1)
|
209 |
+
|
210 |
+
# Initialize cross-graph edges
|
211 |
+
B = mut_node.shape[0]
|
212 |
+
L_mut = mut_node.shape[1]
|
213 |
+
L_wt = wt_node.shape[1]
|
214 |
+
L_binder = binder_node.shape[1]
|
215 |
+
|
216 |
+
mut_binder_edges = torch.randn(B, L_mut, L_binder, self.d_cross_edge).to(device)
|
217 |
+
wt_binder_edges = torch.randn(B, L_wt, L_binder, self.d_cross_edge).to(device)
|
218 |
+
|
219 |
+
mut_binder_mask = mut_node_mask.unsqueeze(-1) * binder_node_mask.unsqueeze(1).to(device)
|
220 |
+
wt_binder_mask = wt_node_mask.unsqueeze(-1) * binder_node_mask.unsqueeze(1).to(device)
|
221 |
+
|
222 |
+
# pdb.set_trace()
|
223 |
+
|
224 |
+
# Cross-Graph Attention
|
225 |
+
# print("Cross")
|
226 |
+
for layer in self.cross_graph_att_layers:
|
227 |
+
wt_node, binder_node, wt_binder_edges = layer(wt_node, binder_node, wt_binder_edges, diff_vec)
|
228 |
+
wt_node = wt_node * wt_node_mask.unsqueeze(-1)
|
229 |
+
binder_node = binder_node * binder_node_mask.unsqueeze(-1)
|
230 |
+
wt_binder_edges = wt_binder_edges * wt_binder_mask.unsqueeze(-1)
|
231 |
+
|
232 |
+
mut_node, binder_node, mut_binder_edges = layer(mut_node, binder_node, mut_binder_edges, diff_vec)
|
233 |
+
mut_node = mut_node * mut_node_mask.unsqueeze(-1)
|
234 |
+
binder_node = binder_node * binder_node_mask.unsqueeze(-1)
|
235 |
+
mut_binder_edges = mut_binder_edges * mut_binder_mask.unsqueeze(-1)
|
236 |
+
|
237 |
+
wt_binder_edges = torch.mean(wt_binder_edges, dim=(1,2))
|
238 |
+
mut_binder_edges = torch.mean(mut_binder_edges, dim=(1,2))
|
239 |
+
|
240 |
+
wt_pred = torch.sigmoid(self.mapping(wt_binder_edges))
|
241 |
+
mut_pred = torch.sigmoid(self.mapping(mut_binder_edges))
|
242 |
+
|
243 |
+
return wt_pred, mut_pred
|
244 |
+
|
245 |
+
def training_step(self, batch, batch_idx):
|
246 |
+
opt = self.optimizers()
|
247 |
+
lr = opt.param_groups[0]['lr']
|
248 |
+
self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
249 |
+
|
250 |
+
binder_tokens = batch['binder_input_ids'].to(self.device)
|
251 |
+
mut_tokens = batch['mutant_input_ids'].to(self.device)
|
252 |
+
wt_tokens = batch['wildtype_input_ids'].to(self.device)
|
253 |
+
|
254 |
+
wt_pred, mut_pred = self.forward(binder_tokens, wt_tokens, mut_tokens)
|
255 |
+
|
256 |
+
wt_loss = (torch.relu(mut_pred) ** 2).mean()
|
257 |
+
mut_loss = (torch.relu(1 - wt_pred) ** 2).mean()
|
258 |
+
loss = wt_loss + mut_loss
|
259 |
+
|
260 |
+
# pdb.set_trace()
|
261 |
+
self.log('train_wt_loss', wt_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
262 |
+
self.log('train_mut_loss', mut_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
263 |
+
self.log('train_loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
264 |
+
return loss
|
265 |
+
|
266 |
+
|
267 |
+
def validation_step(self, batch, batch_idx):
|
268 |
+
binder_tokens = batch['binder_input_ids'].to(self.device)
|
269 |
+
mut_tokens = batch['mutant_input_ids'].to(self.device)
|
270 |
+
wt_tokens = batch['wildtype_input_ids'].to(self.device)
|
271 |
+
|
272 |
+
wt_pred, mut_pred = self.forward(binder_tokens, wt_tokens, mut_tokens)
|
273 |
+
|
274 |
+
wt_loss = (torch.relu(mut_pred) ** 2).mean()
|
275 |
+
mut_loss = (torch.relu(1 - wt_pred) ** 2).mean()
|
276 |
+
loss = wt_loss + mut_loss
|
277 |
+
|
278 |
+
self.log('val_wt_loss', wt_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
279 |
+
self.log('val_mut_loss', mut_loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
|
280 |
+
self.log('val_loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
281 |
+
|
282 |
+
def configure_optimizers(self):
|
283 |
+
optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95))
|
284 |
+
|
285 |
+
base_lr = 1e-5
|
286 |
+
max_lr = self.learning_rate
|
287 |
+
min_lr = 0.1 * self.learning_rate
|
288 |
+
|
289 |
+
schedulers = CosineAnnealingWithWarmup(optimizer, warmup_steps=600, total_steps=15390,
|
290 |
+
base_lr=base_lr, max_lr=max_lr, min_lr=min_lr)
|
291 |
+
|
292 |
+
lr_schedulers = {
|
293 |
+
"scheduler": schedulers,
|
294 |
+
"name": 'learning_rate_logs',
|
295 |
+
"interval": 'step', # The scheduler updates the learning rate at every step (not epoch)
|
296 |
+
'frequency': 1 # The scheduler updates the learning rate after every batch
|
297 |
+
}
|
298 |
+
return [optimizer], [lr_schedulers]
|
299 |
+
|
300 |
+
def on_training_epoch_end(self, outputs):
|
301 |
+
gc.collect()
|
302 |
+
torch.cuda.empty_cache()
|
303 |
+
super().training_epoch_end(outputs)
|
304 |
+
|
305 |
+
# def on_validation_epoch_end(self, outputs):
|
306 |
+
# gc.collect()
|
307 |
+
# torch.cuda.empty_cache()
|
308 |
+
# super().validation_epoch_end(outputs)
|
309 |
+
|
310 |
+
|
311 |
+
def main():
|
312 |
+
parser = ArgumentParser()
|
313 |
+
|
314 |
+
parser.add_argument("-o", dest="output_file", help="File for output of model parameters", required=True, type=str)
|
315 |
+
parser.add_argument("-lr", type=float, default=1e-3)
|
316 |
+
parser.add_argument("-batch_size", type=int, default=2, help="Batch size")
|
317 |
+
parser.add_argument("-d_node", type=int, default=1024, help="Node Representation Dimension")
|
318 |
+
parser.add_argument("-d_edge", type=int, default=512, help="Intra-Graph Edge Representation Dimension")
|
319 |
+
parser.add_argument("-d_cross_edge", type=int, default=512, help="Cross-Graph Edge Representation Dimension")
|
320 |
+
parser.add_argument("-d_position", type=int, default=8, help="Positional Embedding Dimension")
|
321 |
+
parser.add_argument("-n_heads", type=int, default=8)
|
322 |
+
parser.add_argument("-n_intra_layers", type=int, default=1)
|
323 |
+
parser.add_argument("-n_mim_layers", type=int, default=1)
|
324 |
+
parser.add_argument("-n_cross_layers", type=int, default=1)
|
325 |
+
parser.add_argument("-sm", default=None, help="File containing initial params", type=str)
|
326 |
+
parser.add_argument("-max_epochs", type=int, default=15, help="Max number of epochs to train")
|
327 |
+
parser.add_argument("-dropout", type=float, default=0.2)
|
328 |
+
parser.add_argument("-grad_clip", type=float, default=0.5)
|
329 |
+
parser.add_argument("-delta", type=float, default=1)
|
330 |
+
args = parser.parse_args()
|
331 |
+
|
332 |
+
# Initialize the process group for distributed training
|
333 |
+
dist.init_process_group(backend='nccl')
|
334 |
+
|
335 |
+
train_dataset = load_from_disk('/home/tc415/muPPIt/dataset/train/ppiref')
|
336 |
+
val_dataset = load_from_disk('/home/tc415/muPPIt/dataset/val/ppiref')
|
337 |
+
# val_dataset = None
|
338 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
339 |
+
|
340 |
+
data_module = CustomDataModule(train_dataset, val_dataset, tokenizer=tokenizer, batch_size=args.batch_size)
|
341 |
+
|
342 |
+
model = muPPIt(args.d_node, args.d_edge, args.d_cross_edge, args.d_position, args.n_heads,
|
343 |
+
args.n_intra_layers, args.n_mim_layers, args.n_cross_layers, args.lr, args.delta)
|
344 |
+
if args.sm:
|
345 |
+
model = muPPIt.load_from_checkpoint(args.sm,args.d_node, args.d_edge, args.d_cross_edge, args.d_position, args.n_heads,
|
346 |
+
args.n_intra_layers, args.n_mim_layers, args.n_cross_layers, args.lr, args.delta)
|
347 |
+
else:
|
348 |
+
print("Train from scratch!")
|
349 |
+
|
350 |
+
run_id = str(uuid.uuid4())
|
351 |
+
|
352 |
+
logger = WandbLogger(project=f"muppit",
|
353 |
+
name="debug",
|
354 |
+
# name=f"lr={args.lr}_dnode={args.d_node}_dedge={args.d_edge}_dcross={args.d_cross_edge}_dposition={args.d_position}",
|
355 |
+
job_type='model-training',
|
356 |
+
id=run_id)
|
357 |
+
|
358 |
+
checkpoint_callback = ModelCheckpoint(
|
359 |
+
monitor='val_loss',
|
360 |
+
dirpath=args.output_file,
|
361 |
+
filename='model-{epoch:02d}-{val_mcc:.2f}',
|
362 |
+
save_top_k=-1,
|
363 |
+
mode='max',
|
364 |
+
)
|
365 |
+
|
366 |
+
early_stopping_callback = EarlyStopping(
|
367 |
+
monitor='val_mcc',
|
368 |
+
patience=5,
|
369 |
+
verbose=True,
|
370 |
+
mode='max'
|
371 |
+
)
|
372 |
+
|
373 |
+
accumulator = GradientAccumulationScheduler(scheduling={0: 8, 3: 4, 20: 2})
|
374 |
+
|
375 |
+
trainer = pl.Trainer(
|
376 |
+
max_epochs=args.max_epochs,
|
377 |
+
accelerator='gpu',
|
378 |
+
strategy='ddp_find_unused_parameters_true',
|
379 |
+
precision='bf16',
|
380 |
+
# logger=logger,
|
381 |
+
devices=[0,1,2],
|
382 |
+
callbacks=[checkpoint_callback, accumulator, early_stopping_callback],
|
383 |
+
gradient_clip_val=args.grad_clip,
|
384 |
+
# num_sanity_val_steps=0
|
385 |
+
)
|
386 |
+
|
387 |
+
trainer.fit(model, datamodule=data_module)
|
388 |
+
|
389 |
+
best_model_path = checkpoint_callback.best_model_path
|
390 |
+
print(best_model_path)
|
391 |
+
|
392 |
+
|
393 |
+
if __name__ == "__main__":
|
394 |
+
main()
|