from dataclasses import dataclass from pathlib import Path from typing import Literal, Optional, Type, TypeVar from dacite import Config, from_dict from omegaconf import DictConfig, OmegaConf from .dataset import DatasetCfgWrapper from .dataset.data_module import DataLoaderCfg from .loss import LossCfgWrapper from .model.decoder import DecoderCfg from .model.encoder import EncoderCfg from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg @dataclass class CheckpointingCfg: load: Optional[str] # Not a path, since it could be something like wandb://... every_n_train_steps: int save_top_k: int save_weights_only: bool @dataclass class ModelCfg: decoder: DecoderCfg encoder: EncoderCfg @dataclass class TrainerCfg: max_steps: int val_check_interval: int | float | None gradient_clip_val: int | float | None num_nodes: int = 1 accumulate_grad_batches: int = 1 precision: Literal["32", "16-mixed", "bf16-mixed"] = "32" @dataclass class RootCfg: wandb: dict mode: Literal["train", "test"] dataset: list[DatasetCfgWrapper] data_loader: DataLoaderCfg model: ModelCfg optimizer: OptimizerCfg checkpointing: CheckpointingCfg trainer: TrainerCfg loss: list[LossCfgWrapper] test: TestCfg train: TrainCfg seed: int TYPE_HOOKS = { Path: Path, } T = TypeVar("T") def load_typed_config( cfg: DictConfig, data_class: Type[T], extra_type_hooks: dict = {}, ) -> T: return from_dict( data_class, OmegaConf.to_container(cfg), config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), ) def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: # The dummy allows the union to be converted. @dataclass class Dummy: dummy: LossCfgWrapper return [ load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy for k, v in joined.items() ] def separate_dataset_cfg_wrappers(joined: dict) -> list[DatasetCfgWrapper]: # The dummy allows the union to be converted. @dataclass class Dummy: dummy: DatasetCfgWrapper return [ load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy for k, v in joined.items() ] def load_typed_root_config(cfg: DictConfig) -> RootCfg: return load_typed_config( cfg, RootCfg, {list[LossCfgWrapper]: separate_loss_cfg_wrappers, list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers}, )