File size: 2,424 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from dataclasses import dataclass
from typing import Literal
import torch
from jaxtyping import Float, Int64
from torch import Tensor
from .three_view_hack import add_third_context_index
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerArbitraryCfg:
name: Literal["arbitrary"]
num_context_views: int
num_target_views: int
context_views: list[int] | None
target_views: list[int] | None
class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]):
def sample(
self,
scene: str,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
Float[Tensor, " overlap"], # overlap
]:
"""Arbitrarily sample context and target views."""
num_views, _, _ = extrinsics.shape
index_context = torch.randint(
0,
num_views,
size=(self.cfg.num_context_views,),
device=device,
)
# Allow the context views to be fixed.
if self.cfg.context_views is not None:
index_context = torch.tensor(
self.cfg.context_views, dtype=torch.int64, device=device
)
if self.cfg.num_context_views == 3 and len(self.cfg.context_views) == 2:
index_context = add_third_context_index(index_context)
else:
assert len(self.cfg.context_views) == self.cfg.num_context_views
index_target = torch.randint(
0,
num_views,
size=(self.cfg.num_target_views,),
device=device,
)
# Allow the target views to be fixed.
if self.cfg.target_views is not None:
assert len(self.cfg.target_views) == self.cfg.num_target_views
index_target = torch.tensor(
self.cfg.target_views, dtype=torch.int64, device=device
)
overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy
return index_context, index_target, overlap
@property
def num_context_views(self) -> int:
return self.cfg.num_context_views
@property
def num_target_views(self) -> int:
return self.cfg.num_target_views
|