Spaces:
Running
Running
| import os | |
| import torch | |
| from tools.utils.logging import get_logger | |
| def save_ckpt( | |
| model, | |
| cfg, | |
| optimizer, | |
| lr_scheduler, | |
| epoch, | |
| global_step, | |
| metrics, | |
| is_best=False, | |
| logger=None, | |
| prefix=None, | |
| ): | |
| """ | |
| Saving checkpoints | |
| :param epoch: current epoch number | |
| :param log: logging information of the epoch | |
| :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' | |
| """ | |
| if logger is None: | |
| logger = get_logger() | |
| if prefix is None: | |
| if is_best: | |
| save_path = os.path.join(cfg["Global"]["output_dir"], "best.pth") | |
| else: | |
| save_path = os.path.join(cfg["Global"]["output_dir"], "latest.pth") | |
| else: | |
| save_path = os.path.join(cfg["Global"]["output_dir"], prefix + ".pth") | |
| state_dict = model.module.state_dict() if cfg["Global"]["distributed"] else model.state_dict() | |
| state = { | |
| "epoch": epoch, | |
| "global_step": global_step, | |
| "state_dict": state_dict, | |
| "optimizer": None if is_best else optimizer.state_dict(), | |
| "scheduler": None if is_best else lr_scheduler.state_dict(), | |
| "config": cfg, | |
| "metrics": metrics, | |
| } | |
| torch.save(state, save_path) | |
| logger.info(f"save ckpt to {save_path}") | |
| def load_ckpt(model, cfg, optimizer=None, lr_scheduler=None, logger=None): | |
| """ | |
| Resume from saved checkpoints | |
| :param checkpoint_path: Checkpoint path to be resumed | |
| """ | |
| if logger is None: | |
| logger = get_logger() | |
| checkpoints = cfg["Global"].get("checkpoints") | |
| pretrained_model = cfg["Global"].get("pretrained_model") | |
| status = {} | |
| if checkpoints and os.path.exists(checkpoints): | |
| checkpoint = torch.load(checkpoints, map_location=torch.device("cpu")) | |
| model.load_state_dict(checkpoint["state_dict"], strict=True) | |
| if optimizer is not None: | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| if lr_scheduler is not None: | |
| lr_scheduler.load_state_dict(checkpoint["scheduler"]) | |
| logger.info(f"resume from checkpoint {checkpoints} (epoch {checkpoint['epoch']})") | |
| status["global_step"] = checkpoint["global_step"] | |
| status["epoch"] = checkpoint["epoch"] + 1 | |
| status["metrics"] = checkpoint["metrics"] | |
| elif pretrained_model and os.path.exists(pretrained_model): | |
| load_pretrained_params(model, pretrained_model, logger) | |
| logger.info(f"finetune from checkpoint {pretrained_model}") | |
| else: | |
| logger.info("train from scratch") | |
| return status | |
| def load_pretrained_params(model, pretrained_model, logger): | |
| checkpoint = torch.load(pretrained_model, map_location=torch.device("cpu")) | |
| model.load_state_dict(checkpoint["state_dict"], strict=False) | |
| for name in model.state_dict().keys(): | |
| if name not in checkpoint["state_dict"]: | |
| logger.info(f"{name} is not in pretrained model") | |