AnySplat / src /loss /loss_mse.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
from jaxtyping import Float
from torch import Tensor
import torch
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
@dataclass
class LossMseCfg:
weight: float
conf: bool = False
mask: bool = False
alpha: bool = False
@dataclass
class LossMseCfgWrapper:
mse: LossMseCfg
class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]):
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict | None,
global_step: int,
) -> Float[Tensor, ""]:
# Get alpha and valid mask from inputs
alpha = prediction.alpha
# valid_mask = torch.ones_like(alpha, device=alpha.device).bool()
valid_mask = batch['context']['valid_mask']
# # only for objaverse
# if batch['context']['valid_mask'].sum() > 0:
# valid_mask = batch['context']['valid_mask']
# Determine which mask to use based on config
if self.cfg.mask:
mask = valid_mask
elif self.cfg.alpha:
mask = alpha
elif self.cfg.conf:
mask = depth_dict['conf_valid_mask']
else:
mask = torch.ones_like(alpha, device=alpha.device).bool()
# Rearrange and mask predicted and ground truth images
pred_img = prediction.color.permute(0, 1, 3, 4, 2)[mask]
gt_img = ((batch["context"]["image"][:, batch["using_index"]] + 1) / 2).permute(0, 1, 3, 4, 2)[mask]
delta = pred_img - gt_img
return self.cfg.weight * torch.nan_to_num((delta**2).mean(), nan=0.0, posinf=0.0, neginf=0.0)