|
from dataclasses import dataclass |
|
from typing import Literal |
|
|
|
import torch |
|
from jaxtyping import Float, Int64 |
|
from torch import Tensor |
|
|
|
from .three_view_hack import add_third_context_index |
|
from .view_sampler import ViewSampler |
|
|
|
|
|
@dataclass |
|
class ViewSamplerArbitraryCfg: |
|
name: Literal["arbitrary"] |
|
num_context_views: int |
|
num_target_views: int |
|
context_views: list[int] | None |
|
target_views: list[int] | None |
|
|
|
|
|
class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): |
|
def sample( |
|
self, |
|
scene: str, |
|
extrinsics: Float[Tensor, "view 4 4"], |
|
intrinsics: Float[Tensor, "view 3 3"], |
|
device: torch.device = torch.device("cpu"), |
|
) -> tuple[ |
|
Int64[Tensor, " context_view"], |
|
Int64[Tensor, " target_view"], |
|
Float[Tensor, " overlap"], |
|
]: |
|
"""Arbitrarily sample context and target views.""" |
|
num_views, _, _ = extrinsics.shape |
|
|
|
index_context = torch.randint( |
|
0, |
|
num_views, |
|
size=(self.cfg.num_context_views,), |
|
device=device, |
|
) |
|
|
|
|
|
if self.cfg.context_views is not None: |
|
index_context = torch.tensor( |
|
self.cfg.context_views, dtype=torch.int64, device=device |
|
) |
|
|
|
if self.cfg.num_context_views == 3 and len(self.cfg.context_views) == 2: |
|
index_context = add_third_context_index(index_context) |
|
else: |
|
assert len(self.cfg.context_views) == self.cfg.num_context_views |
|
index_target = torch.randint( |
|
0, |
|
num_views, |
|
size=(self.cfg.num_target_views,), |
|
device=device, |
|
) |
|
|
|
|
|
if self.cfg.target_views is not None: |
|
assert len(self.cfg.target_views) == self.cfg.num_target_views |
|
index_target = torch.tensor( |
|
self.cfg.target_views, dtype=torch.int64, device=device |
|
) |
|
|
|
overlap = torch.tensor([0.5], dtype=torch.float32, device=device) |
|
|
|
return index_context, index_target, overlap |
|
|
|
@property |
|
def num_context_views(self) -> int: |
|
return self.cfg.num_context_views |
|
|
|
@property |
|
def num_target_views(self) -> int: |
|
return self.cfg.num_target_views |
|
|