File size: 506 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from pathlib import Path
from typing import Any

import torch
from ignite.handlers import Checkpoint


def load_checkpoint(ckpt_path: Path, to_save: dict[str, Any], strict: bool = False):
    assert ckpt_path.exists(), f"__Checkpoint '{str(ckpt_path)}' is not found"
    checkpoint = torch.load(str(ckpt_path), map_location="cpu")

    to_save = {"model": to_save["model"]}
    checkpoint = {"model": checkpoint["model"]}

    Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint, strict=strict)