|
from typing import Any |
|
|
|
from ...misc.step_tracker import StepTracker |
|
from ..types import Stage |
|
from .view_sampler import ViewSampler |
|
from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg |
|
from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg |
|
from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg |
|
from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg |
|
from .view_sampler_rank import ViewSamplerRank, ViewSamplerRankCfg |
|
VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { |
|
"all": ViewSamplerAll, |
|
"arbitrary": ViewSamplerArbitrary, |
|
"bounded": ViewSamplerBounded, |
|
"evaluation": ViewSamplerEvaluation, |
|
"rank": ViewSamplerRank, |
|
} |
|
|
|
ViewSamplerCfg = ( |
|
ViewSamplerArbitraryCfg |
|
| ViewSamplerBoundedCfg |
|
| ViewSamplerEvaluationCfg |
|
| ViewSamplerAllCfg |
|
| ViewSamplerRankCfg |
|
) |
|
|
|
def get_view_sampler( |
|
cfg: ViewSamplerCfg, |
|
stage: Stage, |
|
overfit: bool, |
|
cameras_are_circular: bool, |
|
step_tracker: StepTracker | None, |
|
) -> ViewSampler[Any]: |
|
return VIEW_SAMPLERS[cfg.name]( |
|
cfg, |
|
stage, |
|
overfit, |
|
cameras_are_circular, |
|
step_tracker, |
|
) |
|
|