File size: 5,736 Bytes
bb3e610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
from torch import nn, Tensor

from torch.optim import Adam
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import LambdaLR

from functools import partial
from argparse import ArgumentParser

import os, sys, math
from typing import Union, Tuple, Dict, List
from collections import OrderedDict

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)

# import losses


def cosine_annealing_warm_restarts(
    epoch: int,
    base_lr: float,
    warmup_epochs: int,
    warmup_lr: float,
    T_0: int,
    T_mult: int,
    eta_min: float,
) -> float:
    """
    Learning rate scheduler.
    The learning rate will linearly increase from warmup_lr to lr in the first warmup_epochs epochs.
    Then, the learning rate will follow the cosine annealing with warm restarts strategy.
    """
    assert epoch >= 0, f"epoch must be non-negative, got {epoch}."
    assert isinstance(warmup_epochs, int) and warmup_epochs >= 0, f"warmup_epochs must be non-negative, got {warmup_epochs}."
    assert isinstance(warmup_lr, float) and warmup_lr > 0, f"warmup_lr must be positive, got {warmup_lr}."
    assert isinstance(T_0, int) and T_0 >= 1, f"T_0 must be greater than or equal to 1, got {T_0}."
    assert isinstance(T_mult, int) and T_mult >= 1, f"T_mult must be greater than or equal to 1, got {T_mult}."
    assert isinstance(eta_min, float) and eta_min > 0, f"eta_min must be positive, got {eta_min}."
    assert isinstance(base_lr, float) and base_lr > 0, f"base_lr must be positive, got {base_lr}."
    assert base_lr > eta_min, f"base_lr must be greater than eta_min, got base_lr={base_lr} and eta_min={eta_min}."
    assert warmup_lr >= eta_min, f"warmup_lr must be greater than or equal to eta_min, got warmup_lr={warmup_lr} and eta_min={eta_min}."

    if epoch < warmup_epochs:
        lr = warmup_lr + (base_lr - warmup_lr) * epoch / warmup_epochs
    else:
        epoch -= warmup_epochs
        if T_mult == 1:
            T_cur = epoch % T_0
            T_i = T_0
        else:
            n = int(math.log((epoch / T_0 * (T_mult - 1) + 1), T_mult))
            T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1)
            T_i = T_0 * T_mult ** (n)
        
        lr = eta_min + (base_lr - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2

    return lr / base_lr


def get_loss_fn(args: ArgumentParser) -> nn.Module:
    if args.bins is None:
        assert args.weight_ot is not None and args.weight_tv is not None, f"Expected weight_ot and weight_tv to be not None, got {args.weight_ot} and {args.weight_tv}"
        loss_fn = losses.DMLoss(
            input_size=args.input_size,
            reduction=args.reduction,
        )
    else:
        loss_fn = losses.DACELoss(
            bins=args.bins,
            reduction=args.reduction,
            weight_count_loss=args.weight_count_loss,
            count_loss=args.count_loss,
            input_size=args.input_size,
        )
    return loss_fn


def get_optimizer(args: ArgumentParser, model: nn.Module) -> Tuple[Adam, LambdaLR]:
    optimizer = Adam(
        params=filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=args.weight_decay
    )

    scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=partial(
            cosine_annealing_warm_restarts,
            warmup_epochs=args.warmup_epochs,
            warmup_lr=args.warmup_lr,
            T_0=args.T_0,
            T_mult=args.T_mult,
            eta_min=args.eta_min,
            base_lr=args.lr
        ),
    )

    return optimizer, scheduler


def load_checkpoint(
    args: ArgumentParser,
    model: nn.Module,
    optimizer: Adam,
    scheduler: LambdaLR,
    grad_scaler: GradScaler,
) -> Tuple[nn.Module, Adam, Union[LambdaLR, None], GradScaler, int, Union[Dict[str, float], None], Dict[str, List[float]], Dict[str, float]]:
    ckpt_path = os.path.join(args.ckpt_dir, "ckpt.pth")
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        start_epoch = ckpt["epoch"]
        loss_info = ckpt["loss_info"]
        hist_scores = ckpt["hist_scores"]
        best_scores = ckpt["best_scores"]

        if scheduler is not None:
            scheduler.load_state_dict(ckpt["scheduler_state_dict"])
        if grad_scaler is not None:
            grad_scaler.load_state_dict(ckpt["grad_scaler_state_dict"])

        print(f"Loaded checkpoint from {ckpt_path}.")

    else:
        start_epoch = 1
        loss_info, hist_scores = None, {"mae": [], "rmse": []}
        best_scores = {k: [torch.inf] * args.save_best_k for k in hist_scores.keys()}
        print(f"Checkpoint not found at {ckpt_path}.")

    return model, optimizer, scheduler, grad_scaler, start_epoch, loss_info, hist_scores, best_scores


def save_checkpoint(
    epoch: int,
    model_state_dict: OrderedDict[str, Tensor],
    optimizer_state_dict: OrderedDict[str, Tensor],
    scheduler_state_dict: OrderedDict[str, Tensor],
    grad_scaler_state_dict: OrderedDict[str, Tensor],
    loss_info: Dict[str, List[float]],
    hist_scores: Dict[str, List[float]],
    best_scores: Dict[str, float],
    ckpt_dir: str,
) -> None:
    ckpt = {
        "epoch": epoch,
        "model_state_dict": model_state_dict,
        "optimizer_state_dict": optimizer_state_dict,
        "scheduler_state_dict": scheduler_state_dict,
        "grad_scaler_state_dict": grad_scaler_state_dict,
        "loss_info": loss_info,
        "hist_scores": hist_scores,
        "best_scores": best_scores,
    }
    torch.save(ckpt, os.path.join(ckpt_dir, "ckpt.pth"))