AnySplat / src /dataset /view_sampler /view_sampler_bounded.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame
5.62 kB
from dataclasses import dataclass
from typing import Literal
import torch
from jaxtyping import Float, Int64
from torch import Tensor
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerBoundedCfg:
name: Literal["bounded"]
num_context_views: int
num_target_views: int
min_distance_between_context_views: int
max_distance_between_context_views: int
min_distance_to_context_views: int
warm_up_steps: int
initial_min_distance_between_context_views: int
initial_max_distance_between_context_views: int
max_img_per_gpu: int
min_gap_multiplier: int
max_gap_multiplier: int
class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]):
def schedule(self, initial: int, final: int) -> int:
fraction = self.global_step / self.cfg.warm_up_steps
return min(initial + int((final - initial) * fraction), final)
def sample(
self,
scene: str,
num_context_views: int,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
Float[Tensor, " overlap"], # overlap
]:
num_views, _, _ = extrinsics.shape
# Compute the context view spacing based on the current global step.
if self.stage == "test":
# When testing, always use the full gap.
max_gap = self.cfg.max_distance_between_context_views
min_gap = self.cfg.max_distance_between_context_views
# elif self.cfg.warm_up_steps > 0:
# max_gap = self.schedule(
# self.cfg.initial_max_distance_between_context_views,
# self.cfg.max_distance_between_context_views,
# )
# min_gap = self.schedule(
# self.cfg.initial_min_distance_between_context_views,
# self.cfg.min_distance_between_context_views,
# )
# else:
# max_gap = self.cfg.max_distance_between_context_views
# min_gap = self.cfg.min_distance_between_context_views
min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views]
max_gap = min(max_gap, num_views-1)
# Pick the gap between the context views.
if not self.cameras_are_circular:
max_gap = min(num_views - 1, max_gap)
min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap)
if max_gap < min_gap:
raise ValueError("Example does not have enough frames!")
context_gap = torch.randint(
min_gap,
max_gap + 1,
size=tuple(),
device=device,
).item()
# Pick the left and right context indices.
index_context_left = torch.randint(
num_views if self.cameras_are_circular else num_views - context_gap,
size=tuple(),
device=device,
).item()
if self.stage == "test":
index_context_left = index_context_left * 0
index_context_right = index_context_left + context_gap
if self.is_overfitting:
index_context_left *= 0
index_context_right *= 0
index_context_right += max_gap
# Pick the target view indices.
if self.stage == "test":
# When testing, pick all.
index_target = torch.arange(
index_context_left,
index_context_right + 1,
device=device,
)
else:
# When training or validating (visualizing), pick at random.
index_target = torch.randint(
index_context_left + self.cfg.min_distance_to_context_views,
index_context_right + 1 - self.cfg.min_distance_to_context_views,
size=(self.cfg.num_target_views,),
device=device,
)
# Apply modulo for circular datasets.
if self.cameras_are_circular:
index_target %= num_views
index_context_right %= num_views
# If more than two context views are desired, pick extra context views between
# the left and right ones.
if num_context_views > 2:
num_extra_views = num_context_views - 2
extra_views = []
while len(set(extra_views)) != num_extra_views:
extra_views = torch.randint(
index_context_left + 1,
index_context_right,
(num_extra_views,),
).tolist()
else:
extra_views = []
overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy
return (
torch.tensor((index_context_left, *extra_views, index_context_right)),
index_target,
overlap
)
@property
def num_context_views(self) -> int:
return self.cfg.num_context_views
@property
def num_target_views(self) -> int:
return self.cfg.num_target_views
@property
def num_ctxt_gap_mapping(self) -> dict:
mapping = dict()
for num_ctxt in range(2, self.cfg.num_context_views + 1):
mapping[num_ctxt] = [min(num_ctxt * self.cfg.min_gap_multiplier, self.cfg.min_distance_between_context_views),
min(max(num_ctxt * self.cfg.max_gap_multiplier, num_ctxt ** 2), self.cfg.max_distance_between_context_views)]
return mapping