File size: 4,716 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm import tqdm
from typing import Callable
from helper.ema import EMA

class Trainer():
    def __init__(self,
                 model: nn.Module,
                 loss_fn: Callable,
                 ema: EMA = None,
                 optimizer: torch.optim.Optimizer = None,
                 scheduler: torch.optim.lr_scheduler = None,
                 start_epoch = 0,
                 best_loss = float("inf"),
                 accumulation_steps: int = 1,
                 max_grad_norm: float = 1.0):
        self.accelerator = Accelerator(mixed_precision = 'fp16', gradient_accumulation_steps=accumulation_steps)
        self.model = model.to(self.accelerator.device)
        if ema is None:
            self.ema = EMA(self.model).to(self.accelerator.device)            
        else:
            self.ema = ema.to(self.accelerator.device)
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        if self.optimizer is None:
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = 1e-4)
        self.scheduler = scheduler
        if self.scheduler is None:
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=100)
        self.start_epoch = start_epoch
        self.best_loss = best_loss
        self.accumulation_steps = accumulation_steps
        self.max_grad_norm = max_grad_norm
            
    def train(self, dl : DataLoader, epochs: int, file_name : str, no_label : bool = False):
        self.model.train()
        self.model, self.optimizer, data_loader, self.scheduler = self.accelerator.prepare(
            self.model, self.optimizer, dl, self.scheduler
            )
        
        for epoch in range(self.start_epoch + 1, epochs + 1):
            epoch_loss = 0.0
            progress_bar = tqdm(data_loader, leave=False, desc=f"Epoch {epoch}/{epochs}", colour="#005500", disable = not self.accelerator.is_local_main_process)
            for step, batch in enumerate(progress_bar):
                with self.accelerator.accumulate(self.model):  # Context manager for accumulation
                    if no_label:
                        if isinstance(batch, list):
                            x = batch[0].to(self.accelerator.device)
                        else:
                            x = batch.to(self.accelerator.device)
                    else:
                        x, y = batch[0].to(self.accelerator.device), batch[1].to(self.accelerator.device)
                    
                    with self.accelerator.autocast():
                        if no_label:
                            loss = self.loss_fn(x)
                        else:
                            loss = self.loss_fn(x, y=y)

                    # Normalize the loss
                    self.accelerator.backward(loss)

                    # Gradient Clipping:
                    if self.max_grad_norm is not None and self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                    # Only step optimizer and scheduler when we have accumulated enough
                    self.optimizer.step()
                    self.ema.update()
                    self.optimizer.zero_grad()

                    epoch_loss += loss.item()
                    progress_bar.set_postfix(loss=epoch_loss / (min(step + 1, len(data_loader)))) # Correct progress bar update
                
            self.accelerator.wait_for_everyone()
            if self.accelerator.is_main_process:
                epoch_loss = epoch_loss / len(progress_bar)
                self.scheduler.step()
                log_string = f"Loss at epoch {epoch}: {epoch_loss :.4f}"

                # Save the best model
                if self.best_loss > epoch_loss:
                    self.best_loss = epoch_loss
                    torch.save({
                        "model_state_dict": self.accelerator.get_state_dict(self.model),
                        "ema_state_dict": self.ema.state_dict(),
                        "optimizer_state_dict": self.optimizer.state_dict(),
                        "scheduler_state_dict": self.scheduler.state_dict(),
                        "epoch": epoch,
                        "training_steps": epoch * len(dl),
                        "best_loss": self.best_loss,
                        "batch_size": dl.batch_size,
                        "number_of_batches": len(dl)
                        }, file_name + '.pth')
                    log_string += " --> Best model ever (stored)"
                print(log_string)