Spaces:
Running
Running
File size: 4,716 Bytes
5ab5cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm import tqdm
from typing import Callable
from helper.ema import EMA
class Trainer():
def __init__(self,
model: nn.Module,
loss_fn: Callable,
ema: EMA = None,
optimizer: torch.optim.Optimizer = None,
scheduler: torch.optim.lr_scheduler = None,
start_epoch = 0,
best_loss = float("inf"),
accumulation_steps: int = 1,
max_grad_norm: float = 1.0):
self.accelerator = Accelerator(mixed_precision = 'fp16', gradient_accumulation_steps=accumulation_steps)
self.model = model.to(self.accelerator.device)
if ema is None:
self.ema = EMA(self.model).to(self.accelerator.device)
else:
self.ema = ema.to(self.accelerator.device)
self.loss_fn = loss_fn
self.optimizer = optimizer
if self.optimizer is None:
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = 1e-4)
self.scheduler = scheduler
if self.scheduler is None:
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=100)
self.start_epoch = start_epoch
self.best_loss = best_loss
self.accumulation_steps = accumulation_steps
self.max_grad_norm = max_grad_norm
def train(self, dl : DataLoader, epochs: int, file_name : str, no_label : bool = False):
self.model.train()
self.model, self.optimizer, data_loader, self.scheduler = self.accelerator.prepare(
self.model, self.optimizer, dl, self.scheduler
)
for epoch in range(self.start_epoch + 1, epochs + 1):
epoch_loss = 0.0
progress_bar = tqdm(data_loader, leave=False, desc=f"Epoch {epoch}/{epochs}", colour="#005500", disable = not self.accelerator.is_local_main_process)
for step, batch in enumerate(progress_bar):
with self.accelerator.accumulate(self.model): # Context manager for accumulation
if no_label:
if isinstance(batch, list):
x = batch[0].to(self.accelerator.device)
else:
x = batch.to(self.accelerator.device)
else:
x, y = batch[0].to(self.accelerator.device), batch[1].to(self.accelerator.device)
with self.accelerator.autocast():
if no_label:
loss = self.loss_fn(x)
else:
loss = self.loss_fn(x, y=y)
# Normalize the loss
self.accelerator.backward(loss)
# Gradient Clipping:
if self.max_grad_norm is not None and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
# Only step optimizer and scheduler when we have accumulated enough
self.optimizer.step()
self.ema.update()
self.optimizer.zero_grad()
epoch_loss += loss.item()
progress_bar.set_postfix(loss=epoch_loss / (min(step + 1, len(data_loader)))) # Correct progress bar update
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
epoch_loss = epoch_loss / len(progress_bar)
self.scheduler.step()
log_string = f"Loss at epoch {epoch}: {epoch_loss :.4f}"
# Save the best model
if self.best_loss > epoch_loss:
self.best_loss = epoch_loss
torch.save({
"model_state_dict": self.accelerator.get_state_dict(self.model),
"ema_state_dict": self.ema.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"epoch": epoch,
"training_steps": epoch * len(dl),
"best_loss": self.best_loss,
"batch_size": dl.batch_size,
"number_of_batches": len(dl)
}, file_name + '.pth')
log_string += " --> Best model ever (stored)"
print(log_string)
|