from dataclasses import dataclass import torch from einops import reduce from jaxtyping import Float from torch import Tensor from src.dataset.types import BatchedExample from src.model.decoder.decoder import DecoderOutput from src.model.types import Gaussians from .loss import Loss from typing import Generic, Literal, Optional, TypeVar from dataclasses import fields import torch.nn.functional as F import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # from src.loss.depth_anything.dpt import DepthAnything from src.misc.utils import vis_depth_map T_cfg = TypeVar("T_cfg") T_wrapper = TypeVar("T_wrapper") @dataclass class LossDepthConsisCfg: weight: float sigma_image: float | None use_second_derivative: bool loss_type: Literal['MSE', 'EdgeAwareLogL1', 'PearsonDepth'] = 'MSE' detach: bool = False conf: bool = False not_use_valid_mask: bool = False apply_after_step: int = 0 @dataclass class LossDepthConsisCfgWrapper: depth_consis: LossDepthConsisCfg class LogL1(torch.nn.Module): """Log-L1 loss""" def __init__( self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs ): super().__init__() self.implementation = implementation def forward(self, pred, gt): if self.implementation == "scalar": return torch.log(1 + torch.abs(pred - gt)).mean() else: return torch.log(1 + torch.abs(pred - gt)) class EdgeAwareLogL1(torch.nn.Module): """Gradient aware Log-L1 loss""" def __init__( self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs ): super().__init__() self.implementation = implementation self.logl1 = LogL1(implementation="per-pixel") def forward(self, pred: Tensor, gt: Tensor, rgb: Tensor, mask: Optional[Tensor]): logl1 = self.logl1(pred, gt) grad_img_x = torch.mean( torch.abs(rgb[..., :, :-1, :] - rgb[..., :, 1:, :]), -1, keepdim=True ) grad_img_y = torch.mean( torch.abs(rgb[..., :-1, :, :] - rgb[..., 1:, :, :]), -1, keepdim=True ) lambda_x = torch.exp(-grad_img_x) lambda_y = torch.exp(-grad_img_y) loss_x = lambda_x * logl1[..., :, :-1, :] loss_y = lambda_y * logl1[..., :-1, :, :] if self.implementation == "per-pixel": if mask is not None: loss_x[~mask[..., :, :-1, :]] = 0 loss_y[~mask[..., :-1, :, :]] = 0 return loss_x[..., :-1, :, :] + loss_y[..., :, :-1, :] if mask is not None: assert mask.shape[:2] == pred.shape[:2] loss_x = loss_x[mask[..., :, :-1, :]] loss_y = loss_y[mask[..., :-1, :, :]] if self.implementation == "scalar": return loss_x.mean() + loss_y.mean() class LossDepthConsis(Loss[LossDepthConsisCfg, LossDepthConsisCfgWrapper]): def __init__(self, cfg: T_wrapper) -> None: super().__init__(cfg) # Extract the configuration from the wrapper. (field,) = fields(type(cfg)) self.cfg = getattr(cfg, field.name) self.name = field.name def forward( self, prediction: DecoderOutput, batch: BatchedExample, gaussians: Gaussians, depth_dict: dict, global_step: int, ) -> Float[Tensor, ""]: # Before the specified step, don't apply the loss. if global_step < self.cfg.apply_after_step: return torch.tensor(0.0, dtype=torch.float32, device=prediction.depth.device) # Scale the depth between the near and far planes. # conf_valid_mask = depth_dict['conf_valid_mask'] rendered_depth = prediction.depth gt_rgb = (batch["context"]["image"] + 1) / 2 valid_mask = depth_dict["distill_infos"]['conf_mask'] if batch['context']['valid_mask'].sum() > 0: valid_mask = batch['context']['valid_mask'] # if self.cfg.conf: # valid_mask = valid_mask & conf_valid_mask if self.cfg.not_use_valid_mask: valid_mask = torch.ones_like(valid_mask, device=valid_mask.device) pred_depth = depth_dict['depth'].squeeze(-1) if self.cfg.detach: pred_depth = pred_depth.detach() if self.cfg.loss_type == 'MSE': depth_loss = F.mse_loss(rendered_depth, pred_depth, reduction='none')[valid_mask].mean() elif self.cfg.loss_type == 'EdgeAwareLogL1': rendered_depth = rendered_depth.flatten(0, 1).unsqueeze(-1) pred_depth = pred_depth.flatten(0, 1).unsqueeze(-1) gt_rgb = gt_rgb.flatten(0, 1).permute(0, 2, 3, 1) valid_mask = valid_mask.flatten(0, 1).unsqueeze(-1) depth_loss = EdgeAwareLogL1()(rendered_depth, pred_depth, gt_rgb, valid_mask) return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0)