File size: 1,701 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from typing import Callable, Literal, TypedDict
from jaxtyping import Float, Int64
from torch import Tensor
Stage = Literal["train", "val", "test"]
# The following types mainly exist to make type-hinted keys show up in VS Code. Some
# dimensions are annotated as "_" because either:
# 1. They're expected to change as part of a function call (e.g., resizing the dataset).
# 2. They're expected to vary within the same function call (e.g., the number of views,
# which differs between context and target BatchedViews).
class BatchedViews(TypedDict, total=False):
extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4
intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3
image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width
near: Float[Tensor, "batch _"] # batch view
far: Float[Tensor, "batch _"] # batch view
index: Int64[Tensor, "batch _"] # batch view
overlap: Float[Tensor, "batch _"] # batch view
class BatchedExample(TypedDict, total=False):
target: BatchedViews
context: BatchedViews
scene: list[str]
class UnbatchedViews(TypedDict, total=False):
extrinsics: Float[Tensor, "_ 4 4"]
intrinsics: Float[Tensor, "_ 3 3"]
image: Float[Tensor, "_ 3 height width"]
near: Float[Tensor, " _"]
far: Float[Tensor, " _"]
index: Int64[Tensor, " _"]
class UnbatchedExample(TypedDict, total=False):
target: UnbatchedViews
context: UnbatchedViews
scene: str
# A data shim modifies the example after it's been returned from the data loader.
DataShim = Callable[[BatchedExample], BatchedExample]
AnyExample = BatchedExample | UnbatchedExample
AnyViews = BatchedViews | UnbatchedViews
|