AnySplat / src /loss /loss_opacity.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
from typing import Literal
from jaxtyping import Float
from torch import Tensor
import torch
import torch.nn.functional as F
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
@dataclass
class LossOpacityCfg:
weight: float
type: Literal["exp", "mean", "exp+mean"] = "exp+mean"
@dataclass
class LossOpacityCfgWrapper:
opacity: LossOpacityCfg
class LossOpacity(Loss[LossOpacityCfg, LossOpacityCfgWrapper]):
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict | None,
global_step: int,
) -> Float[Tensor, ""]:
alpha = prediction.alpha
valid_mask = batch['context']['valid_mask'].float()
opacity_loss = F.mse_loss(alpha, valid_mask, reduction='none').mean()
# if self.cfg.type == "exp":
# opacity_loss = torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean()
# elif self.cfg.type == "mean":
# opacity_loss = gaussians.opacities.mean()
# elif self.cfg.type == "exp+mean":
# opacity_loss = 0.5 * torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean() + gaussians.opacities.mean()
return self.cfg.weight * torch.nan_to_num(opacity_loss, nan=0.0, posinf=0.0, neginf=0.0)