Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,946 Bytes
77a88de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import pathlib
import os
import torch
def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
if len(prev_ckpts) > keep_latest-1:
for f in prev_ckpts[keep_latest-1:]:
f.unlink()
save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
save_dict = {
"model": module.state_dict(),
"optimizer": optimizer.state_dict(),
"global_step": global_step,
}
if scheduler is not None:
save_dict['scheduler'] = scheduler.state_dict()
print(f"saving {save_path}")
torch.save(save_dict, save_path)
return False
def load(fabric, ckpt_path, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
if verbose:
print('reading ckpt from %s' % ckpt_path)
if not os.path.exists(ckpt_path):
print('...there is no full checkpoint in %s' % ckpt_path)
print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_path --')
assert(False)
else:
if os.path.isfile(ckpt_path):
path = ckpt_path
print('...found checkpoint %s' % (path))
else:
prev_ckpts = list(pathlib.Path(ckpt_path).glob('%s-*pth' % model_name))
prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
if len(prev_ckpts):
path = prev_ckpts[0]
# e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
# OR ./whatever.pth
step = int(str(path).split('-')[-1].split('.')[0])
if verbose:
print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
else:
print('...there is no full checkpoint here!')
return 0
if fabric is not None:
checkpoint = fabric.load(path)
else:
checkpoint = torch.load(path, weights_only=weights_only)
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
if scheduler is not None:
scheduler.load_state_dict(checkpoint['scheduler'])
assert ignore_load is None # not ready yet
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=strict)
return step
|