File size: 4,979 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from dataclasses import dataclass
import torch
from einops import reduce
from jaxtyping import Float
from torch import Tensor
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
from typing import Generic, Literal, Optional, TypeVar
from dataclasses import fields
import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from src.loss.depth_anything.dpt import DepthAnything
from src.misc.utils import vis_depth_map
T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")
@dataclass
class LossDepthConsisCfg:
weight: float
sigma_image: float | None
use_second_derivative: bool
loss_type: Literal['MSE', 'EdgeAwareLogL1', 'PearsonDepth'] = 'MSE'
detach: bool = False
conf: bool = False
not_use_valid_mask: bool = False
apply_after_step: int = 0
@dataclass
class LossDepthConsisCfgWrapper:
depth_consis: LossDepthConsisCfg
class LogL1(torch.nn.Module):
"""Log-L1 loss"""
def __init__(
self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs
):
super().__init__()
self.implementation = implementation
def forward(self, pred, gt):
if self.implementation == "scalar":
return torch.log(1 + torch.abs(pred - gt)).mean()
else:
return torch.log(1 + torch.abs(pred - gt))
class EdgeAwareLogL1(torch.nn.Module):
"""Gradient aware Log-L1 loss"""
def __init__(
self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs
):
super().__init__()
self.implementation = implementation
self.logl1 = LogL1(implementation="per-pixel")
def forward(self, pred: Tensor, gt: Tensor, rgb: Tensor, mask: Optional[Tensor]):
logl1 = self.logl1(pred, gt)
grad_img_x = torch.mean(
torch.abs(rgb[..., :, :-1, :] - rgb[..., :, 1:, :]), -1, keepdim=True
)
grad_img_y = torch.mean(
torch.abs(rgb[..., :-1, :, :] - rgb[..., 1:, :, :]), -1, keepdim=True
)
lambda_x = torch.exp(-grad_img_x)
lambda_y = torch.exp(-grad_img_y)
loss_x = lambda_x * logl1[..., :, :-1, :]
loss_y = lambda_y * logl1[..., :-1, :, :]
if self.implementation == "per-pixel":
if mask is not None:
loss_x[~mask[..., :, :-1, :]] = 0
loss_y[~mask[..., :-1, :, :]] = 0
return loss_x[..., :-1, :, :] + loss_y[..., :, :-1, :]
if mask is not None:
assert mask.shape[:2] == pred.shape[:2]
loss_x = loss_x[mask[..., :, :-1, :]]
loss_y = loss_y[mask[..., :-1, :, :]]
if self.implementation == "scalar":
return loss_x.mean() + loss_y.mean()
class LossDepthConsis(Loss[LossDepthConsisCfg, LossDepthConsisCfgWrapper]):
def __init__(self, cfg: T_wrapper) -> None:
super().__init__(cfg)
# Extract the configuration from the wrapper.
(field,) = fields(type(cfg))
self.cfg = getattr(cfg, field.name)
self.name = field.name
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
depth_dict: dict,
global_step: int,
) -> Float[Tensor, ""]:
# Before the specified step, don't apply the loss.
if global_step < self.cfg.apply_after_step:
return torch.tensor(0.0, dtype=torch.float32, device=prediction.depth.device)
# Scale the depth between the near and far planes.
# conf_valid_mask = depth_dict['conf_valid_mask']
rendered_depth = prediction.depth
gt_rgb = (batch["context"]["image"] + 1) / 2
valid_mask = depth_dict["distill_infos"]['conf_mask']
if batch['context']['valid_mask'].sum() > 0:
valid_mask = batch['context']['valid_mask']
# if self.cfg.conf:
# valid_mask = valid_mask & conf_valid_mask
if self.cfg.not_use_valid_mask:
valid_mask = torch.ones_like(valid_mask, device=valid_mask.device)
pred_depth = depth_dict['depth'].squeeze(-1)
if self.cfg.detach:
pred_depth = pred_depth.detach()
if self.cfg.loss_type == 'MSE':
depth_loss = F.mse_loss(rendered_depth, pred_depth, reduction='none')[valid_mask].mean()
elif self.cfg.loss_type == 'EdgeAwareLogL1':
rendered_depth = rendered_depth.flatten(0, 1).unsqueeze(-1)
pred_depth = pred_depth.flatten(0, 1).unsqueeze(-1)
gt_rgb = gt_rgb.flatten(0, 1).permute(0, 2, 3, 1)
valid_mask = valid_mask.flatten(0, 1).unsqueeze(-1)
depth_loss = EdgeAwareLogL1()(rendered_depth, pred_depth, gt_rgb, valid_mask)
return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0) |