File size: 3,363 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from dataclasses import dataclass

import torch
from einops import reduce
from jaxtyping import Float
from torch import Tensor

from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
from typing import Generic, Literal, TypeVar
from dataclasses import fields
import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from src.loss.depth_anything.dpt import DepthAnything
from src.misc.utils import vis_depth_map

T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")


@dataclass
class LossDepthGTCfg:
    weight: float
    type: Literal["l1", "mse", "silog", "gradient", "l1+gradient"] | None

@dataclass
class LossDepthGTCfgWrapper:
    depthgt: LossDepthGTCfg


class LossDepthGT(Loss[LossDepthGTCfg, LossDepthGTCfgWrapper]):
    def gradient_loss(self, gs_depth, target_depth, target_valid_mask):
        diff = gs_depth - target_depth

        grad_x_diff = diff[:, :, :, 1:] - diff[:, :, :, :-1]
        grad_y_diff = diff[:, :, 1:, :] - diff[:, :, :-1, :]

        mask_x = target_valid_mask[:, :, :, 1:] * target_valid_mask[:, :, :, :-1]
        mask_y = target_valid_mask[:, :, 1:, :] * target_valid_mask[:, :, :-1, :]

        grad_x_diff = grad_x_diff * mask_x
        grad_y_diff = grad_y_diff * mask_y

        grad_x_diff = grad_x_diff.clamp(min=-100, max=100)
        grad_y_diff = grad_y_diff.clamp(min=-100, max=100)

        loss_x = grad_x_diff.abs().sum()
        loss_y = grad_y_diff.abs().sum()
        num_valid = mask_x.sum() + mask_y.sum()

        if num_valid == 0:
            gradient_loss = 0
        else:
            gradient_loss = (loss_x + loss_y) / (num_valid + 1e-6)
        
        return gradient_loss
    
    def forward(
        self,
        prediction: DecoderOutput,
        batch: BatchedExample,
        gaussians: Gaussians,
        global_step: int,
    ) -> Float[Tensor, ""]:
        # Scale the depth between the near and far planes.

        # prediction: B, H, W, C
        # target: B, H, W, C
        # mask: B, H, W
        
        target_depth = batch["target"]["depth"]
        target_valid_mask = batch["target"]["valid_mask"]
        gs_depth = prediction.depth.clamp(1e-3)
        
        if self.cfg.type == "l1":
            depth_loss = torch.abs(target_depth[target_valid_mask] - gs_depth[target_valid_mask]).mean()
        elif self.cfg.type == "mse":
            depth_loss = F.mse_loss(target_depth[target_valid_mask], gs_depth[target_valid_mask])
        elif self.cfg.type == "silog":
            depth_loss = torch.log(gs_depth[target_valid_mask]) ** 2 + (gs_depth[target_valid_mask] - target_depth[target_valid_mask]) ** 2 - 0.5
            depth_loss = depth_loss.mean()
        elif self.cfg.type == "gradient":
            depth_loss = self.gradient_loss(gs_depth, target_depth, target_valid_mask)
        elif self.cfg.type == "l1+gradient":
            depth_loss_l1 = torch.abs(target_depth[target_valid_mask] - gs_depth[target_valid_mask]).mean()
            depth_loss_gradient = self.gradient_loss(gs_depth, target_depth, target_valid_mask)
            depth_loss = depth_loss_l1 + depth_loss_gradient

        return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0, posinf=0.0, neginf=0.0)