|
from abc import ABC, abstractmethod |
|
from dataclasses import fields |
|
from typing import Generic, TypeVar |
|
|
|
from jaxtyping import Float |
|
from torch import Tensor, nn |
|
|
|
from src.dataset.types import BatchedExample |
|
from src.model.decoder.decoder import DecoderOutput |
|
from src.model.types import Gaussians |
|
|
|
T_cfg = TypeVar("T_cfg") |
|
T_wrapper = TypeVar("T_wrapper") |
|
|
|
|
|
class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): |
|
cfg: T_cfg |
|
name: str |
|
|
|
def __init__(self, cfg: T_wrapper) -> None: |
|
super().__init__() |
|
|
|
|
|
(field,) = fields(type(cfg)) |
|
self.cfg = getattr(cfg, field.name) |
|
self.name = field.name |
|
|
|
@abstractmethod |
|
def forward( |
|
self, |
|
prediction: DecoderOutput, |
|
batch: BatchedExample, |
|
gaussians: Gaussians, |
|
depth_dict: dict, |
|
global_step: int, |
|
) -> Float[Tensor, ""]: |
|
pass |
|
|