ReubenSun's picture
1
2ac1c2d
import os
from dataclasses import dataclass, field
from datetime import datetime
from omegaconf import OmegaConf
from .core import debug, find, info, warn
from .typing import *
# ============ Register OmegaConf Resolvers ============= #
OmegaConf.register_new_resolver(
"calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n)
)
OmegaConf.register_new_resolver("add", lambda a, b: a + b)
OmegaConf.register_new_resolver("sub", lambda a, b: a - b)
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
OmegaConf.register_new_resolver("div", lambda a, b: a / b)
OmegaConf.register_new_resolver("idiv", lambda a, b: a // b)
OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p))
OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub))
OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)])
OmegaConf.register_new_resolver("gt0", lambda s: s > 0)
OmegaConf.register_new_resolver("not", lambda s: not s)
def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8):
return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs
OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps)
# ======================================================= #
# ============== Automatic Name Resolvers =============== #
def get_naming_convention(cfg):
# TODO
name = f"lrm_{cfg.system.backbone.num_layers}"
return name
# ======================================================= #
@dataclass
class ExperimentConfig:
name: str = "default"
description: str = ""
tag: str = ""
seed: int = 0
use_timestamp: bool = True
timestamp: Optional[str] = None
exp_root_dir: str = "outputs"
### these shouldn't be set manually
exp_dir: str = "outputs/default"
trial_name: str = "exp"
trial_dir: str = "outputs/default/exp"
n_gpus: int = 1
###
resume: Optional[str] = None
data_cls: str = ""
data: dict = field(default_factory=dict)
system_cls: str = ""
system: dict = field(default_factory=dict)
# accept pytorch-lightning trainer parameters
# see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api
trainer: dict = field(default_factory=dict)
# accept pytorch-lightning checkpoint callback parameters
# see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
checkpoint: dict = field(default_factory=dict)
def load_config(
*yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs
) -> Any:
if from_string:
parse_func = OmegaConf.create
else:
parse_func = OmegaConf.load
yaml_confs = []
for y in yamls:
conf = parse_func(y)
extends = conf.pop("extends", None)
if extends:
assert os.path.exists(extends), f"File {extends} does not exist."
yaml_confs.append(OmegaConf.load(extends))
yaml_confs.append(conf)
cli_conf = OmegaConf.from_cli(cli_args)
cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg)
# post processing
# auto naming
if scfg.name == "auto":
scfg.name = get_naming_convention(scfg)
# add timestamp
if not scfg.tag and not scfg.use_timestamp:
raise ValueError("Either tag is specified or use_timestamp is True.")
scfg.trial_name = scfg.tag
# if resume from an existing config, scfg.timestamp should not be None
if scfg.timestamp is None:
scfg.timestamp = ""
if scfg.use_timestamp:
if scfg.n_gpus > 1:
warn(
"Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag."
)
else:
scfg.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S")
# make directories
scfg.trial_name += scfg.timestamp
scfg.exp_dir = os.path.join(scfg.exp_root_dir, scfg.name)
scfg.trial_dir = os.path.join(scfg.exp_dir, scfg.trial_name)
if makedirs:
os.makedirs(scfg.trial_dir, exist_ok=True)
return scfg
def config_to_primitive(config, resolve: bool = True) -> Any:
return OmegaConf.to_container(config, resolve=resolve)
def dump_config(path: str, config) -> None:
with open(path, "w") as fp:
OmegaConf.save(config=config, f=fp)
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
return scfg