AnySplat / src /loss /loss.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame contribute delete
934 Bytes
from abc import ABC, abstractmethod
from dataclasses import fields
from typing import Generic, TypeVar
from jaxtyping import Float
from torch import Tensor, nn
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")
class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]):
cfg: T_cfg
name: str
def __init__(self, cfg: T_wrapper) -> None:
super().__init__()
# Extract the configuration from the wrapper.
(field,) = fields(type(cfg))
self.cfg = getattr(cfg, field.name)
self.name = field.name
@abstractmethod
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict,
global_step: int,
) -> Float[Tensor, ""]:
pass