File size: 1,435 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from dataclasses import dataclass
from typing import Literal

from jaxtyping import Float
from torch import Tensor
import torch
import torch.nn.functional as F
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss


@dataclass
class LossOpacityCfg:
    weight: float
    type: Literal["exp", "mean", "exp+mean"] = "exp+mean"


@dataclass
class LossOpacityCfgWrapper:
    opacity: LossOpacityCfg


class LossOpacity(Loss[LossOpacityCfg, LossOpacityCfgWrapper]):
    def forward(
        self,
        prediction: DecoderOutput,
        batch: BatchedExample,
        gaussians: Gaussians,
        depth_dict: dict | None,
        global_step: int,
    ) -> Float[Tensor, ""]:
        alpha = prediction.alpha
        valid_mask = batch['context']['valid_mask'].float()
        opacity_loss = F.mse_loss(alpha, valid_mask, reduction='none').mean()
        # if self.cfg.type == "exp":
        #     opacity_loss = torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean()
        # elif self.cfg.type == "mean":
        #     opacity_loss = gaussians.opacities.mean()
        # elif self.cfg.type == "exp+mean":
        #     opacity_loss = 0.5 * torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean() + gaussians.opacities.mean()
        return self.cfg.weight * torch.nan_to_num(opacity_loss, nan=0.0, posinf=0.0, neginf=0.0)