File size: 1,435 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from dataclasses import dataclass
from typing import Literal
from jaxtyping import Float
from torch import Tensor
import torch
import torch.nn.functional as F
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
@dataclass
class LossOpacityCfg:
weight: float
type: Literal["exp", "mean", "exp+mean"] = "exp+mean"
@dataclass
class LossOpacityCfgWrapper:
opacity: LossOpacityCfg
class LossOpacity(Loss[LossOpacityCfg, LossOpacityCfgWrapper]):
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict | None,
global_step: int,
) -> Float[Tensor, ""]:
alpha = prediction.alpha
valid_mask = batch['context']['valid_mask'].float()
opacity_loss = F.mse_loss(alpha, valid_mask, reduction='none').mean()
# if self.cfg.type == "exp":
# opacity_loss = torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean()
# elif self.cfg.type == "mean":
# opacity_loss = gaussians.opacities.mean()
# elif self.cfg.type == "exp+mean":
# opacity_loss = 0.5 * torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean() + gaussians.opacities.mean()
return self.cfg.weight * torch.nan_to_num(opacity_loss, nan=0.0, posinf=0.0, neginf=0.0)
|