|
from dataclasses import dataclass |
|
|
|
import torch |
|
from einops import reduce |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from src.dataset.types import BatchedExample |
|
from src.model.decoder.decoder import DecoderOutput |
|
from src.model.types import Gaussians |
|
from .loss import Loss |
|
from typing import Generic, Literal, Optional, TypeVar |
|
from dataclasses import fields |
|
import torch.nn.functional as F |
|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from src.misc.utils import vis_depth_map |
|
|
|
T_cfg = TypeVar("T_cfg") |
|
T_wrapper = TypeVar("T_wrapper") |
|
|
|
|
|
@dataclass |
|
class LossDepthConsisCfg: |
|
weight: float |
|
sigma_image: float | None |
|
use_second_derivative: bool |
|
loss_type: Literal['MSE', 'EdgeAwareLogL1', 'PearsonDepth'] = 'MSE' |
|
detach: bool = False |
|
conf: bool = False |
|
not_use_valid_mask: bool = False |
|
apply_after_step: int = 0 |
|
|
|
@dataclass |
|
class LossDepthConsisCfgWrapper: |
|
depth_consis: LossDepthConsisCfg |
|
|
|
|
|
class LogL1(torch.nn.Module): |
|
"""Log-L1 loss""" |
|
|
|
def __init__( |
|
self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs |
|
): |
|
super().__init__() |
|
self.implementation = implementation |
|
|
|
def forward(self, pred, gt): |
|
if self.implementation == "scalar": |
|
return torch.log(1 + torch.abs(pred - gt)).mean() |
|
else: |
|
return torch.log(1 + torch.abs(pred - gt)) |
|
|
|
class EdgeAwareLogL1(torch.nn.Module): |
|
"""Gradient aware Log-L1 loss""" |
|
|
|
def __init__( |
|
self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs |
|
): |
|
super().__init__() |
|
self.implementation = implementation |
|
self.logl1 = LogL1(implementation="per-pixel") |
|
|
|
def forward(self, pred: Tensor, gt: Tensor, rgb: Tensor, mask: Optional[Tensor]): |
|
logl1 = self.logl1(pred, gt) |
|
|
|
grad_img_x = torch.mean( |
|
torch.abs(rgb[..., :, :-1, :] - rgb[..., :, 1:, :]), -1, keepdim=True |
|
) |
|
grad_img_y = torch.mean( |
|
torch.abs(rgb[..., :-1, :, :] - rgb[..., 1:, :, :]), -1, keepdim=True |
|
) |
|
lambda_x = torch.exp(-grad_img_x) |
|
lambda_y = torch.exp(-grad_img_y) |
|
|
|
loss_x = lambda_x * logl1[..., :, :-1, :] |
|
loss_y = lambda_y * logl1[..., :-1, :, :] |
|
|
|
if self.implementation == "per-pixel": |
|
if mask is not None: |
|
loss_x[~mask[..., :, :-1, :]] = 0 |
|
loss_y[~mask[..., :-1, :, :]] = 0 |
|
return loss_x[..., :-1, :, :] + loss_y[..., :, :-1, :] |
|
|
|
if mask is not None: |
|
assert mask.shape[:2] == pred.shape[:2] |
|
loss_x = loss_x[mask[..., :, :-1, :]] |
|
loss_y = loss_y[mask[..., :-1, :, :]] |
|
|
|
if self.implementation == "scalar": |
|
return loss_x.mean() + loss_y.mean() |
|
|
|
class LossDepthConsis(Loss[LossDepthConsisCfg, LossDepthConsisCfgWrapper]): |
|
def __init__(self, cfg: T_wrapper) -> None: |
|
super().__init__(cfg) |
|
|
|
|
|
(field,) = fields(type(cfg)) |
|
self.cfg = getattr(cfg, field.name) |
|
self.name = field.name |
|
|
|
def forward( |
|
self, |
|
prediction: DecoderOutput, |
|
batch: BatchedExample, |
|
gaussians: Gaussians, |
|
depth_dict: dict, |
|
global_step: int, |
|
) -> Float[Tensor, ""]: |
|
|
|
|
|
if global_step < self.cfg.apply_after_step: |
|
return torch.tensor(0.0, dtype=torch.float32, device=prediction.depth.device) |
|
|
|
|
|
|
|
rendered_depth = prediction.depth |
|
gt_rgb = (batch["context"]["image"] + 1) / 2 |
|
valid_mask = depth_dict["distill_infos"]['conf_mask'] |
|
|
|
if batch['context']['valid_mask'].sum() > 0: |
|
valid_mask = batch['context']['valid_mask'] |
|
|
|
|
|
if self.cfg.not_use_valid_mask: |
|
valid_mask = torch.ones_like(valid_mask, device=valid_mask.device) |
|
pred_depth = depth_dict['depth'].squeeze(-1) |
|
if self.cfg.detach: |
|
pred_depth = pred_depth.detach() |
|
if self.cfg.loss_type == 'MSE': |
|
depth_loss = F.mse_loss(rendered_depth, pred_depth, reduction='none')[valid_mask].mean() |
|
elif self.cfg.loss_type == 'EdgeAwareLogL1': |
|
rendered_depth = rendered_depth.flatten(0, 1).unsqueeze(-1) |
|
pred_depth = pred_depth.flatten(0, 1).unsqueeze(-1) |
|
gt_rgb = gt_rgb.flatten(0, 1).permute(0, 2, 3, 1) |
|
valid_mask = valid_mask.flatten(0, 1).unsqueeze(-1) |
|
depth_loss = EdgeAwareLogL1()(rendered_depth, pred_depth, gt_rgb, valid_mask) |
|
return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0) |