KoFace-AI / helper /trainer.py
JuyeopDang's picture
Upload 35 files
5ab5cab verified
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm import tqdm
from typing import Callable
from helper.ema import EMA
class Trainer():
def __init__(self,
model: nn.Module,
loss_fn: Callable,
ema: EMA = None,
optimizer: torch.optim.Optimizer = None,
scheduler: torch.optim.lr_scheduler = None,
start_epoch = 0,
best_loss = float("inf"),
accumulation_steps: int = 1,
max_grad_norm: float = 1.0):
self.accelerator = Accelerator(mixed_precision = 'fp16', gradient_accumulation_steps=accumulation_steps)
self.model = model.to(self.accelerator.device)
if ema is None:
self.ema = EMA(self.model).to(self.accelerator.device)
else:
self.ema = ema.to(self.accelerator.device)
self.loss_fn = loss_fn
self.optimizer = optimizer
if self.optimizer is None:
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = 1e-4)
self.scheduler = scheduler
if self.scheduler is None:
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=100)
self.start_epoch = start_epoch
self.best_loss = best_loss
self.accumulation_steps = accumulation_steps
self.max_grad_norm = max_grad_norm
def train(self, dl : DataLoader, epochs: int, file_name : str, no_label : bool = False):
self.model.train()
self.model, self.optimizer, data_loader, self.scheduler = self.accelerator.prepare(
self.model, self.optimizer, dl, self.scheduler
)
for epoch in range(self.start_epoch + 1, epochs + 1):
epoch_loss = 0.0
progress_bar = tqdm(data_loader, leave=False, desc=f"Epoch {epoch}/{epochs}", colour="#005500", disable = not self.accelerator.is_local_main_process)
for step, batch in enumerate(progress_bar):
with self.accelerator.accumulate(self.model): # Context manager for accumulation
if no_label:
if isinstance(batch, list):
x = batch[0].to(self.accelerator.device)
else:
x = batch.to(self.accelerator.device)
else:
x, y = batch[0].to(self.accelerator.device), batch[1].to(self.accelerator.device)
with self.accelerator.autocast():
if no_label:
loss = self.loss_fn(x)
else:
loss = self.loss_fn(x, y=y)
# Normalize the loss
self.accelerator.backward(loss)
# Gradient Clipping:
if self.max_grad_norm is not None and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
# Only step optimizer and scheduler when we have accumulated enough
self.optimizer.step()
self.ema.update()
self.optimizer.zero_grad()
epoch_loss += loss.item()
progress_bar.set_postfix(loss=epoch_loss / (min(step + 1, len(data_loader)))) # Correct progress bar update
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
epoch_loss = epoch_loss / len(progress_bar)
self.scheduler.step()
log_string = f"Loss at epoch {epoch}: {epoch_loss :.4f}"
# Save the best model
if self.best_loss > epoch_loss:
self.best_loss = epoch_loss
torch.save({
"model_state_dict": self.accelerator.get_state_dict(self.model),
"ema_state_dict": self.ema.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"epoch": epoch,
"training_steps": epoch * len(dl),
"best_loss": self.best_loss,
"batch_size": dl.batch_size,
"number_of_batches": len(dl)
}, file_name + '.pth')
log_string += " --> Best model ever (stored)"
print(log_string)