from abc import ABC, abstractmethod from dataclasses import fields from typing import Generic, TypeVar from jaxtyping import Float from torch import Tensor, nn from src.dataset.types import BatchedExample from src.model.decoder.decoder import DecoderOutput from src.model.types import Gaussians T_cfg = TypeVar("T_cfg") T_wrapper = TypeVar("T_wrapper") class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): cfg: T_cfg name: str def __init__(self, cfg: T_wrapper) -> None: super().__init__() # Extract the configuration from the wrapper. (field,) = fields(type(cfg)) self.cfg = getattr(cfg, field.name) self.name = field.name @abstractmethod def forward( self, prediction: DecoderOutput, batch: BatchedExample, gaussians: Gaussians, depth_dict: dict, global_step: int, ) -> Float[Tensor, ""]: pass