File size: 1,091 Bytes
c165cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import shutil

import accelerate
import torch
import glob


def restore_checkpoint(
        checkpoint_dir,
        accelerator: accelerate.Accelerator,
        logger=None
):
    dirs = glob.glob(os.path.join(checkpoint_dir, "*"))
    dirs.sort()
    path = dirs[-1] if len(dirs) > 0 else None
    if path is None:
        if logger is not None:
            logger.info("Checkpoint does not exist. Starting a new training run.")
        init_step = 0
    else:
        if logger is not None:
            logger.info(f"Resuming from checkpoint {path}")
        accelerator.load_state(path)
        init_step = int(os.path.basename(path))
    return init_step


def save_checkpoint(save_dir,
                    accelerator: accelerate.Accelerator,
                    step=0,
                    total_limit=3):
    if total_limit > 0:
        folders = glob.glob(os.path.join(save_dir, "*"))
        folders.sort()
        for folder in folders[: len(folders) + 1 - total_limit]:
            shutil.rmtree(folder)
    accelerator.save_state(os.path.join(save_dir, f"{step:06d}"))