AnySplat / src /config.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, Type, TypeVar
from dacite import Config, from_dict
from omegaconf import DictConfig, OmegaConf
from .dataset import DatasetCfgWrapper
from .dataset.data_module import DataLoaderCfg
from .loss import LossCfgWrapper
from .model.decoder import DecoderCfg
from .model.encoder import EncoderCfg
from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg
@dataclass
class CheckpointingCfg:
load: Optional[str] # Not a path, since it could be something like wandb://...
every_n_train_steps: int
save_top_k: int
save_weights_only: bool
@dataclass
class ModelCfg:
decoder: DecoderCfg
encoder: EncoderCfg
@dataclass
class TrainerCfg:
max_steps: int
val_check_interval: int | float | None
gradient_clip_val: int | float | None
num_nodes: int = 1
accumulate_grad_batches: int = 1
precision: Literal["32", "16-mixed", "bf16-mixed"] = "32"
@dataclass
class RootCfg:
wandb: dict
mode: Literal["train", "test"]
dataset: list[DatasetCfgWrapper]
data_loader: DataLoaderCfg
model: ModelCfg
optimizer: OptimizerCfg
checkpointing: CheckpointingCfg
trainer: TrainerCfg
loss: list[LossCfgWrapper]
test: TestCfg
train: TrainCfg
seed: int
TYPE_HOOKS = {
Path: Path,
}
T = TypeVar("T")
def load_typed_config(
cfg: DictConfig,
data_class: Type[T],
extra_type_hooks: dict = {},
) -> T:
return from_dict(
data_class,
OmegaConf.to_container(cfg),
config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}),
)
def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]:
# The dummy allows the union to be converted.
@dataclass
class Dummy:
dummy: LossCfgWrapper
return [
load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy
for k, v in joined.items()
]
def separate_dataset_cfg_wrappers(joined: dict) -> list[DatasetCfgWrapper]:
# The dummy allows the union to be converted.
@dataclass
class Dummy:
dummy: DatasetCfgWrapper
return [
load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy
for k, v in joined.items()
]
def load_typed_root_config(cfg: DictConfig) -> RootCfg:
return load_typed_config(
cfg,
RootCfg,
{list[LossCfgWrapper]: separate_loss_cfg_wrappers,
list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers},
)