KoFace-AI / helper /loader.py
JuyeopDang's picture
Upload 35 files
5ab5cab verified
import torch
import torch.nn as nn
from helper.ema import EMA
from transformers import get_cosine_schedule_with_warmup
class Loader():
def __init__(self, device = None):
self.device = device
def print_model(self, check_point):
print("Epoch: " + str(check_point["epoch"]))
print("Training step: " + str(check_point["training_steps"]))
print("Best loss: " + str(check_point["best_loss"]))
print("Batch size: " + str(check_point["batch_size"]))
print("Number of batches: " + str(check_point["number_of_batches"]))
def model_load(self, file_name : str, model : nn.Module,
print_dict : bool = True, is_ema: bool = True):
check_point = torch.load(file_name + ".pth", map_location=self.device,
weights_only=True)
if print_dict: self.print_model(check_point)
if is_ema:
model = EMA(model)
model.load_state_dict(check_point['ema_state_dict'])
model = model.ema_model
else:
model.load_state_dict(check_point['model_state_dict'])
model.eval()
print("===Model loaded!===")
return model
def load_for_training(self, file_name: str, model: nn.Module, print_dict: bool = True):
check_point = torch.load(file_name + ".pth", map_location=self.device,
weights_only=True)
if print_dict: self.print_model(check_point)
model.load_state_dict(check_point['model_state_dict'])
model.train()
ema = EMA(model)
ema.load_state_dict(check_point['ema_state_dict'])
ema.train()
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4)
optimizer.load_state_dict(check_point["optimizer_state_dict"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
scheduler.load_state_dict(check_point["scheduler_state_dict"])
epoch = check_point["epoch"]
loss = check_point["best_loss"]
print("===Model/EMA/Optimizer/Scheduler/Epoch/Loss loaded!===")
return model, ema, optimizer, scheduler, epoch, loss