|
from abc import ABC, abstractmethod |
|
from typing import Generic, TypeVar |
|
|
|
import torch |
|
from jaxtyping import Float, Int64 |
|
from torch import Tensor |
|
|
|
from ...misc.step_tracker import StepTracker |
|
from ..types import Stage |
|
|
|
T = TypeVar("T") |
|
|
|
|
|
class ViewSampler(ABC, Generic[T]): |
|
cfg: T |
|
stage: Stage |
|
is_overfitting: bool |
|
cameras_are_circular: bool |
|
step_tracker: StepTracker | None |
|
|
|
def __init__( |
|
self, |
|
cfg: T, |
|
stage: Stage, |
|
is_overfitting: bool, |
|
cameras_are_circular: bool, |
|
step_tracker: StepTracker | None, |
|
) -> None: |
|
self.cfg = cfg |
|
self.stage = stage |
|
self.is_overfitting = is_overfitting |
|
self.cameras_are_circular = cameras_are_circular |
|
self.step_tracker = step_tracker |
|
|
|
@abstractmethod |
|
def sample( |
|
self, |
|
scene: str, |
|
extrinsics: Float[Tensor, "view 4 4"], |
|
intrinsics: Float[Tensor, "view 3 3"], |
|
device: torch.device = torch.device("cpu"), |
|
) -> tuple[ |
|
Int64[Tensor, " context_view"], |
|
Int64[Tensor, " target_view"], |
|
Float[Tensor, " overlap"], |
|
]: |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def num_target_views(self) -> int: |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def num_context_views(self) -> int: |
|
pass |
|
|
|
@property |
|
def global_step(self) -> int: |
|
return 0 if self.step_tracker is None else self.step_tracker.get_step() |
|
|