File size: 2,521 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, Type, TypeVar
from dacite import Config, from_dict
from omegaconf import DictConfig, OmegaConf
from .dataset import DatasetCfgWrapper
from .dataset.data_module import DataLoaderCfg
from .loss import LossCfgWrapper
from .model.decoder import DecoderCfg
from .model.encoder import EncoderCfg
from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg
@dataclass
class CheckpointingCfg:
load: Optional[str] # Not a path, since it could be something like wandb://...
every_n_train_steps: int
save_top_k: int
save_weights_only: bool
@dataclass
class ModelCfg:
decoder: DecoderCfg
encoder: EncoderCfg
@dataclass
class TrainerCfg:
max_steps: int
val_check_interval: int | float | None
gradient_clip_val: int | float | None
num_nodes: int = 1
accumulate_grad_batches: int = 1
precision: Literal["32", "16-mixed", "bf16-mixed"] = "32"
@dataclass
class RootCfg:
wandb: dict
mode: Literal["train", "test"]
dataset: list[DatasetCfgWrapper]
data_loader: DataLoaderCfg
model: ModelCfg
optimizer: OptimizerCfg
checkpointing: CheckpointingCfg
trainer: TrainerCfg
loss: list[LossCfgWrapper]
test: TestCfg
train: TrainCfg
seed: int
TYPE_HOOKS = {
Path: Path,
}
T = TypeVar("T")
def load_typed_config(
cfg: DictConfig,
data_class: Type[T],
extra_type_hooks: dict = {},
) -> T:
return from_dict(
data_class,
OmegaConf.to_container(cfg),
config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}),
)
def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]:
# The dummy allows the union to be converted.
@dataclass
class Dummy:
dummy: LossCfgWrapper
return [
load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy
for k, v in joined.items()
]
def separate_dataset_cfg_wrappers(joined: dict) -> list[DatasetCfgWrapper]:
# The dummy allows the union to be converted.
@dataclass
class Dummy:
dummy: DatasetCfgWrapper
return [
load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy
for k, v in joined.items()
]
def load_typed_root_config(cfg: DictConfig) -> RootCfg:
return load_typed_config(
cfg,
RootCfg,
{list[LossCfgWrapper]: separate_loss_cfg_wrappers,
list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers},
)
|