Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,549 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 |
import math
import torch
import torch.nn.functional as F
# TODO: check if the functions can be moved somewhere else
from scenedino.common.util import kl_div, normalized_entropy
from scenedino.models.prediction_heads.layers import ssim, geo
# TODO: have two signatures with override. One for mask, one without mask
# NOTE: what is the purpose of the mask. Ask Felix
def compute_l1ssim(
img0: torch.Tensor, img1: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor: ## (img0 == pred, img1 == GT)
"""Calculate the L1-SSIM error between two images. Use a mask if provided to ignore certain pixels.
Args:
img0 (torch.Tensor): torch.Tensor of shape (B, c, h, w) containing the predicted images.
img1 (torch.Tensor): torch.Tensor of shape (B, c, h, w) containing the ground truth images.
mask (torch.Tensor | None, optional): torch.Tensor of shape (B, h, w). Defaults to None.
Returns:
torch.Tensor: per patch error of shape (B, h, w)
"""
errors = 0.85 * torch.mean(
ssim(img0, img1, pad_reflection=False, gaussian_average=True, comp_mode=True),
dim=1,
) + 0.15 * torch.mean(torch.abs(img0 - img1), dim=1)
# checking if a mask is provided. If a mask is provided, it is returned along with the errors. Otherwise, only the errors are returned.
# if mask is not None:
# return (
# errors,
# mask,
# )
return errors # (B, h, w)
def compute_normalized_l1(
flow0: torch.Tensor, flow1: torch.Tensor) -> torch.Tensor:
errors = (flow0 - flow1).abs() / (flow0.detach().norm(dim=1, keepdim=True) + 1e-4)
return errors
# TODO: integrate the mask
def compute_edge_aware_smoothness(
gt_img: torch.Tensor, input: torch.Tensor, mask: torch.Tensor | None = None, temperature: int = 1
) -> torch.Tensor:
"""Compute the edge aware smoothness loss of the depth prediction based on the gradient of the original image.
Args:
gt_img (torch.Tensor): ground truth images of shape (B, c, h, w)
input (torch.Tensor): predicted tensor of shape (B, c, h, w)
mask (torch.Tensor | None, optional): Not used yet. Defaults to None.
Returns:
torch.Tensor: per pixel edge aware smoothness loss of shape (B, h, w)
"""
_, _, h, w = gt_img.shape
# TODO: check whether interpolation is necessary
# gt_img = F.interpolate(gt_img, (h, w))
input_dx = torch.mean(
torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:]), 1, keepdim=True
) # (B, 1, h, w-1)
input_dy = torch.mean(
torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]), 1, keepdim=True
) # (B, 1, h-1, w)
i_dx = torch.mean(
torch.abs(gt_img[:, :, :, :-1] - gt_img[:, :, :, 1:]), 1, keepdim=True
) # (B, 1, h, w-1)
i_dy = torch.mean(
torch.abs(gt_img[:, :, :-1, :] - gt_img[:, :, 1:, :]), 1, keepdim=True
) # (B, 1, h-1, w)
input_dx *= torch.exp(-temperature * i_dx) # (B, 1, h, w-1)
input_dy *= torch.exp(-temperature * i_dy) # (B, 1, h-1, w)
errors = F.pad(input_dx, pad=(0, 1), mode="constant", value=0) + F.pad(
input_dy, pad=(0, 0, 0, 1), mode="constant", value=0
) # (B, 1, h, w)
return errors[:, 0, :, :] # (B, h, w)
def compute_3d_smoothness(
feature_sample: torch.Tensor, sigma_sample: torch.Tensor
) -> torch.Tensor:
return torch.var(feature_sample, dim=2)
def compute_occupancy_error(
teacher_field: torch.Tensor,
student_field: torch.Tensor,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute the distillation error between the teacher and student density.
Args:
teacher_density (torch.Tensor): teacher occpancy map of shape (B)
student_density (torch.Tensor): student occupancy map of shape (B)
mask (torch.Tensor | None, optional): Mask indicating bad occpancy values for student or teacher, e.g. invalid occupancies due to out of frustum. Defaults to None.
Returns:
torch.Tensor: distillation error of shape (B)
"""
if mask is not None:
teacher_field = teacher_field[mask]
student_field = student_field[mask]
return torch.nn.MSELoss(reduction="mean")(teacher_field, student_field) # (1)
def depth_regularization(depth: torch.Tensor) -> torch.Tensor:
"""Compute the depth regularization loss.
Args:
depth (torch.Tensor): depth map of shape (B, 1, h, w)
Returns:
torch.Tensor: depth regularization loss of shape (B)
"""
depth_grad_x = depth[:, :, 1:, :] - depth[:, :, :-1, :]
depth_grad_y = depth[:, :, :, 1:] - depth[:, :, :, :-1]
depth_reg_loss = (depth_grad_x**2).mean() + (depth_grad_y**2).mean()
return depth_reg_loss
def alpha_regularization(
alphas: torch.Tensor, invalids: torch.Tensor | None = None
) -> torch.Tensor:
# TODO: make configurable
alpha_reg_fraction = 1 / 8
alpha_reg_reduction = "ray"
"""Compute the alpha regularization loss.
Args:
alphas (torch.Tensor): alpha map of shape (B, 1, h, w)
invalids (torch.Tensor | None, optional): Mask indicating bad alpha values, e.g. invalid alpha due to out of frustum. Defaults to None.
Returns:
torch.Tensor: alpha regularization loss of shape (B)
"""
n_smps = alphas.shape[-1]
alpha_sum = alphas[..., :-1].sum(-1)
min_cap = torch.ones_like(alpha_sum) * (n_smps * alpha_reg_fraction)
if invalids is not None:
alpha_sum = alpha_sum * (1 - invalids.squeeze(-1).to(torch.float32))
min_cap = min_cap * (1 - invalids.squeeze(-1).to(torch.float32))
match alpha_reg_reduction:
case "ray":
alpha_reg_loss = (alpha_sum - min_cap).clamp_min(0)
case "slice":
alpha_reg_loss = (alpha_sum.sum(dim=-1) - min_cap.sum(dim=-1)).clamp_min(
0
) / alpha_sum.shape[-1]
case _:
raise ValueError(f"Invalid alpha_reg_reduction: {alpha_reg_reduction}")
return alpha_reg_loss
def surfaceness_regularization(
alphas: torch.Tensor, invalids: torch.Tensor | None = None
) -> torch.Tensor:
p = -torch.log(torch.exp(-alphas.abs()) + torch.exp(-(1 - alphas).abs()))
p = p.mean(-1)
if invalids is not None:
p = p * (1 - invalids.squeeze(-1).to(torch.float32))
surfaceness_reg_loss = p.mean()
return surfaceness_reg_loss
def depth_smoothness_regularization(depths: torch.Tensor) -> torch.Tensor:
depth_smoothness_loss = ((depths[..., :-1, :] - depths[..., 1:, :]) ** 2).mean() + (
(depths[..., :, :-1] - depths[..., :, 1:]) ** 2
).mean()
return depth_smoothness_loss
def sdf_eikonal_regularization(sdf: torch.Tensor) -> torch.Tensor:
grad_x = sdf[:, :1, :-1, :-1, 1:] - sdf[:, :1, :-1, :-1, :-1]
grad_y = sdf[:, :1, :-1, 1:, :-1] - sdf[:, :1, :-1, :-1, :-1]
grad_z = sdf[:, :1, 1:, :-1, :-1] - sdf[:, :1, :-1, :-1, :-1]
grad = (torch.cat((grad_x, grad_y, grad_z), dim=1) ** 2).sum(dim=1) ** 0.5
eikonal_loss = ((grad - 1) ** 2).mean(dim=(1, 2, 3))
return eikonal_loss
def weight_entropy_regularization(
weights: torch.Tensor, invalids: torch.Tensor | None = None
) -> torch.Tensor:
ignore_last = False
weights = weights.clone()
if ignore_last:
weights = weights[..., :-1]
weights = weights / weights.sum(dim=-1, keepdim=True)
H_max = math.log2(weights.shape[-1])
# x log2 (x) -> 0 . Therefore, we can set log2 (x) to 0 if x is small enough.
# This should ensure numerical stability.
weights_too_small = weights < 2 ** (-16)
weights[weights_too_small] = 2
wlw = torch.log2(weights) * weights
wlw[weights_too_small] = 0
# This is the formula for the normalised entropy
entropy = -wlw.sum(-1) / H_max
return entropy
def max_alpha_regularization(alphas: torch.Tensor, invalids: torch.Tensor | None = None):
alphas_max = alphas[..., :-1].max(dim=-1)[0]
alphas_reg = (1 - alphas_max).clamp(0, 1).mean()
return alphas_reg
def max_alpha_inputframe_regularization(alphas: torch.Tensor, ray_info, invalids: torch.Tensor | None = None):
mask = ray_info[..., 0] == 0
alphas_max = alphas.max(dim=-1)[0]
alphas_reg = ((1 - alphas_max).clamp(0, 1) * mask.to(alphas_max.dtype)).mean()
return alphas_reg
def epipolar_line_regularization(data, rgb_gt, scale):
rgb = data["coarse"][scale]["rgb"]
rgb_samps = data["coarse"][scale]["rgb_samps"]
b, pc, h, w, n_samps, nv, c = rgb_samps.shape
rgb_gt = data["rgb_gt"].unsqueeze(-2).expand(rgb.shape)
alphas = data["coarse"][scale]["alphas"]
# TODO
def density_grid_regularization(density_grid, threshold):
density_grid = (density_grid.abs() - threshold).clamp_min(0)
# Attempt to make it more numerically stable
max_v = density_grid.max().clamp_min(1).detach()
# print(max_v.item())
error = (((density_grid / max_v)).mean() * max_v)
error = torch.nan_to_num(error, 0, 0, 0)
# Black magic to prevent error massages from anomaly detection when using AMP
if torch.all(error == 0):
error = error.detach()
return error
def kl_prop(weights):
entropy = normalized_entropy(weights.detach())
kl_prop = entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 2:, 1:-1]).clamp_min(0) * kl_div(weights[..., 2:, 1:-1, :].detach(), weights[..., 1:-1, 1:-1, :])
kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 0:-2, 1:-1]).clamp_min(0) * kl_div(weights[..., 0:-2, 1:-1, :].detach(), weights[..., 1:-1, 1:-1, :])
kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 1:-1, 2:]).clamp_min(0) * kl_div(weights[..., 1:-1, 2:, :].detach(), weights[..., 1:-1, 1:-1, :])
kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 1:-1, 0:-2]).clamp_min(0) * kl_div(weights[..., 1:-1, :-2, :].detach(), weights[..., 1:-1, 1:-1, :])
return kl_prop.mean()
def alpha_consistency(alphas, invalids, consistency_policy):
invalids = torch.all(invalids < .5, dim=-1)
if consistency_policy == "max":
target = torch.max(alphas, dim=-1, keepdim=True)[0].detach()
elif consistency_policy == "min":
target = torch.max(alphas, dim=-1, keepdim=True)[0].detach()
elif consistency_policy == "median":
target = torch.median(alphas, dim=-1, keepdim=True)[0].detach()
elif consistency_policy == "mean":
target = torch.mean(alphas, dim=-1, keepdim=True).detach()
else:
raise NotImplementedError
diff = (alphas - target).abs().mean(dim=-1)
invalids = invalids.to(diff.dtype)
diff = (diff * invalids)
return diff.mean()
def alpha_consistency_uncert(alphas, invalids, uncert):
invalids = torch.all(invalids < .5, dim=-1)
alphas = alphas.detach()
nf = alphas.shape[-1]
alphas_median = torch.median(alphas, dim=-1, keepdim=True)[0].detach()
target = (alphas - alphas_median).abs().mean(dim=-1) * (nf / (nf-1))
diff = (uncert[..., None] - target).abs()
invalids = invalids.to(diff.dtype)
diff = (diff * invalids)
return diff.mean()
def entropy_based_smoothness(weights, depth, invalids=None):
entropy = normalized_entropy(weights.detach())
error_fn = lambda d0, d1: (d0 - d1.detach()).abs()
if invalids is None:
invalids = torch.zeros_like(depth)
# up
kl_prop_up = entropy[..., :-1, :] * (entropy[..., :-1, :] - entropy[..., 1:, :]).clamp_min(0) * error_fn(depth[..., :-1, :], depth[..., 1:, :]) * (1 - invalids[..., :-1, :])
# down
kl_prop_down = entropy[..., 1:, :] * (entropy[..., 1:, :] - entropy[..., :-1, :]).clamp_min(0) * error_fn(depth[..., 1:, :], depth[..., :-1, :]) * (1 - invalids[..., 1:, :])
# left
kl_prop_left = entropy[..., :, :-1] * (entropy[..., :, :-1] - entropy[..., :, 1:]).clamp_min(0) * error_fn(depth[..., :, :-1], depth[..., :, 1:]) * (1 - invalids[..., :, :-1])
# right
kl_prop_right = entropy[..., :, 1:] * (entropy[..., :, 1:] - entropy[..., :, :-1]).clamp_min(0) * error_fn(depth[..., :, 1:], depth[..., :, :-1]) * (1 - invalids[..., :, 1:])
kl_prop = kl_prop_up.mean() + kl_prop_down.mean() + kl_prop_left.mean() + kl_prop_right.mean()
return kl_prop.mean()
def flow_regularization(flow, gt_flow, invalids=None):
flow_reg = (flow[..., 0, :] - gt_flow).abs().mean(dim=-1, keepdim=True)
if invalids is not None:
flow_reg = flow_reg * (1 - invalids)
return flow_reg.mean()
|