AnySplat / src /dataset /types.py
alexnasa's picture
Upload 243 files
2568013 verified
from typing import Callable, Literal, TypedDict
from jaxtyping import Float, Int64
from torch import Tensor
Stage = Literal["train", "val", "test"]
# The following types mainly exist to make type-hinted keys show up in VS Code. Some
# dimensions are annotated as "_" because either:
# 1. They're expected to change as part of a function call (e.g., resizing the dataset).
# 2. They're expected to vary within the same function call (e.g., the number of views,
# which differs between context and target BatchedViews).
class BatchedViews(TypedDict, total=False):
extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4
intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3
image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width
near: Float[Tensor, "batch _"] # batch view
far: Float[Tensor, "batch _"] # batch view
index: Int64[Tensor, "batch _"] # batch view
overlap: Float[Tensor, "batch _"] # batch view
class BatchedExample(TypedDict, total=False):
target: BatchedViews
context: BatchedViews
scene: list[str]
class UnbatchedViews(TypedDict, total=False):
extrinsics: Float[Tensor, "_ 4 4"]
intrinsics: Float[Tensor, "_ 3 3"]
image: Float[Tensor, "_ 3 height width"]
near: Float[Tensor, " _"]
far: Float[Tensor, " _"]
index: Int64[Tensor, " _"]
class UnbatchedExample(TypedDict, total=False):
target: UnbatchedViews
context: UnbatchedViews
scene: str
# A data shim modifies the example after it's been returned from the data loader.
DataShim = Callable[[BatchedExample], BatchedExample]
AnyExample = BatchedExample | UnbatchedExample
AnyViews = BatchedViews | UnbatchedViews