Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,689 Bytes
2ac1c2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
from dataclasses import dataclass, field
from datetime import datetime
from omegaconf import OmegaConf
from .core import debug, find, info, warn
from .typing import *
# ============ Register OmegaConf Resolvers ============= #
OmegaConf.register_new_resolver(
"calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n)
)
OmegaConf.register_new_resolver("add", lambda a, b: a + b)
OmegaConf.register_new_resolver("sub", lambda a, b: a - b)
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
OmegaConf.register_new_resolver("div", lambda a, b: a / b)
OmegaConf.register_new_resolver("idiv", lambda a, b: a // b)
OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p))
OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub))
OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)])
OmegaConf.register_new_resolver("gt0", lambda s: s > 0)
OmegaConf.register_new_resolver("not", lambda s: not s)
def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8):
return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs
OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps)
# ======================================================= #
# ============== Automatic Name Resolvers =============== #
def get_naming_convention(cfg):
# TODO
name = f"lrm_{cfg.system.backbone.num_layers}"
return name
# ======================================================= #
@dataclass
class ExperimentConfig:
name: str = "default"
description: str = ""
tag: str = ""
seed: int = 0
use_timestamp: bool = True
timestamp: Optional[str] = None
exp_root_dir: str = "outputs"
### these shouldn't be set manually
exp_dir: str = "outputs/default"
trial_name: str = "exp"
trial_dir: str = "outputs/default/exp"
n_gpus: int = 1
###
resume: Optional[str] = None
data_cls: str = ""
data: dict = field(default_factory=dict)
system_cls: str = ""
system: dict = field(default_factory=dict)
# accept pytorch-lightning trainer parameters
# see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api
trainer: dict = field(default_factory=dict)
# accept pytorch-lightning checkpoint callback parameters
# see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
checkpoint: dict = field(default_factory=dict)
def load_config(
*yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs
) -> Any:
if from_string:
parse_func = OmegaConf.create
else:
parse_func = OmegaConf.load
yaml_confs = []
for y in yamls:
conf = parse_func(y)
extends = conf.pop("extends", None)
if extends:
assert os.path.exists(extends), f"File {extends} does not exist."
yaml_confs.append(OmegaConf.load(extends))
yaml_confs.append(conf)
cli_conf = OmegaConf.from_cli(cli_args)
cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg)
# post processing
# auto naming
if scfg.name == "auto":
scfg.name = get_naming_convention(scfg)
# add timestamp
if not scfg.tag and not scfg.use_timestamp:
raise ValueError("Either tag is specified or use_timestamp is True.")
scfg.trial_name = scfg.tag
# if resume from an existing config, scfg.timestamp should not be None
if scfg.timestamp is None:
scfg.timestamp = ""
if scfg.use_timestamp:
if scfg.n_gpus > 1:
warn(
"Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag."
)
else:
scfg.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S")
# make directories
scfg.trial_name += scfg.timestamp
scfg.exp_dir = os.path.join(scfg.exp_root_dir, scfg.name)
scfg.trial_dir = os.path.join(scfg.exp_dir, scfg.trial_name)
if makedirs:
os.makedirs(scfg.trial_dir, exist_ok=True)
return scfg
def config_to_primitive(config, resolve: bool = True) -> Any:
return OmegaConf.to_container(config, resolve=resolve)
def dump_config(path: str, config) -> None:
with open(path, "w") as fp:
OmegaConf.save(config=config, f=fp)
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
return scfg
|