|
from dataclasses import dataclass |
|
from typing import Literal |
|
|
|
import torch |
|
from jaxtyping import Float, Int64 |
|
from torch import Tensor |
|
|
|
from .view_sampler import ViewSampler |
|
|
|
|
|
@dataclass |
|
class ViewSamplerBoundedCfg: |
|
name: Literal["bounded"] |
|
num_context_views: int |
|
num_target_views: int |
|
min_distance_between_context_views: int |
|
max_distance_between_context_views: int |
|
min_distance_to_context_views: int |
|
warm_up_steps: int |
|
initial_min_distance_between_context_views: int |
|
initial_max_distance_between_context_views: int |
|
max_img_per_gpu: int |
|
min_gap_multiplier: int |
|
max_gap_multiplier: int |
|
|
|
class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): |
|
def schedule(self, initial: int, final: int) -> int: |
|
fraction = self.global_step / self.cfg.warm_up_steps |
|
return min(initial + int((final - initial) * fraction), final) |
|
|
|
def sample( |
|
self, |
|
scene: str, |
|
num_context_views: int, |
|
extrinsics: Float[Tensor, "view 4 4"], |
|
intrinsics: Float[Tensor, "view 3 3"], |
|
device: torch.device = torch.device("cpu"), |
|
) -> tuple[ |
|
Int64[Tensor, " context_view"], |
|
Int64[Tensor, " target_view"], |
|
Float[Tensor, " overlap"], |
|
]: |
|
num_views, _, _ = extrinsics.shape |
|
|
|
|
|
if self.stage == "test": |
|
|
|
max_gap = self.cfg.max_distance_between_context_views |
|
min_gap = self.cfg.max_distance_between_context_views |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views] |
|
max_gap = min(max_gap, num_views-1) |
|
|
|
if not self.cameras_are_circular: |
|
max_gap = min(num_views - 1, max_gap) |
|
min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) |
|
if max_gap < min_gap: |
|
raise ValueError("Example does not have enough frames!") |
|
context_gap = torch.randint( |
|
min_gap, |
|
max_gap + 1, |
|
size=tuple(), |
|
device=device, |
|
).item() |
|
|
|
|
|
index_context_left = torch.randint( |
|
num_views if self.cameras_are_circular else num_views - context_gap, |
|
size=tuple(), |
|
device=device, |
|
).item() |
|
if self.stage == "test": |
|
index_context_left = index_context_left * 0 |
|
index_context_right = index_context_left + context_gap |
|
|
|
if self.is_overfitting: |
|
index_context_left *= 0 |
|
index_context_right *= 0 |
|
index_context_right += max_gap |
|
|
|
|
|
if self.stage == "test": |
|
|
|
index_target = torch.arange( |
|
index_context_left, |
|
index_context_right + 1, |
|
device=device, |
|
) |
|
else: |
|
|
|
index_target = torch.randint( |
|
index_context_left + self.cfg.min_distance_to_context_views, |
|
index_context_right + 1 - self.cfg.min_distance_to_context_views, |
|
size=(self.cfg.num_target_views,), |
|
device=device, |
|
) |
|
|
|
|
|
if self.cameras_are_circular: |
|
index_target %= num_views |
|
index_context_right %= num_views |
|
|
|
|
|
|
|
if num_context_views > 2: |
|
num_extra_views = num_context_views - 2 |
|
extra_views = [] |
|
while len(set(extra_views)) != num_extra_views: |
|
extra_views = torch.randint( |
|
index_context_left + 1, |
|
index_context_right, |
|
(num_extra_views,), |
|
).tolist() |
|
else: |
|
extra_views = [] |
|
|
|
overlap = torch.tensor([0.5], dtype=torch.float32, device=device) |
|
|
|
return ( |
|
torch.tensor((index_context_left, *extra_views, index_context_right)), |
|
index_target, |
|
overlap |
|
) |
|
|
|
@property |
|
def num_context_views(self) -> int: |
|
return self.cfg.num_context_views |
|
|
|
@property |
|
def num_target_views(self) -> int: |
|
return self.cfg.num_target_views |
|
|
|
@property |
|
def num_ctxt_gap_mapping(self) -> dict: |
|
mapping = dict() |
|
for num_ctxt in range(2, self.cfg.num_context_views + 1): |
|
mapping[num_ctxt] = [min(num_ctxt * self.cfg.min_gap_multiplier, self.cfg.min_distance_between_context_views), |
|
min(max(num_ctxt * self.cfg.max_gap_multiplier, num_ctxt ** 2), self.cfg.max_distance_between_context_views)] |
|
return mapping |
|
|