AnySplat / src /dataset /view_sampler /view_sampler_arbitrary.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
from typing import Literal
import torch
from jaxtyping import Float, Int64
from torch import Tensor
from .three_view_hack import add_third_context_index
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerArbitraryCfg:
name: Literal["arbitrary"]
num_context_views: int
num_target_views: int
context_views: list[int] | None
target_views: list[int] | None
class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]):
def sample(
self,
scene: str,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
Float[Tensor, " overlap"], # overlap
]:
"""Arbitrarily sample context and target views."""
num_views, _, _ = extrinsics.shape
index_context = torch.randint(
0,
num_views,
size=(self.cfg.num_context_views,),
device=device,
)
# Allow the context views to be fixed.
if self.cfg.context_views is not None:
index_context = torch.tensor(
self.cfg.context_views, dtype=torch.int64, device=device
)
if self.cfg.num_context_views == 3 and len(self.cfg.context_views) == 2:
index_context = add_third_context_index(index_context)
else:
assert len(self.cfg.context_views) == self.cfg.num_context_views
index_target = torch.randint(
0,
num_views,
size=(self.cfg.num_target_views,),
device=device,
)
# Allow the target views to be fixed.
if self.cfg.target_views is not None:
assert len(self.cfg.target_views) == self.cfg.num_target_views
index_target = torch.tensor(
self.cfg.target_views, dtype=torch.int64, device=device
)
overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy
return index_context, index_target, overlap
@property
def num_context_views(self) -> int:
return self.cfg.num_context_views
@property
def num_target_views(self) -> int:
return self.cfg.num_target_views