File size: 1,521 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
import torch
from jaxtyping import Float, Int64
from torch import Tensor
from ...misc.step_tracker import StepTracker
from ..types import Stage
T = TypeVar("T")
class ViewSampler(ABC, Generic[T]):
cfg: T
stage: Stage
is_overfitting: bool
cameras_are_circular: bool
step_tracker: StepTracker | None
def __init__(
self,
cfg: T,
stage: Stage,
is_overfitting: bool,
cameras_are_circular: bool,
step_tracker: StepTracker | None,
) -> None:
self.cfg = cfg
self.stage = stage
self.is_overfitting = is_overfitting
self.cameras_are_circular = cameras_are_circular
self.step_tracker = step_tracker
@abstractmethod
def sample(
self,
scene: str,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
Float[Tensor, " overlap"], # overlap
]:
pass
@property
@abstractmethod
def num_target_views(self) -> int:
pass
@property
@abstractmethod
def num_context_views(self) -> int:
pass
@property
def global_step(self) -> int:
return 0 if self.step_tracker is None else self.step_tracker.get_step()
|