|
from dataclasses import dataclass |
|
|
|
from jaxtyping import Float |
|
from torch import Tensor |
|
import torch |
|
from src.dataset.types import BatchedExample |
|
from src.model.decoder.decoder import DecoderOutput |
|
from src.model.types import Gaussians |
|
from .loss import Loss |
|
|
|
|
|
@dataclass |
|
class LossMseCfg: |
|
weight: float |
|
conf: bool = False |
|
mask: bool = False |
|
alpha: bool = False |
|
|
|
|
|
@dataclass |
|
class LossMseCfgWrapper: |
|
mse: LossMseCfg |
|
|
|
|
|
class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): |
|
def forward( |
|
self, |
|
prediction: DecoderOutput, |
|
batch: BatchedExample, |
|
gaussians: Gaussians, |
|
depth_dict: dict | None, |
|
global_step: int, |
|
) -> Float[Tensor, ""]: |
|
|
|
alpha = prediction.alpha |
|
|
|
valid_mask = batch['context']['valid_mask'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.cfg.mask: |
|
mask = valid_mask |
|
elif self.cfg.alpha: |
|
mask = alpha |
|
elif self.cfg.conf: |
|
mask = depth_dict['conf_valid_mask'] |
|
else: |
|
mask = torch.ones_like(alpha, device=alpha.device).bool() |
|
|
|
|
|
pred_img = prediction.color.permute(0, 1, 3, 4, 2)[mask] |
|
gt_img = ((batch["context"]["image"][:, batch["using_index"]] + 1) / 2).permute(0, 1, 3, 4, 2)[mask] |
|
|
|
delta = pred_img - gt_img |
|
|
|
return self.cfg.weight * torch.nan_to_num((delta**2).mean(), nan=0.0, posinf=0.0, neginf=0.0) |
|
|