Spaces:
Sleeping
Sleeping
| # Copyright 2019-present NAVER Corp. | |
| # CC BY-NC-SA 3.0 | |
| # Available only for non-commercial use | |
| import pdb | |
| from tqdm import tqdm | |
| from collections import defaultdict | |
| import torch | |
| import torch.nn as nn | |
| class Trainer(nn.Module): | |
| """Helper class to train a deep network. | |
| Overload this class `forward_backward` for your actual needs. | |
| Usage: | |
| train = Trainer(net, loader, loss, optimizer) | |
| for epoch in range(n_epochs): | |
| train() | |
| """ | |
| def __init__(self, net, loader, loss, optimizer): | |
| nn.Module.__init__(self) | |
| self.net = net | |
| self.loader = loader | |
| self.loss_func = loss | |
| self.optimizer = optimizer | |
| def iscuda(self): | |
| return next(self.net.parameters()).device != torch.device("cpu") | |
| def todevice(self, x): | |
| if isinstance(x, dict): | |
| return {k: self.todevice(v) for k, v in x.items()} | |
| if isinstance(x, (tuple, list)): | |
| return [self.todevice(v) for v in x] | |
| if self.iscuda(): | |
| return x.contiguous().cuda(non_blocking=True) | |
| else: | |
| return x.cpu() | |
| def __call__(self): | |
| self.net.train() | |
| stats = defaultdict(list) | |
| for iter, inputs in enumerate(tqdm(self.loader)): | |
| inputs = self.todevice(inputs) | |
| # compute gradient and do model update | |
| self.optimizer.zero_grad() | |
| loss, details = self.forward_backward(inputs) | |
| if torch.isnan(loss): | |
| raise RuntimeError("Loss is NaN") | |
| self.optimizer.step() | |
| for key, val in details.items(): | |
| stats[key].append(val) | |
| print(" Summary of losses during this epoch:") | |
| mean = lambda lis: sum(lis) / len(lis) | |
| for loss_name, vals in stats.items(): | |
| N = 1 + len(vals) // 10 | |
| print(f" - {loss_name:20}:", end="") | |
| print( | |
| f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})" | |
| ) | |
| return mean(stats["loss"]) # return average loss | |
| def forward_backward(self, inputs): | |
| raise NotImplementedError() | |