AnySplat / src /dataset /view_sampler /view_sampler_evaluation.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame
2.17 kB
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import torch
from dacite import Config, from_dict
from jaxtyping import Float, Int64
from torch import Tensor
from ...evaluation.evaluation_index_generator import IndexEntry
from ...global_cfg import get_cfg
from ...misc.step_tracker import StepTracker
from ..types import Stage
from .three_view_hack import add_third_context_index
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerEvaluationCfg:
name: Literal["evaluation"]
index_path: Path
num_context_views: int
class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]):
index: dict[str, IndexEntry | None]
def __init__(
self,
cfg: ViewSamplerEvaluationCfg,
stage: Stage,
is_overfitting: bool,
cameras_are_circular: bool,
step_tracker: StepTracker | None,
) -> None:
super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker)
dacite_config = Config(cast=[tuple])
with cfg.index_path.open("r") as f:
self.index = {
k: None if v is None else from_dict(IndexEntry, v, dacite_config)
for k, v in json.load(f).items()
}
def sample(
self,
scene: str,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
**kwargs,
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
]:
entry = self.index.get(scene)
if entry is None:
raise ValueError(f"No indices available for scene {scene}.")
context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device)
target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device)
return context_indices, target_indices, torch.zeros(1)
@property
def num_context_views(self) -> int:
return 0
@property
def num_target_views(self) -> int:
return 0