from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Generic, Literal, TypeVar from jaxtyping import Float from torch import Tensor, nn from ..types import Gaussians DepthRenderingMode = Literal[ "depth", "log", "disparity", "relative_disparity", ] @dataclass class DecoderOutput: color: Float[Tensor, "batch view 3 height width"] depth: Float[Tensor, "batch view height width"] | None alpha: Float[Tensor, "batch view height width"] | None lod_rendering: dict | None T = TypeVar("T") class Decoder(nn.Module, ABC, Generic[T]): cfg: T def __init__(self, cfg: T) -> None: super().__init__() self.cfg = cfg @abstractmethod def forward( self, gaussians: Gaussians, extrinsics: Float[Tensor, "batch view 4 4"], intrinsics: Float[Tensor, "batch view 3 3"], near: Float[Tensor, "batch view"], far: Float[Tensor, "batch view"], image_shape: tuple[int, int], depth_mode: DepthRenderingMode | None = None, ) -> DecoderOutput: pass