File size: 1,091 Bytes
c165cd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import os
import shutil
import accelerate
import torch
import glob
def restore_checkpoint(
checkpoint_dir,
accelerator: accelerate.Accelerator,
logger=None
):
dirs = glob.glob(os.path.join(checkpoint_dir, "*"))
dirs.sort()
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
if logger is not None:
logger.info("Checkpoint does not exist. Starting a new training run.")
init_step = 0
else:
if logger is not None:
logger.info(f"Resuming from checkpoint {path}")
accelerator.load_state(path)
init_step = int(os.path.basename(path))
return init_step
def save_checkpoint(save_dir,
accelerator: accelerate.Accelerator,
step=0,
total_limit=3):
if total_limit > 0:
folders = glob.glob(os.path.join(save_dir, "*"))
folders.sort()
for folder in folders[: len(folders) + 1 - total_limit]:
shutil.rmtree(folder)
accelerator.save_state(os.path.join(save_dir, f"{step:06d}"))
|