Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
import torch.nn.functional as F | |
# TODO: check if the functions can be moved somewhere else | |
from scenedino.common.util import kl_div, normalized_entropy | |
from scenedino.models.prediction_heads.layers import ssim, geo | |
# TODO: have two signatures with override. One for mask, one without mask | |
# NOTE: what is the purpose of the mask. Ask Felix | |
def compute_l1ssim( | |
img0: torch.Tensor, img1: torch.Tensor, mask: torch.Tensor | None = None | |
) -> torch.Tensor: ## (img0 == pred, img1 == GT) | |
"""Calculate the L1-SSIM error between two images. Use a mask if provided to ignore certain pixels. | |
Args: | |
img0 (torch.Tensor): torch.Tensor of shape (B, c, h, w) containing the predicted images. | |
img1 (torch.Tensor): torch.Tensor of shape (B, c, h, w) containing the ground truth images. | |
mask (torch.Tensor | None, optional): torch.Tensor of shape (B, h, w). Defaults to None. | |
Returns: | |
torch.Tensor: per patch error of shape (B, h, w) | |
""" | |
errors = 0.85 * torch.mean( | |
ssim(img0, img1, pad_reflection=False, gaussian_average=True, comp_mode=True), | |
dim=1, | |
) + 0.15 * torch.mean(torch.abs(img0 - img1), dim=1) | |
# checking if a mask is provided. If a mask is provided, it is returned along with the errors. Otherwise, only the errors are returned. | |
# if mask is not None: | |
# return ( | |
# errors, | |
# mask, | |
# ) | |
return errors # (B, h, w) | |
def compute_normalized_l1( | |
flow0: torch.Tensor, flow1: torch.Tensor) -> torch.Tensor: | |
errors = (flow0 - flow1).abs() / (flow0.detach().norm(dim=1, keepdim=True) + 1e-4) | |
return errors | |
# TODO: integrate the mask | |
def compute_edge_aware_smoothness( | |
gt_img: torch.Tensor, input: torch.Tensor, mask: torch.Tensor | None = None, temperature: int = 1 | |
) -> torch.Tensor: | |
"""Compute the edge aware smoothness loss of the depth prediction based on the gradient of the original image. | |
Args: | |
gt_img (torch.Tensor): ground truth images of shape (B, c, h, w) | |
input (torch.Tensor): predicted tensor of shape (B, c, h, w) | |
mask (torch.Tensor | None, optional): Not used yet. Defaults to None. | |
Returns: | |
torch.Tensor: per pixel edge aware smoothness loss of shape (B, h, w) | |
""" | |
_, _, h, w = gt_img.shape | |
# TODO: check whether interpolation is necessary | |
# gt_img = F.interpolate(gt_img, (h, w)) | |
input_dx = torch.mean( | |
torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:]), 1, keepdim=True | |
) # (B, 1, h, w-1) | |
input_dy = torch.mean( | |
torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]), 1, keepdim=True | |
) # (B, 1, h-1, w) | |
i_dx = torch.mean( | |
torch.abs(gt_img[:, :, :, :-1] - gt_img[:, :, :, 1:]), 1, keepdim=True | |
) # (B, 1, h, w-1) | |
i_dy = torch.mean( | |
torch.abs(gt_img[:, :, :-1, :] - gt_img[:, :, 1:, :]), 1, keepdim=True | |
) # (B, 1, h-1, w) | |
input_dx *= torch.exp(-temperature * i_dx) # (B, 1, h, w-1) | |
input_dy *= torch.exp(-temperature * i_dy) # (B, 1, h-1, w) | |
errors = F.pad(input_dx, pad=(0, 1), mode="constant", value=0) + F.pad( | |
input_dy, pad=(0, 0, 0, 1), mode="constant", value=0 | |
) # (B, 1, h, w) | |
return errors[:, 0, :, :] # (B, h, w) | |
def compute_3d_smoothness( | |
feature_sample: torch.Tensor, sigma_sample: torch.Tensor | |
) -> torch.Tensor: | |
return torch.var(feature_sample, dim=2) | |
def compute_occupancy_error( | |
teacher_field: torch.Tensor, | |
student_field: torch.Tensor, | |
mask: torch.Tensor | None = None, | |
) -> torch.Tensor: | |
"""Compute the distillation error between the teacher and student density. | |
Args: | |
teacher_density (torch.Tensor): teacher occpancy map of shape (B) | |
student_density (torch.Tensor): student occupancy map of shape (B) | |
mask (torch.Tensor | None, optional): Mask indicating bad occpancy values for student or teacher, e.g. invalid occupancies due to out of frustum. Defaults to None. | |
Returns: | |
torch.Tensor: distillation error of shape (B) | |
""" | |
if mask is not None: | |
teacher_field = teacher_field[mask] | |
student_field = student_field[mask] | |
return torch.nn.MSELoss(reduction="mean")(teacher_field, student_field) # (1) | |
def depth_regularization(depth: torch.Tensor) -> torch.Tensor: | |
"""Compute the depth regularization loss. | |
Args: | |
depth (torch.Tensor): depth map of shape (B, 1, h, w) | |
Returns: | |
torch.Tensor: depth regularization loss of shape (B) | |
""" | |
depth_grad_x = depth[:, :, 1:, :] - depth[:, :, :-1, :] | |
depth_grad_y = depth[:, :, :, 1:] - depth[:, :, :, :-1] | |
depth_reg_loss = (depth_grad_x**2).mean() + (depth_grad_y**2).mean() | |
return depth_reg_loss | |
def alpha_regularization( | |
alphas: torch.Tensor, invalids: torch.Tensor | None = None | |
) -> torch.Tensor: | |
# TODO: make configurable | |
alpha_reg_fraction = 1 / 8 | |
alpha_reg_reduction = "ray" | |
"""Compute the alpha regularization loss. | |
Args: | |
alphas (torch.Tensor): alpha map of shape (B, 1, h, w) | |
invalids (torch.Tensor | None, optional): Mask indicating bad alpha values, e.g. invalid alpha due to out of frustum. Defaults to None. | |
Returns: | |
torch.Tensor: alpha regularization loss of shape (B) | |
""" | |
n_smps = alphas.shape[-1] | |
alpha_sum = alphas[..., :-1].sum(-1) | |
min_cap = torch.ones_like(alpha_sum) * (n_smps * alpha_reg_fraction) | |
if invalids is not None: | |
alpha_sum = alpha_sum * (1 - invalids.squeeze(-1).to(torch.float32)) | |
min_cap = min_cap * (1 - invalids.squeeze(-1).to(torch.float32)) | |
match alpha_reg_reduction: | |
case "ray": | |
alpha_reg_loss = (alpha_sum - min_cap).clamp_min(0) | |
case "slice": | |
alpha_reg_loss = (alpha_sum.sum(dim=-1) - min_cap.sum(dim=-1)).clamp_min( | |
0 | |
) / alpha_sum.shape[-1] | |
case _: | |
raise ValueError(f"Invalid alpha_reg_reduction: {alpha_reg_reduction}") | |
return alpha_reg_loss | |
def surfaceness_regularization( | |
alphas: torch.Tensor, invalids: torch.Tensor | None = None | |
) -> torch.Tensor: | |
p = -torch.log(torch.exp(-alphas.abs()) + torch.exp(-(1 - alphas).abs())) | |
p = p.mean(-1) | |
if invalids is not None: | |
p = p * (1 - invalids.squeeze(-1).to(torch.float32)) | |
surfaceness_reg_loss = p.mean() | |
return surfaceness_reg_loss | |
def depth_smoothness_regularization(depths: torch.Tensor) -> torch.Tensor: | |
depth_smoothness_loss = ((depths[..., :-1, :] - depths[..., 1:, :]) ** 2).mean() + ( | |
(depths[..., :, :-1] - depths[..., :, 1:]) ** 2 | |
).mean() | |
return depth_smoothness_loss | |
def sdf_eikonal_regularization(sdf: torch.Tensor) -> torch.Tensor: | |
grad_x = sdf[:, :1, :-1, :-1, 1:] - sdf[:, :1, :-1, :-1, :-1] | |
grad_y = sdf[:, :1, :-1, 1:, :-1] - sdf[:, :1, :-1, :-1, :-1] | |
grad_z = sdf[:, :1, 1:, :-1, :-1] - sdf[:, :1, :-1, :-1, :-1] | |
grad = (torch.cat((grad_x, grad_y, grad_z), dim=1) ** 2).sum(dim=1) ** 0.5 | |
eikonal_loss = ((grad - 1) ** 2).mean(dim=(1, 2, 3)) | |
return eikonal_loss | |
def weight_entropy_regularization( | |
weights: torch.Tensor, invalids: torch.Tensor | None = None | |
) -> torch.Tensor: | |
ignore_last = False | |
weights = weights.clone() | |
if ignore_last: | |
weights = weights[..., :-1] | |
weights = weights / weights.sum(dim=-1, keepdim=True) | |
H_max = math.log2(weights.shape[-1]) | |
# x log2 (x) -> 0 . Therefore, we can set log2 (x) to 0 if x is small enough. | |
# This should ensure numerical stability. | |
weights_too_small = weights < 2 ** (-16) | |
weights[weights_too_small] = 2 | |
wlw = torch.log2(weights) * weights | |
wlw[weights_too_small] = 0 | |
# This is the formula for the normalised entropy | |
entropy = -wlw.sum(-1) / H_max | |
return entropy | |
def max_alpha_regularization(alphas: torch.Tensor, invalids: torch.Tensor | None = None): | |
alphas_max = alphas[..., :-1].max(dim=-1)[0] | |
alphas_reg = (1 - alphas_max).clamp(0, 1).mean() | |
return alphas_reg | |
def max_alpha_inputframe_regularization(alphas: torch.Tensor, ray_info, invalids: torch.Tensor | None = None): | |
mask = ray_info[..., 0] == 0 | |
alphas_max = alphas.max(dim=-1)[0] | |
alphas_reg = ((1 - alphas_max).clamp(0, 1) * mask.to(alphas_max.dtype)).mean() | |
return alphas_reg | |
def epipolar_line_regularization(data, rgb_gt, scale): | |
rgb = data["coarse"][scale]["rgb"] | |
rgb_samps = data["coarse"][scale]["rgb_samps"] | |
b, pc, h, w, n_samps, nv, c = rgb_samps.shape | |
rgb_gt = data["rgb_gt"].unsqueeze(-2).expand(rgb.shape) | |
alphas = data["coarse"][scale]["alphas"] | |
# TODO | |
def density_grid_regularization(density_grid, threshold): | |
density_grid = (density_grid.abs() - threshold).clamp_min(0) | |
# Attempt to make it more numerically stable | |
max_v = density_grid.max().clamp_min(1).detach() | |
# print(max_v.item()) | |
error = (((density_grid / max_v)).mean() * max_v) | |
error = torch.nan_to_num(error, 0, 0, 0) | |
# Black magic to prevent error massages from anomaly detection when using AMP | |
if torch.all(error == 0): | |
error = error.detach() | |
return error | |
def kl_prop(weights): | |
entropy = normalized_entropy(weights.detach()) | |
kl_prop = entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 2:, 1:-1]).clamp_min(0) * kl_div(weights[..., 2:, 1:-1, :].detach(), weights[..., 1:-1, 1:-1, :]) | |
kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 0:-2, 1:-1]).clamp_min(0) * kl_div(weights[..., 0:-2, 1:-1, :].detach(), weights[..., 1:-1, 1:-1, :]) | |
kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 1:-1, 2:]).clamp_min(0) * kl_div(weights[..., 1:-1, 2:, :].detach(), weights[..., 1:-1, 1:-1, :]) | |
kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 1:-1, 0:-2]).clamp_min(0) * kl_div(weights[..., 1:-1, :-2, :].detach(), weights[..., 1:-1, 1:-1, :]) | |
return kl_prop.mean() | |
def alpha_consistency(alphas, invalids, consistency_policy): | |
invalids = torch.all(invalids < .5, dim=-1) | |
if consistency_policy == "max": | |
target = torch.max(alphas, dim=-1, keepdim=True)[0].detach() | |
elif consistency_policy == "min": | |
target = torch.max(alphas, dim=-1, keepdim=True)[0].detach() | |
elif consistency_policy == "median": | |
target = torch.median(alphas, dim=-1, keepdim=True)[0].detach() | |
elif consistency_policy == "mean": | |
target = torch.mean(alphas, dim=-1, keepdim=True).detach() | |
else: | |
raise NotImplementedError | |
diff = (alphas - target).abs().mean(dim=-1) | |
invalids = invalids.to(diff.dtype) | |
diff = (diff * invalids) | |
return diff.mean() | |
def alpha_consistency_uncert(alphas, invalids, uncert): | |
invalids = torch.all(invalids < .5, dim=-1) | |
alphas = alphas.detach() | |
nf = alphas.shape[-1] | |
alphas_median = torch.median(alphas, dim=-1, keepdim=True)[0].detach() | |
target = (alphas - alphas_median).abs().mean(dim=-1) * (nf / (nf-1)) | |
diff = (uncert[..., None] - target).abs() | |
invalids = invalids.to(diff.dtype) | |
diff = (diff * invalids) | |
return diff.mean() | |
def entropy_based_smoothness(weights, depth, invalids=None): | |
entropy = normalized_entropy(weights.detach()) | |
error_fn = lambda d0, d1: (d0 - d1.detach()).abs() | |
if invalids is None: | |
invalids = torch.zeros_like(depth) | |
# up | |
kl_prop_up = entropy[..., :-1, :] * (entropy[..., :-1, :] - entropy[..., 1:, :]).clamp_min(0) * error_fn(depth[..., :-1, :], depth[..., 1:, :]) * (1 - invalids[..., :-1, :]) | |
# down | |
kl_prop_down = entropy[..., 1:, :] * (entropy[..., 1:, :] - entropy[..., :-1, :]).clamp_min(0) * error_fn(depth[..., 1:, :], depth[..., :-1, :]) * (1 - invalids[..., 1:, :]) | |
# left | |
kl_prop_left = entropy[..., :, :-1] * (entropy[..., :, :-1] - entropy[..., :, 1:]).clamp_min(0) * error_fn(depth[..., :, :-1], depth[..., :, 1:]) * (1 - invalids[..., :, :-1]) | |
# right | |
kl_prop_right = entropy[..., :, 1:] * (entropy[..., :, 1:] - entropy[..., :, :-1]).clamp_min(0) * error_fn(depth[..., :, 1:], depth[..., :, :-1]) * (1 - invalids[..., :, 1:]) | |
kl_prop = kl_prop_up.mean() + kl_prop_down.mean() + kl_prop_left.mean() + kl_prop_right.mean() | |
return kl_prop.mean() | |
def flow_regularization(flow, gt_flow, invalids=None): | |
flow_reg = (flow[..., 0, :] - gt_flow).abs().mean(dim=-1, keepdim=True) | |
if invalids is not None: | |
flow_reg = flow_reg * (1 - invalids) | |
return flow_reg.mean() | |