File size: 1,736 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import os
from omegaconf import OmegaConf, DictConfig
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from datetime import datetime
@dataclass
class ExperimentConfig:
name: str = "default"
tag: str = ""
use_timestamp: bool = False
timestamp: Optional[str] = None
exp_root_dir: str = "outputs"
### these shouldn't be set manually
exp_dir: str = "outputs/default"
trial_name: str = "exp"
trial_dir: str = "outputs/default/exp"
###
resume: Optional[str] = None
ckpt_path: Optional[str] = None
data: dict = field(default_factory=dict)
model_pl: dict = field(default_factory=dict)
trainer: dict = field(default_factory=dict)
checkpoint: dict = field(default_factory=dict)
checkpoint_epoch: Optional[dict] = None
wandb: dict = field(default_factory=dict)
def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any:
if from_string:
yaml_confs = [OmegaConf.create(s) for s in yamls]
else:
yaml_confs = [OmegaConf.load(f) for f in yamls]
cli_conf = OmegaConf.from_cli(cli_args)
cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
scfg = parse_structured(ExperimentConfig, cfg)
return scfg
def config_to_primitive(config, resolve: bool = True) -> Any:
return OmegaConf.to_container(config, resolve=resolve)
def dump_config(path: str, config) -> None:
with open(path, "w") as fp:
OmegaConf.save(config=config, f=fp)
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
scfg = OmegaConf.structured(fields(**cfg))
return scfg |