File size: 3,363 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from dataclasses import dataclass
import torch
from einops import reduce
from jaxtyping import Float
from torch import Tensor
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
from typing import Generic, Literal, TypeVar
from dataclasses import fields
import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from src.loss.depth_anything.dpt import DepthAnything
from src.misc.utils import vis_depth_map
T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")
@dataclass
class LossDepthGTCfg:
weight: float
type: Literal["l1", "mse", "silog", "gradient", "l1+gradient"] | None
@dataclass
class LossDepthGTCfgWrapper:
depthgt: LossDepthGTCfg
class LossDepthGT(Loss[LossDepthGTCfg, LossDepthGTCfgWrapper]):
def gradient_loss(self, gs_depth, target_depth, target_valid_mask):
diff = gs_depth - target_depth
grad_x_diff = diff[:, :, :, 1:] - diff[:, :, :, :-1]
grad_y_diff = diff[:, :, 1:, :] - diff[:, :, :-1, :]
mask_x = target_valid_mask[:, :, :, 1:] * target_valid_mask[:, :, :, :-1]
mask_y = target_valid_mask[:, :, 1:, :] * target_valid_mask[:, :, :-1, :]
grad_x_diff = grad_x_diff * mask_x
grad_y_diff = grad_y_diff * mask_y
grad_x_diff = grad_x_diff.clamp(min=-100, max=100)
grad_y_diff = grad_y_diff.clamp(min=-100, max=100)
loss_x = grad_x_diff.abs().sum()
loss_y = grad_y_diff.abs().sum()
num_valid = mask_x.sum() + mask_y.sum()
if num_valid == 0:
gradient_loss = 0
else:
gradient_loss = (loss_x + loss_y) / (num_valid + 1e-6)
return gradient_loss
def forward(
self,
prediction: DecoderOutput,
batch: BatchedExample,
gaussians: Gaussians,
global_step: int,
) -> Float[Tensor, ""]:
# Scale the depth between the near and far planes.
# prediction: B, H, W, C
# target: B, H, W, C
# mask: B, H, W
target_depth = batch["target"]["depth"]
target_valid_mask = batch["target"]["valid_mask"]
gs_depth = prediction.depth.clamp(1e-3)
if self.cfg.type == "l1":
depth_loss = torch.abs(target_depth[target_valid_mask] - gs_depth[target_valid_mask]).mean()
elif self.cfg.type == "mse":
depth_loss = F.mse_loss(target_depth[target_valid_mask], gs_depth[target_valid_mask])
elif self.cfg.type == "silog":
depth_loss = torch.log(gs_depth[target_valid_mask]) ** 2 + (gs_depth[target_valid_mask] - target_depth[target_valid_mask]) ** 2 - 0.5
depth_loss = depth_loss.mean()
elif self.cfg.type == "gradient":
depth_loss = self.gradient_loss(gs_depth, target_depth, target_valid_mask)
elif self.cfg.type == "l1+gradient":
depth_loss_l1 = torch.abs(target_depth[target_valid_mask] - gs_depth[target_valid_mask]).mean()
depth_loss_gradient = self.gradient_loss(gs_depth, target_depth, target_valid_mask)
depth_loss = depth_loss_l1 + depth_loss_gradient
return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0, posinf=0.0, neginf=0.0) |