alexnasa's picture
Upload 243 files
2568013 verified
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, Literal, TypeVar
from jaxtyping import Float
from torch import Tensor, nn
from ..types import Gaussians
DepthRenderingMode = Literal[
"depth",
"log",
"disparity",
"relative_disparity",
]
@dataclass
class DecoderOutput:
color: Float[Tensor, "batch view 3 height width"]
depth: Float[Tensor, "batch view height width"] | None
alpha: Float[Tensor, "batch view height width"] | None
lod_rendering: dict | None
T = TypeVar("T")
class Decoder(nn.Module, ABC, Generic[T]):
cfg: T
def __init__(self, cfg: T) -> None:
super().__init__()
self.cfg = cfg
@abstractmethod
def forward(
self,
gaussians: Gaussians,
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
near: Float[Tensor, "batch view"],
far: Float[Tensor, "batch view"],
image_shape: tuple[int, int],
depth_mode: DepthRenderingMode | None = None,
) -> DecoderOutput:
pass