File size: 4,979 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from dataclasses import dataclass

import torch
from einops import reduce
from jaxtyping import Float
from torch import Tensor

from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
from typing import Generic, Literal, Optional, TypeVar
from dataclasses import fields
import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from src.loss.depth_anything.dpt import DepthAnything
from src.misc.utils import vis_depth_map

T_cfg = TypeVar("T_cfg")
T_wrapper = TypeVar("T_wrapper")


@dataclass
class LossDepthConsisCfg:
    weight: float
    sigma_image: float | None
    use_second_derivative: bool
    loss_type: Literal['MSE', 'EdgeAwareLogL1', 'PearsonDepth'] = 'MSE'
    detach: bool = False
    conf: bool = False
    not_use_valid_mask: bool = False
    apply_after_step: int = 0

@dataclass
class LossDepthConsisCfgWrapper:
    depth_consis: LossDepthConsisCfg


class LogL1(torch.nn.Module):
    """Log-L1 loss"""

    def __init__(
        self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs
    ):
        super().__init__()
        self.implementation = implementation

    def forward(self, pred, gt):
        if self.implementation == "scalar":
            return torch.log(1 + torch.abs(pred - gt)).mean()
        else:
            return torch.log(1 + torch.abs(pred - gt))

class EdgeAwareLogL1(torch.nn.Module):
    """Gradient aware Log-L1 loss"""

    def __init__(
        self, implementation: Literal["scalar", "per-pixel"] = "scalar", **kwargs
    ):
        super().__init__()
        self.implementation = implementation
        self.logl1 = LogL1(implementation="per-pixel")

    def forward(self, pred: Tensor, gt: Tensor, rgb: Tensor, mask: Optional[Tensor]):
        logl1 = self.logl1(pred, gt)

        grad_img_x = torch.mean(
            torch.abs(rgb[..., :, :-1, :] - rgb[..., :, 1:, :]), -1, keepdim=True
        )
        grad_img_y = torch.mean(
            torch.abs(rgb[..., :-1, :, :] - rgb[..., 1:, :, :]), -1, keepdim=True
        )
        lambda_x = torch.exp(-grad_img_x)
        lambda_y = torch.exp(-grad_img_y)

        loss_x = lambda_x * logl1[..., :, :-1, :]
        loss_y = lambda_y * logl1[..., :-1, :, :]

        if self.implementation == "per-pixel":
            if mask is not None:
                loss_x[~mask[..., :, :-1, :]] = 0
                loss_y[~mask[..., :-1, :, :]] = 0
            return loss_x[..., :-1, :, :] + loss_y[..., :, :-1, :]

        if mask is not None:
            assert mask.shape[:2] == pred.shape[:2]
            loss_x = loss_x[mask[..., :, :-1, :]]
            loss_y = loss_y[mask[..., :-1, :, :]]

        if self.implementation == "scalar":
            return loss_x.mean() + loss_y.mean()
        
class LossDepthConsis(Loss[LossDepthConsisCfg, LossDepthConsisCfgWrapper]):
    def __init__(self, cfg: T_wrapper) -> None:
        super().__init__(cfg)
        
        # Extract the configuration from the wrapper.
        (field,) = fields(type(cfg))
        self.cfg = getattr(cfg, field.name)
        self.name = field.name

    def forward(
        self,
        prediction: DecoderOutput,
        batch: BatchedExample,
        gaussians: Gaussians,
        depth_dict: dict,
        global_step: int,
    ) -> Float[Tensor, ""]:
        
        # Before the specified step, don't apply the loss.
        if global_step < self.cfg.apply_after_step:
            return torch.tensor(0.0, dtype=torch.float32, device=prediction.depth.device)
        
        # Scale the depth between the near and far planes.
        # conf_valid_mask = depth_dict['conf_valid_mask']
        rendered_depth = prediction.depth
        gt_rgb = (batch["context"]["image"] + 1) / 2
        valid_mask = depth_dict["distill_infos"]['conf_mask']

        if batch['context']['valid_mask'].sum() > 0:
            valid_mask = batch['context']['valid_mask']
        # if self.cfg.conf:
        #     valid_mask = valid_mask & conf_valid_mask
        if self.cfg.not_use_valid_mask:
            valid_mask = torch.ones_like(valid_mask, device=valid_mask.device)
        pred_depth = depth_dict['depth'].squeeze(-1)
        if self.cfg.detach:
            pred_depth = pred_depth.detach()
        if self.cfg.loss_type == 'MSE':
            depth_loss = F.mse_loss(rendered_depth, pred_depth, reduction='none')[valid_mask].mean()
        elif self.cfg.loss_type == 'EdgeAwareLogL1':
            rendered_depth = rendered_depth.flatten(0, 1).unsqueeze(-1)
            pred_depth = pred_depth.flatten(0, 1).unsqueeze(-1)
            gt_rgb = gt_rgb.flatten(0, 1).permute(0, 2, 3, 1)
            valid_mask = valid_mask.flatten(0, 1).unsqueeze(-1)
            depth_loss = EdgeAwareLogL1()(rendered_depth, pred_depth, gt_rgb, valid_mask)
        return self.cfg.weight * torch.nan_to_num(depth_loss, nan=0.0)