File size: 12,549 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import math
import torch
import torch.nn.functional as F

# TODO: check if the functions can be moved somewhere else
from scenedino.common.util import kl_div, normalized_entropy
from scenedino.models.prediction_heads.layers import ssim, geo


# TODO: have two signatures with override. One for mask, one without mask
# NOTE: what is the purpose of the mask. Ask Felix
def compute_l1ssim(
    img0: torch.Tensor, img1: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor:  ## (img0 == pred, img1 == GT)
    """Calculate the L1-SSIM error between two images. Use a mask if provided to ignore certain pixels.

    Args:
        img0 (torch.Tensor): torch.Tensor of shape (B, c, h, w) containing the predicted images.
        img1 (torch.Tensor): torch.Tensor of shape (B, c, h, w) containing the ground truth images.
        mask (torch.Tensor | None, optional): torch.Tensor of shape (B, h, w). Defaults to None.

    Returns:
        torch.Tensor: per patch error of shape (B, h, w)
    """
    errors = 0.85 * torch.mean(
        ssim(img0, img1, pad_reflection=False, gaussian_average=True, comp_mode=True),
        dim=1,
    ) + 0.15 * torch.mean(torch.abs(img0 - img1), dim=1)
    # checking if a mask is provided. If a mask is provided, it is returned along with the errors. Otherwise, only the errors are returned.
    # if mask is not None:
    #     return (
    #         errors,
    #         mask,
    #     )
    return errors  # (B, h, w)


def compute_normalized_l1(
    flow0: torch.Tensor, flow1: torch.Tensor) -> torch.Tensor:

    errors = (flow0 - flow1).abs() / (flow0.detach().norm(dim=1, keepdim=True) + 1e-4)

    return errors


# TODO: integrate the mask
def compute_edge_aware_smoothness(
    gt_img: torch.Tensor, input: torch.Tensor, mask: torch.Tensor | None = None, temperature: int = 1
) -> torch.Tensor:
    """Compute the edge aware smoothness loss of the depth prediction based on the gradient of the original image.

    Args:
        gt_img (torch.Tensor): ground truth images of shape (B, c, h, w)
        input (torch.Tensor): predicted tensor of shape (B, c, h, w)
        mask (torch.Tensor | None, optional): Not used yet. Defaults to None.

    Returns:
        torch.Tensor: per pixel edge aware smoothness loss of shape (B, h, w)
    """
    _, _, h, w = gt_img.shape

    # TODO: check whether interpolation is necessary
    # gt_img = F.interpolate(gt_img, (h, w))

    input_dx = torch.mean(
        torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:]), 1, keepdim=True
    )  # (B, 1, h, w-1)
    input_dy = torch.mean(
        torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]), 1, keepdim=True
    )  # (B, 1, h-1, w)

    i_dx = torch.mean(
        torch.abs(gt_img[:, :, :, :-1] - gt_img[:, :, :, 1:]), 1, keepdim=True
    )  # (B, 1, h, w-1)
    i_dy = torch.mean(
        torch.abs(gt_img[:, :, :-1, :] - gt_img[:, :, 1:, :]), 1, keepdim=True
    )  # (B, 1, h-1, w)

    input_dx *= torch.exp(-temperature * i_dx)  # (B, 1, h, w-1)
    input_dy *= torch.exp(-temperature * i_dy)  # (B, 1, h-1, w)

    errors = F.pad(input_dx, pad=(0, 1), mode="constant", value=0) + F.pad(
        input_dy, pad=(0, 0, 0, 1), mode="constant", value=0
    )  # (B, 1, h, w)
    return errors[:, 0, :, :]  # (B, h, w)


def compute_3d_smoothness(
    feature_sample: torch.Tensor, sigma_sample: torch.Tensor
) -> torch.Tensor:

    return torch.var(feature_sample, dim=2)


