File size: 2,946 Bytes
77a88de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pathlib
import os
import torch

def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
    pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
    prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
    prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
    if len(prev_ckpts) > keep_latest-1:
        for f in prev_ckpts[keep_latest-1:]:
            f.unlink()
    save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
    save_dict = {
        "model": module.state_dict(),
        "optimizer": optimizer.state_dict(),
        "global_step": global_step,
    }
    if scheduler is not None:
        save_dict['scheduler'] = scheduler.state_dict()
    print(f"saving {save_path}")
    torch.save(save_dict, save_path)
    return False

def load(fabric, ckpt_path, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
    if verbose:
        print('reading ckpt from %s' % ckpt_path)
    if not os.path.exists(ckpt_path):
        print('...there is no full checkpoint in %s' % ckpt_path)
        print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_path --')
        assert(False)
    else:
        if os.path.isfile(ckpt_path):
            path = ckpt_path
            print('...found checkpoint %s' % (path))
        else:
            prev_ckpts = list(pathlib.Path(ckpt_path).glob('%s-*pth' % model_name))
            prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
            if len(prev_ckpts):
                path = prev_ckpts[0]
                # e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
                # OR ./whatever.pth
                step = int(str(path).split('-')[-1].split('.')[0])
                if verbose:
                    print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
            else:
                print('...there is no full checkpoint here!')
                return 0
        if fabric is not None:
            checkpoint = fabric.load(path)
        else:
            checkpoint = torch.load(path, weights_only=weights_only)
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer'])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint['scheduler'])
        assert ignore_load is None # not ready yet
        if 'model' in checkpoint:
            state_dict = checkpoint['model']
        else:
            state_dict = checkpoint
        model.load_state_dict(state_dict, strict=strict)
    return step