|
from dataclasses import dataclass |
|
from typing import Literal |
|
|
|
import torch |
|
from jaxtyping import Float, Int64 |
|
from torch import Tensor |
|
|
|
from .view_sampler import ViewSampler |
|
|
|
|
|
@dataclass |
|
class ViewSamplerAllCfg: |
|
name: Literal["all"] |
|
|
|
|
|
class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): |
|
def sample( |
|
self, |
|
scene: str, |
|
extrinsics: Float[Tensor, "view 4 4"], |
|
intrinsics: Float[Tensor, "view 3 3"], |
|
device: torch.device = torch.device("cpu"), |
|
) -> tuple[ |
|
Int64[Tensor, " context_view"], |
|
Int64[Tensor, " target_view"], |
|
]: |
|
v, _, _ = extrinsics.shape |
|
all_frames = torch.arange(v, device=device) |
|
return all_frames, all_frames |
|
|
|
@property |
|
def num_context_views(self) -> int: |
|
return 0 |
|
|
|
@property |
|
def num_target_views(self) -> int: |
|
return 0 |
|
|