def compute_occupancy_error(
    teacher_field: torch.Tensor,
    student_field: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Compute the distillation error between the teacher and student density.

    Args:
        teacher_density (torch.Tensor): teacher occpancy map of shape (B)
        student_density (torch.Tensor): student occupancy map of shape (B)
        mask (torch.Tensor | None, optional): Mask indicating bad occpancy values for student or teacher, e.g. invalid occupancies due to out of frustum. Defaults to None.

    Returns:
        torch.Tensor: distillation error of shape (B)
    """
    if mask is not None:
        teacher_field = teacher_field[mask]
        student_field = student_field[mask]

    return torch.nn.MSELoss(reduction="mean")(teacher_field, student_field)  # (1)


def depth_regularization(depth: torch.Tensor) -> torch.Tensor:
    """Compute the depth regularization loss.

    Args:
        depth (torch.Tensor): depth map of shape (B, 1, h, w)

    Returns:
        torch.Tensor: depth regularization loss of shape (B)
    """
    depth_grad_x = depth[:, :, 1:, :] - depth[:, :, :-1, :]
    depth_grad_y = depth[:, :, :, 1:] - depth[:, :, :, :-1]
    depth_reg_loss = (depth_grad_x**2).mean() + (depth_grad_y**2).mean()

    return depth_reg_loss


def alpha_regularization(
    alphas: torch.Tensor, invalids: torch.Tensor | None = None
) -> torch.Tensor:
    # TODO: make configurable
    alpha_reg_fraction = 1 / 8
    alpha_reg_reduction = "ray"
    """Compute the alpha regularization loss.

    Args:
        alphas (torch.Tensor): alpha map of shape (B, 1, h, w)
        invalids (torch.Tensor | None, optional): Mask indicating bad alpha values, e.g. invalid alpha due to out of frustum. Defaults to None.

    Returns:
        torch.Tensor: alpha regularization loss of shape (B)
    """
    n_smps = alphas.shape[-1]

    alpha_sum = alphas[..., :-1].sum(-1)
    min_cap = torch.ones_like(alpha_sum) * (n_smps * alpha_reg_fraction)

    if invalids is not None:
        alpha_sum = alpha_sum * (1 - invalids.squeeze(-1).to(torch.float32))
        min_cap = min_cap * (1 - invalids.squeeze(-1).to(torch.float32))

    match alpha_reg_reduction:
        case "ray":
            alpha_reg_loss = (alpha_sum - min_cap).clamp_min(0)
        case "slice":
            alpha_reg_loss = (alpha_sum.sum(dim=-1) - min_cap.sum(dim=-1)).clamp_min(
                0
            ) / alpha_sum.shape[-1]
        case _:
            raise ValueError(f"Invalid alpha_reg_reduction: {alpha_reg_reduction}")

    return alpha_reg_loss


def surfaceness_regularization(
    alphas: torch.Tensor, invalids: torch.Tensor | None = None
) -> torch.Tensor:
    p = -torch.log(torch.exp(-alphas.abs()) + torch.exp(-(1 - alphas).abs()))
    p = p.mean(-1)

    if invalids is not None:
        p = p * (1 - invalids.squeeze(-1).to(torch.float32))

    surfaceness_reg_loss = p.mean()
    return surfaceness_reg_loss


def depth_smoothness_regularization(depths: torch.Tensor) -> torch.Tensor:
    depth_smoothness_loss = ((depths[..., :-1, :] - depths[..., 1:, :]) ** 2).mean() + (
        (depths[..., :, :-1] - depths[..., :, 1:]) ** 2
    ).mean()

    return depth_smoothness_loss


def sdf_eikonal_regularization(sdf: torch.Tensor) -> torch.Tensor:
    grad_x = sdf[:, :1, :-1, :-1, 1:] - sdf[:, :1, :-1, :-1, :-1]
    grad_y = sdf[:, :1, :-1, 1:, :-1] - sdf[:, :1, :-1, :-1, :-1]
    grad_z = sdf[:, :1, 1:, :-1, :-1] - sdf[:, :1, :-1, :-1, :-1]
    grad = (torch.cat((grad_x, grad_y, grad_z), dim=1) ** 2).sum(dim=1) ** 0.5

    eikonal_loss = ((grad - 1) ** 2).mean(dim=(1, 2, 3))

    return eikonal_loss


def weight_entropy_regularization(
    weights: torch.Tensor, invalids: torch.Tensor | None = None
) -> torch.Tensor:
    ignore_last = False

    weights = weights.clone()

    if ignore_last:
        weights = weights[..., :-1]
        weights = weights / weights.sum(dim=-1, keepdim=True)

    H_max = math.log2(weights.shape[-1])

    # x log2 (x) -> 0 . Therefore, we can set log2 (x) to 0 if x is small enough.
    # This should ensure numerical stability.
    weights_too_small = weights < 2 ** (-16)
    weights[weights_too_small] = 2
    
    wlw = torch.log2(weights) * weights

    wlw[weights_too_small] = 0

    # This is the formula for the normalised entropy
    entropy = -wlw.sum(-1) / H_max
    return entropy


def max_alpha_regularization(alphas: torch.Tensor, invalids: torch.Tensor | None = None):
    alphas_max = alphas[..., :-1].max(dim=-1)[0]
    alphas_reg = (1 - alphas_max).clamp(0, 1).mean()
    return alphas_reg


def max_alpha_inputframe_regularization(alphas: torch.Tensor, ray_info, invalids: torch.Tensor | None = None):
    mask = ray_info[..., 0] == 0
    alphas_max = alphas.max(dim=-1)[0]
    alphas_reg = ((1 - alphas_max).clamp(0, 1) * mask.to(alphas_max.dtype)).mean()
    return alphas_reg


def epipolar_line_regularization(data, rgb_gt, scale):
    rgb = data["coarse"][scale]["rgb"]
    rgb_samps = data["coarse"][scale]["rgb_samps"]

    b, pc, h, w, n_samps, nv, c = rgb_samps.shape

    rgb_gt = data["rgb_gt"].unsqueeze(-2).expand(rgb.shape)

    alphas = data["coarse"][scale]["alphas"]

    # TODO


def density_grid_regularization(density_grid, threshold):
    density_grid = (density_grid.abs() - threshold).clamp_min(0)

    # Attempt to make it more numerically stable
    max_v = density_grid.max().clamp_min(1).detach()

    # print(max_v.item())

    error = (((density_grid / max_v)).mean() * max_v)

    error = torch.nan_to_num(error, 0, 0, 0)

    # Black magic to prevent error massages from anomaly detection when using AMP
    if torch.all(error == 0):
        error = error.detach()

    return error


def kl_prop(weights):
    entropy = normalized_entropy(weights.detach())

    kl_prop = entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 2:, 1:-1]).clamp_min(0) * kl_div(weights[..., 2:, 1:-1, :].detach(), weights[..., 1:-1, 1:-1, :])
    kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 0:-2, 1:-1]).clamp_min(0) * kl_div(weights[..., 0:-2, 1:-1, :].detach(), weights[..., 1:-1, 1:-1, :])
    kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 1:-1, 2:]).clamp_min(0) * kl_div(weights[..., 1:-1, 2:, :].detach(), weights[..., 1:-1, 1:-1, :])
    kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 1:-1, 0:-2]).clamp_min(0) * kl_div(weights[..., 1:-1, :-2, :].detach(), weights[..., 1:-1, 1:-1, :])

    return kl_prop.mean()


def alpha_consistency(alphas, invalids, consistency_policy):
    invalids = torch.all(invalids < .5, dim=-1)

    if consistency_policy == "max":
        target = torch.max(alphas, dim=-1, keepdim=True)[0].detach()
    elif consistency_policy == "min":
        target = torch.max(alphas, dim=-1, keepdim=True)[0].detach()
    elif consistency_policy == "median":
        target = torch.median(alphas, dim=-1, keepdim=True)[0].detach()
    elif consistency_policy == "mean":
        target = torch.mean(alphas, dim=-1, keepdim=True).detach()
    else:
        raise NotImplementedError

    diff = (alphas - target).abs().mean(dim=-1)

    invalids = invalids.to(diff.dtype)

    diff = (diff * invalids)

    return diff.mean()


def alpha_consistency_uncert(alphas, invalids, uncert):
    invalids = torch.all(invalids < .5, dim=-1)
    alphas = alphas.detach()
    nf = alphas.shape[-1]

    alphas_median = torch.median(alphas, dim=-1, keepdim=True)[0].detach()

    target = (alphas - alphas_median).abs().mean(dim=-1) * (nf / (nf-1))

    diff = (uncert[..., None] - target).abs()

    invalids = invalids.to(diff.dtype)

    diff = (diff * invalids)

    return diff.mean()


def entropy_based_smoothness(weights, depth, invalids=None):
    entropy = normalized_entropy(weights.detach())

    error_fn = lambda d0, d1: (d0 - d1.detach()).abs()

    if invalids is None:
        invalids = torch.zeros_like(depth)

    # up
    kl_prop_up = entropy[..., :-1, :] * (entropy[..., :-1, :] - entropy[..., 1:, :]).clamp_min(0) * error_fn(depth[..., :-1, :], depth[..., 1:, :]) * (1 - invalids[..., :-1, :])
    # down
    kl_prop_down = entropy[..., 1:, :] * (entropy[..., 1:, :] - entropy[..., :-1, :]).clamp_min(0) * error_fn(depth[..., 1:, :], depth[..., :-1, :]) * (1 - invalids[..., 1:, :])
    # left
    kl_prop_left = entropy[..., :, :-1] * (entropy[..., :, :-1] - entropy[..., :, 1:]).clamp_min(0) * error_fn(depth[..., :, :-1], depth[..., :, 1:]) * (1 - invalids[..., :, :-1])
    # right
    kl_prop_right = entropy[..., :, 1:] * (entropy[..., :, 1:] - entropy[..., :, :-1]).clamp_min(0) * error_fn(depth[..., :, 1:], depth[..., :, :-1]) * (1 - invalids[..., :, 1:])

    kl_prop = kl_prop_up.mean() + kl_prop_down.mean() + kl_prop_left.mean() + kl_prop_right.mean()

    return kl_prop.mean()


def flow_regularization(flow, gt_flow, invalids=None):
    flow_reg = (flow[..., 0, :] - gt_flow).abs().mean(dim=-1, keepdim=True)
    
    if invalids is not None:
        flow_reg = flow_reg * (1 - invalids)

    return flow_reg.mean()