Spaces:
Runtime error
Runtime error
File size: 5,736 Bytes
bb3e610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
from torch import nn, Tensor
from torch.optim import Adam
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import LambdaLR
from functools import partial
from argparse import ArgumentParser
import os, sys, math
from typing import Union, Tuple, Dict, List
from collections import OrderedDict
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
# import losses
def cosine_annealing_warm_restarts(
epoch: int,
base_lr: float,
warmup_epochs: int,
warmup_lr: float,
T_0: int,
T_mult: int,
eta_min: float,
) -> float:
"""
Learning rate scheduler.
The learning rate will linearly increase from warmup_lr to lr in the first warmup_epochs epochs.
Then, the learning rate will follow the cosine annealing with warm restarts strategy.
"""
assert epoch >= 0, f"epoch must be non-negative, got {epoch}."
assert isinstance(warmup_epochs, int) and warmup_epochs >= 0, f"warmup_epochs must be non-negative, got {warmup_epochs}."
assert isinstance(warmup_lr, float) and warmup_lr > 0, f"warmup_lr must be positive, got {warmup_lr}."
assert isinstance(T_0, int) and T_0 >= 1, f"T_0 must be greater than or equal to 1, got {T_0}."
assert isinstance(T_mult, int) and T_mult >= 1, f"T_mult must be greater than or equal to 1, got {T_mult}."
assert isinstance(eta_min, float) and eta_min > 0, f"eta_min must be positive, got {eta_min}."
assert isinstance(base_lr, float) and base_lr > 0, f"base_lr must be positive, got {base_lr}."
assert base_lr > eta_min, f"base_lr must be greater than eta_min, got base_lr={base_lr} and eta_min={eta_min}."
assert warmup_lr >= eta_min, f"warmup_lr must be greater than or equal to eta_min, got warmup_lr={warmup_lr} and eta_min={eta_min}."
if epoch < warmup_epochs:
lr = warmup_lr + (base_lr - warmup_lr) * epoch / warmup_epochs
else:
epoch -= warmup_epochs
if T_mult == 1:
T_cur = epoch % T_0
T_i = T_0
else:
n = int(math.log((epoch / T_0 * (T_mult - 1) + 1), T_mult))
T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1)
T_i = T_0 * T_mult ** (n)
lr = eta_min + (base_lr - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
return lr / base_lr
def get_loss_fn(args: ArgumentParser) -> nn.Module:
if args.bins is None:
assert args.weight_ot is not None and args.weight_tv is not None, f"Expected weight_ot and weight_tv to be not None, got {args.weight_ot} and {args.weight_tv}"
loss_fn = losses.DMLoss(
input_size=args.input_size,
reduction=args.reduction,
)
else:
loss_fn = losses.DACELoss(
bins=args.bins,
reduction=args.reduction,
weight_count_loss=args.weight_count_loss,
count_loss=args.count_loss,
input_size=args.input_size,
)
return loss_fn
def get_optimizer(args: ArgumentParser, model: nn.Module) -> Tuple[Adam, LambdaLR]:
optimizer = Adam(
params=filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
weight_decay=args.weight_decay
)
scheduler = LambdaLR(
optimizer=optimizer,
lr_lambda=partial(
cosine_annealing_warm_restarts,
warmup_epochs=args.warmup_epochs,
warmup_lr=args.warmup_lr,
T_0=args.T_0,
T_mult=args.T_mult,
eta_min=args.eta_min,
base_lr=args.lr
),
)
return optimizer, scheduler
def load_checkpoint(
args: ArgumentParser,
model: nn.Module,
optimizer: Adam,
scheduler: LambdaLR,
grad_scaler: GradScaler,
) -> Tuple[nn.Module, Adam, Union[LambdaLR, None], GradScaler, int, Union[Dict[str, float], None], Dict[str, List[float]], Dict[str, float]]:
ckpt_path = os.path.join(args.ckpt_dir, "ckpt.pth")
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
start_epoch = ckpt["epoch"]
loss_info = ckpt["loss_info"]
hist_scores = ckpt["hist_scores"]
best_scores = ckpt["best_scores"]
if scheduler is not None:
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
if grad_scaler is not None:
grad_scaler.load_state_dict(ckpt["grad_scaler_state_dict"])
print(f"Loaded checkpoint from {ckpt_path}.")
else:
start_epoch = 1
loss_info, hist_scores = None, {"mae": [], "rmse": []}
best_scores = {k: [torch.inf] * args.save_best_k for k in hist_scores.keys()}
print(f"Checkpoint not found at {ckpt_path}.")
return model, optimizer, scheduler, grad_scaler, start_epoch, loss_info, hist_scores, best_scores
def save_checkpoint(
epoch: int,
model_state_dict: OrderedDict[str, Tensor],
optimizer_state_dict: OrderedDict[str, Tensor],
scheduler_state_dict: OrderedDict[str, Tensor],
grad_scaler_state_dict: OrderedDict[str, Tensor],
loss_info: Dict[str, List[float]],
hist_scores: Dict[str, List[float]],
best_scores: Dict[str, float],
ckpt_dir: str,
) -> None:
ckpt = {
"epoch": epoch,
"model_state_dict": model_state_dict,
"optimizer_state_dict": optimizer_state_dict,
"scheduler_state_dict": scheduler_state_dict,
"grad_scaler_state_dict": grad_scaler_state_dict,
"loss_info": loss_info,
"hist_scores": hist_scores,
"best_scores": best_scores,
}
torch.save(ckpt, os.path.join(ckpt_dir, "ckpt.pth"))
|