|
from dataclasses import dataclass |
|
|
|
import torch |
|
from einops import reduce |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from src.dataset.types import BatchedExample |
|
from src.model.decoder.decoder import DecoderOutput |
|
from src.model.types import Gaussians |
|
from .loss import Loss |
|
from typing import Generic, Literal, TypeVar |
|
from dataclasses import fields |
|
import torch.nn.functional as F |
|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from src.misc.utils import vis_depth_map |
|
|
|
T_cfg = TypeVar("T_cfg") |
|
T_wrapper = TypeVar("T_wrapper") |
|
|
|
|
|
@dataclass |
|
class LossDepthGTCfg: |
|
weight: float |
|
type: Literal["l1", "mse", "silog", "gradient", "l1+gradient"] | None |
|
|
|
@dataclass |
|
class LossDepthGTCfgWrapper: |
|
depthgt: LossDepthGTCfg |
|
|
|
|
|
class LossDepthGT(Loss[LossDepthGTCfg, LossDepthGTCfgWrapper]): |
|
def gradient_loss(self, gs_depth, target_depth, target_valid_mask): |
|
diff = gs_depth - target_depth |
|
|
|
grad_x_diff = diff[:, :, :, 1:] - diff[:, :, :, :-1] |
|
grad_y_diff = diff[:, :, 1:, :] - diff[:, :, :-1, :] |
|
|
|
mask_x = target_valid_mask[:, :, :, 1:] * target_valid_mask[:, :, :, :-1] |
|
mask_y = target_valid_mask[:, :, 1:, :] * target_valid_mask[:, :, :-1, :] |
|
|
|
grad_x_diff = grad_x_diff * mask_x |
|
grad_y_diff = grad_y_diff * mask_y |
|
|
|
grad_x_diff = grad_x_diff.clamp(min=-100, max=100) |
|
grad_y_diff = grad_y_diff.clamp(min=-100, max=100) |
|
|
|
loss_x = grad_x_diff.abs().sum() |
|
loss_y = grad_y_diff.abs().sum() |
|
num_valid = mask_x.sum() + mask_y.sum() |
|
|
|
if num_valid == 0: |
|
gradient_loss = 0 |
|
else: |
|
gradient_loss = (loss_x + loss_y) / (num_valid + 1e-6) |
|
|
|
return gradient_loss |
|
|
|
def forward( |
|
self, |
|
prediction: DecoderOutput, |
|
batch: BatchedExample, |
|
gaussians: Gaussians, |
|
global_step: int, |
|
) -> Float[Tensor, ""]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
target_depth = batch["target"]["depth"] |
|
target_valid_mask = batch["target"]["valid_mask"] |
|
gs_depth = prediction.depth.clamp(1e-3) |
|
|
|
if self.cfg.type == "l1": |
|
depth_loss = torch.abs(target_depth[target_valid_mask] - gs_depth[target_valid_mask]).mean() |
|
elif self.cfg.type == "mse": |
|
depth_loss = F.mse_loss(target_depth[target_valid_mask], gs_depth[target_valid_mask]) |
|
elif self.cfg.type == "silog": |
|
depth_loss = torch.log(gs_depth[target_valid_mask]) ** 2 + (gs_depth[target_valid_mask] - target_depth[target_valid_mask]) ** 2 - 0.5 |
|
depth_loss = depth_loss.mean() |
|
elif self.cfg.type == "gradient": |
|
depth_loss = self.gradient_loss(gs_depth, target_depth, target_valid_mask) |
|
elif self.cfg.type == "l1+gradient": |
|
depth_loss_l1 = torch.abs(target_depth[target_valid_mask] - gs_depth[target_valid_mask]).mean() |
|
depth_loss_gradient = self.gradient_loss(gs_depth, target_depth, target_valid_mask) |
|
depth_loss = depth_loss_l1 + depth_loss_gradient |
|
|
|
return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0, posinf=0.0, neginf=0.0) |