|
import os |
|
from omegaconf import OmegaConf, DictConfig |
|
from dataclasses import dataclass, field |
|
from typing import Any, Dict, List, Optional, Union |
|
from datetime import datetime |
|
|
|
@dataclass |
|
class ExperimentConfig: |
|
name: str = "default" |
|
tag: str = "" |
|
use_timestamp: bool = False |
|
timestamp: Optional[str] = None |
|
exp_root_dir: str = "outputs" |
|
|
|
|
|
exp_dir: str = "outputs/default" |
|
trial_name: str = "exp" |
|
trial_dir: str = "outputs/default/exp" |
|
|
|
|
|
resume: Optional[str] = None |
|
ckpt_path: Optional[str] = None |
|
|
|
data: dict = field(default_factory=dict) |
|
model_pl: dict = field(default_factory=dict) |
|
|
|
trainer: dict = field(default_factory=dict) |
|
checkpoint: dict = field(default_factory=dict) |
|
checkpoint_epoch: Optional[dict] = None |
|
wandb: dict = field(default_factory=dict) |
|
|
|
|
|
def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: |
|
if from_string: |
|
yaml_confs = [OmegaConf.create(s) for s in yamls] |
|
else: |
|
yaml_confs = [OmegaConf.load(f) for f in yamls] |
|
cli_conf = OmegaConf.from_cli(cli_args) |
|
cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) |
|
OmegaConf.resolve(cfg) |
|
assert isinstance(cfg, DictConfig) |
|
scfg = parse_structured(ExperimentConfig, cfg) |
|
return scfg |
|
|
|
|
|
def config_to_primitive(config, resolve: bool = True) -> Any: |
|
return OmegaConf.to_container(config, resolve=resolve) |
|
|
|
|
|
def dump_config(path: str, config) -> None: |
|
with open(path, "w") as fp: |
|
OmegaConf.save(config=config, f=fp) |
|
|
|
|
|
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: |
|
scfg = OmegaConf.structured(fields(**cfg)) |
|
return scfg |