AnySplat / src /dataset /view_sampler /view_sampler_arbitrary.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame
2.42 kB
from dataclasses import dataclass
from typing import Literal
import torch
from jaxtyping import Float, Int64
from torch import Tensor
from .three_view_hack import add_third_context_index
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerArbitraryCfg:
name: Literal["arbitrary"]
num_context_views: int
num_target_views: int
context_views: list[int] | None
target_views: list[int] | None
class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]):
def sample(
self,
scene: str,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
Float[Tensor, " overlap"], # overlap
]:
"""Arbitrarily sample context and target views."""
num_views, _, _ = extrinsics.shape
index_context = torch.randint(
0,
num_views,
size=(self.cfg.num_context_views,),
device=device,
)
# Allow the context views to be fixed.
if self.cfg.context_views is not None:
index_context = torch.tensor(
self.cfg.context_views, dtype=torch.int64, device=device
)
if self.cfg.num_context_views == 3 and len(self.cfg.context_views) == 2:
index_context = add_third_context_index(index_context)
else:
assert len(self.cfg.context_views) == self.cfg.num_context_views
index_target = torch.randint(
0,
num_views,
size=(self.cfg.num_target_views,),
device=device,
)
# Allow the target views to be fixed.
if self.cfg.target_views is not None:
assert len(self.cfg.target_views) == self.cfg.num_target_views
index_target = torch.tensor(
self.cfg.target_views, dtype=torch.int64, device=device
)
overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy
return index_context, index_target, overlap
@property
def num_context_views(self) -> int:
return self.cfg.num_context_views
@property
def num_target_views(self) -> int:
return self.cfg.num_target_views