alexnasa's picture
Upload 243 files
2568013 verified
from typing import Any
from ...misc.step_tracker import StepTracker
from ..types import Stage
from .view_sampler import ViewSampler
from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg
from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg
from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg
from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg
from .view_sampler_rank import ViewSamplerRank, ViewSamplerRankCfg
VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = {
"all": ViewSamplerAll,
"arbitrary": ViewSamplerArbitrary,
"bounded": ViewSamplerBounded,
"evaluation": ViewSamplerEvaluation,
"rank": ViewSamplerRank,
}
ViewSamplerCfg = (
ViewSamplerArbitraryCfg
| ViewSamplerBoundedCfg
| ViewSamplerEvaluationCfg
| ViewSamplerAllCfg
| ViewSamplerRankCfg
)
def get_view_sampler(
cfg: ViewSamplerCfg,
stage: Stage,
overfit: bool,
cameras_are_circular: bool,
step_tracker: StepTracker | None,
) -> ViewSampler[Any]:
return VIEW_SAMPLERS[cfg.name](
cfg,
stage,
overfit,
cameras_are_circular,
step_tracker,
)