import json from dataclasses import dataclass from pathlib import Path from typing import Literal import torch from dacite import Config, from_dict from jaxtyping import Float, Int64 from torch import Tensor from ...evaluation.evaluation_index_generator import IndexEntry from ...global_cfg import get_cfg from ...misc.step_tracker import StepTracker from ..types import Stage from .three_view_hack import add_third_context_index from .view_sampler import ViewSampler @dataclass class ViewSamplerEvaluationCfg: name: Literal["evaluation"] index_path: Path num_context_views: int class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): index: dict[str, IndexEntry | None] def __init__( self, cfg: ViewSamplerEvaluationCfg, stage: Stage, is_overfitting: bool, cameras_are_circular: bool, step_tracker: StepTracker | None, ) -> None: super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) dacite_config = Config(cast=[tuple]) with cfg.index_path.open("r") as f: self.index = { k: None if v is None else from_dict(IndexEntry, v, dacite_config) for k, v in json.load(f).items() } def sample( self, scene: str, extrinsics: Float[Tensor, "view 4 4"], intrinsics: Float[Tensor, "view 3 3"], device: torch.device = torch.device("cpu"), **kwargs, ) -> tuple[ Int64[Tensor, " context_view"], # indices for context views Int64[Tensor, " target_view"], # indices for target views ]: entry = self.index.get(scene) if entry is None: raise ValueError(f"No indices available for scene {scene}.") context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) return context_indices, target_indices, torch.zeros(1) @property def num_context_views(self) -> int: return 0 @property def num_target_views(self) -> int: return 0