| 
							 | 
						from abc import ABC, abstractmethod | 
					
					
						
						| 
							 | 
						from dataclasses import fields | 
					
					
						
						| 
							 | 
						from typing import Generic, TypeVar | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from jaxtyping import Float | 
					
					
						
						| 
							 | 
						from torch import Tensor, nn | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from src.dataset.types import BatchedExample | 
					
					
						
						| 
							 | 
						from src.model.decoder.decoder import DecoderOutput | 
					
					
						
						| 
							 | 
						from src.model.types import Gaussians | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						T_cfg = TypeVar("T_cfg") | 
					
					
						
						| 
							 | 
						T_wrapper = TypeVar("T_wrapper") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): | 
					
					
						
						| 
							 | 
						    cfg: T_cfg | 
					
					
						
						| 
							 | 
						    name: str | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, cfg: T_wrapper) -> None: | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        (field,) = fields(type(cfg)) | 
					
					
						
						| 
							 | 
						        self.cfg = getattr(cfg, field.name) | 
					
					
						
						| 
							 | 
						        self.name = field.name | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @abstractmethod | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        prediction: DecoderOutput, | 
					
					
						
						| 
							 | 
						        batch: BatchedExample, | 
					
					
						
						| 
							 | 
						        gaussians: Gaussians, | 
					
					
						
						| 
							 | 
						        depth_dict: dict, | 
					
					
						
						| 
							 | 
						        global_step: int, | 
					
					
						
						| 
							 | 
						    ) -> Float[Tensor, ""]: | 
					
					
						
						| 
							 | 
						        pass | 
					
					
						
						| 
							 | 
						
 |