import math import torch import torch.nn.functional as F # TODO: check if the functions can be moved somewhere else from scenedino.common.util import kl_div, normalized_entropy from scenedino.models.prediction_heads.layers import ssim, geo # TODO: have two signatures with override. One for mask, one without mask # NOTE: what is the purpose of the mask. Ask Felix def compute_l1ssim( img0: torch.Tensor, img1: torch.Tensor, mask: torch.Tensor | None = None ) -> torch.Tensor: ## (img0 == pred, img1 == GT) """Calculate the L1-SSIM error between two images. Use a mask if provided to ignore certain pixels. Args: img0 (torch.Tensor): torch.Tensor of shape (B, c, h, w) containing the predicted images. img1 (torch.Tensor): torch.Tensor of shape (B, c, h, w) containing the ground truth images. mask (torch.Tensor | None, optional): torch.Tensor of shape (B, h, w). Defaults to None. Returns: torch.Tensor: per patch error of shape (B, h, w) """ errors = 0.85 * torch.mean( ssim(img0, img1, pad_reflection=False, gaussian_average=True, comp_mode=True), dim=1, ) + 0.15 * torch.mean(torch.abs(img0 - img1), dim=1) # checking if a mask is provided. If a mask is provided, it is returned along with the errors. Otherwise, only the errors are returned. # if mask is not None: # return ( # errors, # mask, # ) return errors # (B, h, w) def compute_normalized_l1( flow0: torch.Tensor, flow1: torch.Tensor) -> torch.Tensor: errors = (flow0 - flow1).abs() / (flow0.detach().norm(dim=1, keepdim=True) + 1e-4) return errors # TODO: integrate the mask def compute_edge_aware_smoothness( gt_img: torch.Tensor, input: torch.Tensor, mask: torch.Tensor | None = None, temperature: int = 1 ) -> torch.Tensor: """Compute the edge aware smoothness loss of the depth prediction based on the gradient of the original image. Args: gt_img (torch.Tensor): ground truth images of shape (B, c, h, w) input (torch.Tensor): predicted tensor of shape (B, c, h, w) mask (torch.Tensor | None, optional): Not used yet. Defaults to None. Returns: torch.Tensor: per pixel edge aware smoothness loss of shape (B, h, w) """ _, _, h, w = gt_img.shape # TODO: check whether interpolation is necessary # gt_img = F.interpolate(gt_img, (h, w)) input_dx = torch.mean( torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:]), 1, keepdim=True ) # (B, 1, h, w-1) input_dy = torch.mean( torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]), 1, keepdim=True ) # (B, 1, h-1, w) i_dx = torch.mean( torch.abs(gt_img[:, :, :, :-1] - gt_img[:, :, :, 1:]), 1, keepdim=True ) # (B, 1, h, w-1) i_dy = torch.mean( torch.abs(gt_img[:, :, :-1, :] - gt_img[:, :, 1:, :]), 1, keepdim=True ) # (B, 1, h-1, w) input_dx *= torch.exp(-temperature * i_dx) # (B, 1, h, w-1) input_dy *= torch.exp(-temperature * i_dy) # (B, 1, h-1, w) errors = F.pad(input_dx, pad=(0, 1), mode="constant", value=0) + F.pad( input_dy, pad=(0, 0, 0, 1), mode="constant", value=0 ) # (B, 1, h, w) return errors[:, 0, :, :] # (B, h, w) def compute_3d_smoothness( feature_sample: torch.Tensor, sigma_sample: torch.Tensor ) -> torch.Tensor: return torch.var(feature_sample, dim=2) def compute_occupancy_error( teacher_field: torch.Tensor, student_field: torch.Tensor, mask: torch.Tensor | None = None, ) -> torch.Tensor: """Compute the distillation error between the teacher and student density. Args: teacher_density (torch.Tensor): teacher occpancy map of shape (B) student_density (torch.Tensor): student occupancy map of shape (B) mask (torch.Tensor | None, optional): Mask indicating bad occpancy values for student or teacher, e.g. invalid occupancies due to out of frustum. Defaults to None. Returns: torch.Tensor: distillation error of shape (B) """ if mask is not None: teacher_field = teacher_field[mask] student_field = student_field[mask] return torch.nn.MSELoss(reduction="mean")(teacher_field, student_field) # (1) def depth_regularization(depth: torch.Tensor) -> torch.Tensor: """Compute the depth regularization loss. Args: depth (torch.Tensor): depth map of shape (B, 1, h, w) Returns: torch.Tensor: depth regularization loss of shape (B) """ depth_grad_x = depth[:, :, 1:, :] - depth[:, :, :-1, :] depth_grad_y = depth[:, :, :, 1:] - depth[:, :, :, :-1] depth_reg_loss = (depth_grad_x**2).mean() + (depth_grad_y**2).mean() return depth_reg_loss def alpha_regularization( alphas: torch.Tensor, invalids: torch.Tensor | None = None ) -> torch.Tensor: # TODO: make configurable alpha_reg_fraction = 1 / 8 alpha_reg_reduction = "ray" """Compute the alpha regularization loss. Args: alphas (torch.Tensor): alpha map of shape (B, 1, h, w) invalids (torch.Tensor | None, optional): Mask indicating bad alpha values, e.g. invalid alpha due to out of frustum. Defaults to None. Returns: torch.Tensor: alpha regularization loss of shape (B) """ n_smps = alphas.shape[-1] alpha_sum = alphas[..., :-1].sum(-1) min_cap = torch.ones_like(alpha_sum) * (n_smps * alpha_reg_fraction) if invalids is not None: alpha_sum = alpha_sum * (1 - invalids.squeeze(-1).to(torch.float32)) min_cap = min_cap * (1 - invalids.squeeze(-1).to(torch.float32)) match alpha_reg_reduction: case "ray": alpha_reg_loss = (alpha_sum - min_cap).clamp_min(0) case "slice": alpha_reg_loss = (alpha_sum.sum(dim=-1) - min_cap.sum(dim=-1)).clamp_min( 0 ) / alpha_sum.shape[-1] case _: raise ValueError(f"Invalid alpha_reg_reduction: {alpha_reg_reduction}") return alpha_reg_loss def surfaceness_regularization( alphas: torch.Tensor, invalids: torch.Tensor | None = None ) -> torch.Tensor: p = -torch.log(torch.exp(-alphas.abs()) + torch.exp(-(1 - alphas).abs())) p = p.mean(-1) if invalids is not None: p = p * (1 - invalids.squeeze(-1).to(torch.float32)) surfaceness_reg_loss = p.mean() return surfaceness_reg_loss def depth_smoothness_regularization(depths: torch.Tensor) -> torch.Tensor: depth_smoothness_loss = ((depths[..., :-1, :] - depths[..., 1:, :]) ** 2).mean() + ( (depths[..., :, :-1] - depths[..., :, 1:]) ** 2 ).mean() return depth_smoothness_loss def sdf_eikonal_regularization(sdf: torch.Tensor) -> torch.Tensor: grad_x = sdf[:, :1, :-1, :-1, 1:] - sdf[:, :1, :-1, :-1, :-1] grad_y = sdf[:, :1, :-1, 1:, :-1] - sdf[:, :1, :-1, :-1, :-1] grad_z = sdf[:, :1, 1:, :-1, :-1] - sdf[:, :1, :-1, :-1, :-1] grad = (torch.cat((grad_x, grad_y, grad_z), dim=1) ** 2).sum(dim=1) ** 0.5 eikonal_loss = ((grad - 1) ** 2).mean(dim=(1, 2, 3)) return eikonal_loss def weight_entropy_regularization( weights: torch.Tensor, invalids: torch.Tensor | None = None ) -> torch.Tensor: ignore_last = False weights = weights.clone() if ignore_last: weights = weights[..., :-1] weights = weights / weights.sum(dim=-1, keepdim=True) H_max = math.log2(weights.shape[-1]) # x log2 (x) -> 0 . Therefore, we can set log2 (x) to 0 if x is small enough. # This should ensure numerical stability. weights_too_small = weights < 2 ** (-16) weights[weights_too_small] = 2 wlw = torch.log2(weights) * weights wlw[weights_too_small] = 0 # This is the formula for the normalised entropy entropy = -wlw.sum(-1) / H_max return entropy def max_alpha_regularization(alphas: torch.Tensor, invalids: torch.Tensor | None = None): alphas_max = alphas[..., :-1].max(dim=-1)[0] alphas_reg = (1 - alphas_max).clamp(0, 1).mean() return alphas_reg def max_alpha_inputframe_regularization(alphas: torch.Tensor, ray_info, invalids: torch.Tensor | None = None): mask = ray_info[..., 0] == 0 alphas_max = alphas.max(dim=-1)[0] alphas_reg = ((1 - alphas_max).clamp(0, 1) * mask.to(alphas_max.dtype)).mean() return alphas_reg def epipolar_line_regularization(data, rgb_gt, scale): rgb = data["coarse"][scale]["rgb"] rgb_samps = data["coarse"][scale]["rgb_samps"] b, pc, h, w, n_samps, nv, c = rgb_samps.shape rgb_gt = data["rgb_gt"].unsqueeze(-2).expand(rgb.shape) alphas = data["coarse"][scale]["alphas"] # TODO def density_grid_regularization(density_grid, threshold): density_grid = (density_grid.abs() - threshold).clamp_min(0) # Attempt to make it more numerically stable max_v = density_grid.max().clamp_min(1).detach() # print(max_v.item()) error = (((density_grid / max_v)).mean() * max_v) error = torch.nan_to_num(error, 0, 0, 0) # Black magic to prevent error massages from anomaly detection when using AMP if torch.all(error == 0): error = error.detach() return error def kl_prop(weights): entropy = normalized_entropy(weights.detach()) kl_prop = entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 2:, 1:-1]).clamp_min(0) * kl_div(weights[..., 2:, 1:-1, :].detach(), weights[..., 1:-1, 1:-1, :]) kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 0:-2, 1:-1]).clamp_min(0) * kl_div(weights[..., 0:-2, 1:-1, :].detach(), weights[..., 1:-1, 1:-1, :]) kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 1:-1, 2:]).clamp_min(0) * kl_div(weights[..., 1:-1, 2:, :].detach(), weights[..., 1:-1, 1:-1, :]) kl_prop += entropy[..., 1:-1, 1:-1] * (entropy[..., 1:-1, 1:-1] - entropy[..., 1:-1, 0:-2]).clamp_min(0) * kl_div(weights[..., 1:-1, :-2, :].detach(), weights[..., 1:-1, 1:-1, :]) return kl_prop.mean() def alpha_consistency(alphas, invalids, consistency_policy): invalids = torch.all(invalids < .5, dim=-1) if consistency_policy == "max": target = torch.max(alphas, dim=-1, keepdim=True)[0].detach() elif consistency_policy == "min": target = torch.max(alphas, dim=-1, keepdim=True)[0].detach() elif consistency_policy == "median": target = torch.median(alphas, dim=-1, keepdim=True)[0].detach() elif consistency_policy == "mean": target = torch.mean(alphas, dim=-1, keepdim=True).detach() else: raise NotImplementedError diff = (alphas - target).abs().mean(dim=-1) invalids = invalids.to(diff.dtype) diff = (diff * invalids) return diff.mean() def alpha_consistency_uncert(alphas, invalids, uncert): invalids = torch.all(invalids < .5, dim=-1) alphas = alphas.detach() nf = alphas.shape[-1] alphas_median = torch.median(alphas, dim=-1, keepdim=True)[0].detach() target = (alphas - alphas_median).abs().mean(dim=-1) * (nf / (nf-1)) diff = (uncert[..., None] - target).abs() invalids = invalids.to(diff.dtype) diff = (diff * invalids) return diff.mean() def entropy_based_smoothness(weights, depth, invalids=None): entropy = normalized_entropy(weights.detach()) error_fn = lambda d0, d1: (d0 - d1.detach()).abs() if invalids is None: invalids = torch.zeros_like(depth) # up kl_prop_up = entropy[..., :-1, :] * (entropy[..., :-1, :] - entropy[..., 1:, :]).clamp_min(0) * error_fn(depth[..., :-1, :], depth[..., 1:, :]) * (1 - invalids[..., :-1, :]) # down kl_prop_down = entropy[..., 1:, :] * (entropy[..., 1:, :] - entropy[..., :-1, :]).clamp_min(0) * error_fn(depth[..., 1:, :], depth[..., :-1, :]) * (1 - invalids[..., 1:, :]) # left kl_prop_left = entropy[..., :, :-1] * (entropy[..., :, :-1] - entropy[..., :, 1:]).clamp_min(0) * error_fn(depth[..., :, :-1], depth[..., :, 1:]) * (1 - invalids[..., :, :-1]) # right kl_prop_right = entropy[..., :, 1:] * (entropy[..., :, 1:] - entropy[..., :, :-1]).clamp_min(0) * error_fn(depth[..., :, 1:], depth[..., :, :-1]) * (1 - invalids[..., :, 1:]) kl_prop = kl_prop_up.mean() + kl_prop_down.mean() + kl_prop_left.mean() + kl_prop_right.mean() return kl_prop.mean() def flow_regularization(flow, gt_flow, invalids=None): flow_reg = (flow[..., 0, :] - gt_flow).abs().mean(dim=-1, keepdim=True) if invalids is not None: flow_reg = flow_reg * (1 - invalids) return flow_reg.mean()