|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Literal, Optional, Type, TypeVar |
|
|
|
from dacite import Config, from_dict |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
from .dataset import DatasetCfgWrapper |
|
from .dataset.data_module import DataLoaderCfg |
|
from .loss import LossCfgWrapper |
|
from .model.decoder import DecoderCfg |
|
from .model.encoder import EncoderCfg |
|
from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg |
|
|
|
|
|
@dataclass |
|
class CheckpointingCfg: |
|
load: Optional[str] |
|
every_n_train_steps: int |
|
save_top_k: int |
|
save_weights_only: bool |
|
|
|
|
|
@dataclass |
|
class ModelCfg: |
|
decoder: DecoderCfg |
|
encoder: EncoderCfg |
|
|
|
|
|
@dataclass |
|
class TrainerCfg: |
|
max_steps: int |
|
val_check_interval: int | float | None |
|
gradient_clip_val: int | float | None |
|
num_nodes: int = 1 |
|
accumulate_grad_batches: int = 1 |
|
precision: Literal["32", "16-mixed", "bf16-mixed"] = "32" |
|
|
|
|
|
@dataclass |
|
class RootCfg: |
|
wandb: dict |
|
mode: Literal["train", "test"] |
|
dataset: list[DatasetCfgWrapper] |
|
data_loader: DataLoaderCfg |
|
model: ModelCfg |
|
optimizer: OptimizerCfg |
|
checkpointing: CheckpointingCfg |
|
trainer: TrainerCfg |
|
loss: list[LossCfgWrapper] |
|
test: TestCfg |
|
train: TrainCfg |
|
seed: int |
|
|
|
|
|
TYPE_HOOKS = { |
|
Path: Path, |
|
} |
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
def load_typed_config( |
|
cfg: DictConfig, |
|
data_class: Type[T], |
|
extra_type_hooks: dict = {}, |
|
) -> T: |
|
return from_dict( |
|
data_class, |
|
OmegaConf.to_container(cfg), |
|
config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), |
|
) |
|
|
|
|
|
def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: |
|
|
|
@dataclass |
|
class Dummy: |
|
dummy: LossCfgWrapper |
|
|
|
return [ |
|
load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy |
|
for k, v in joined.items() |
|
] |
|
|
|
|
|
def separate_dataset_cfg_wrappers(joined: dict) -> list[DatasetCfgWrapper]: |
|
|
|
@dataclass |
|
class Dummy: |
|
dummy: DatasetCfgWrapper |
|
|
|
return [ |
|
load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy |
|
for k, v in joined.items() |
|
] |
|
|
|
|
|
def load_typed_root_config(cfg: DictConfig) -> RootCfg: |
|
return load_typed_config( |
|
cfg, |
|
RootCfg, |
|
{list[LossCfgWrapper]: separate_loss_cfg_wrappers, |
|
list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers}, |
|
) |
|
|