Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from helper.ema import EMA | |
from transformers import get_cosine_schedule_with_warmup | |
class Loader(): | |
def __init__(self, device = None): | |
self.device = device | |
def print_model(self, check_point): | |
print("Epoch: " + str(check_point["epoch"])) | |
print("Training step: " + str(check_point["training_steps"])) | |
print("Best loss: " + str(check_point["best_loss"])) | |
print("Batch size: " + str(check_point["batch_size"])) | |
print("Number of batches: " + str(check_point["number_of_batches"])) | |
def model_load(self, file_name : str, model : nn.Module, | |
print_dict : bool = True, is_ema: bool = True): | |
check_point = torch.load(file_name + ".pth", map_location=self.device, | |
weights_only=True) | |
if print_dict: self.print_model(check_point) | |
if is_ema: | |
model = EMA(model) | |
model.load_state_dict(check_point['ema_state_dict']) | |
model = model.ema_model | |
else: | |
model.load_state_dict(check_point['model_state_dict']) | |
model.eval() | |
print("===Model loaded!===") | |
return model | |
def load_for_training(self, file_name: str, model: nn.Module, print_dict: bool = True): | |
check_point = torch.load(file_name + ".pth", map_location=self.device, | |
weights_only=True) | |
if print_dict: self.print_model(check_point) | |
model.load_state_dict(check_point['model_state_dict']) | |
model.train() | |
ema = EMA(model) | |
ema.load_state_dict(check_point['ema_state_dict']) | |
ema.train() | |
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4) | |
optimizer.load_state_dict(check_point["optimizer_state_dict"]) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100) | |
scheduler.load_state_dict(check_point["scheduler_state_dict"]) | |
epoch = check_point["epoch"] | |
loss = check_point["best_loss"] | |
print("===Model/EMA/Optimizer/Scheduler/Epoch/Loss loaded!===") | |
return model, ema, optimizer, scheduler, epoch, loss |