Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from accelerate import Accelerator | |
from tqdm import tqdm | |
from typing import Callable | |
from helper.ema import EMA | |
class Trainer(): | |
def __init__(self, | |
model: nn.Module, | |
loss_fn: Callable, | |
ema: EMA = None, | |
optimizer: torch.optim.Optimizer = None, | |
scheduler: torch.optim.lr_scheduler = None, | |
start_epoch = 0, | |
best_loss = float("inf"), | |
accumulation_steps: int = 1, | |
max_grad_norm: float = 1.0): | |
self.accelerator = Accelerator(mixed_precision = 'fp16', gradient_accumulation_steps=accumulation_steps) | |
self.model = model.to(self.accelerator.device) | |
if ema is None: | |
self.ema = EMA(self.model).to(self.accelerator.device) | |
else: | |
self.ema = ema.to(self.accelerator.device) | |
self.loss_fn = loss_fn | |
self.optimizer = optimizer | |
if self.optimizer is None: | |
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = 1e-4) | |
self.scheduler = scheduler | |
if self.scheduler is None: | |
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=100) | |
self.start_epoch = start_epoch | |
self.best_loss = best_loss | |
self.accumulation_steps = accumulation_steps | |
self.max_grad_norm = max_grad_norm | |
def train(self, dl : DataLoader, epochs: int, file_name : str, no_label : bool = False): | |
self.model.train() | |
self.model, self.optimizer, data_loader, self.scheduler = self.accelerator.prepare( | |
self.model, self.optimizer, dl, self.scheduler | |
) | |
for epoch in range(self.start_epoch + 1, epochs + 1): | |
epoch_loss = 0.0 | |
progress_bar = tqdm(data_loader, leave=False, desc=f"Epoch {epoch}/{epochs}", colour="#005500", disable = not self.accelerator.is_local_main_process) | |
for step, batch in enumerate(progress_bar): | |
with self.accelerator.accumulate(self.model): # Context manager for accumulation | |
if no_label: | |
if isinstance(batch, list): | |
x = batch[0].to(self.accelerator.device) | |
else: | |
x = batch.to(self.accelerator.device) | |
else: | |
x, y = batch[0].to(self.accelerator.device), batch[1].to(self.accelerator.device) | |
with self.accelerator.autocast(): | |
if no_label: | |
loss = self.loss_fn(x) | |
else: | |
loss = self.loss_fn(x, y=y) | |
# Normalize the loss | |
self.accelerator.backward(loss) | |
# Gradient Clipping: | |
if self.max_grad_norm is not None and self.accelerator.sync_gradients: | |
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) | |
# Only step optimizer and scheduler when we have accumulated enough | |
self.optimizer.step() | |
self.ema.update() | |
self.optimizer.zero_grad() | |
epoch_loss += loss.item() | |
progress_bar.set_postfix(loss=epoch_loss / (min(step + 1, len(data_loader)))) # Correct progress bar update | |
self.accelerator.wait_for_everyone() | |
if self.accelerator.is_main_process: | |
epoch_loss = epoch_loss / len(progress_bar) | |
self.scheduler.step() | |
log_string = f"Loss at epoch {epoch}: {epoch_loss :.4f}" | |
# Save the best model | |
if self.best_loss > epoch_loss: | |
self.best_loss = epoch_loss | |
torch.save({ | |
"model_state_dict": self.accelerator.get_state_dict(self.model), | |
"ema_state_dict": self.ema.state_dict(), | |
"optimizer_state_dict": self.optimizer.state_dict(), | |
"scheduler_state_dict": self.scheduler.state_dict(), | |
"epoch": epoch, | |
"training_steps": epoch * len(dl), | |
"best_loss": self.best_loss, | |
"batch_size": dl.batch_size, | |
"number_of_batches": len(dl) | |
}, file_name + '.pth') | |
log_string += " --> Best model ever (stored)" | |
print(log_string) | |