Spaces:
Paused
Paused
# | |
# Copyright (C) 2023, Inria | |
# GRAPHDECO research group, https://team.inria.fr/graphdeco | |
# All rights reserved. | |
# | |
# This software is free for non-commercial, research and evaluation use | |
# under the terms of the LICENSE.md file. | |
# | |
# For inquiries contact [email protected] | |
# | |
import torch | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from math import exp | |
from torch import Tensor, nn | |
from typing import Dict, Literal, Optional, Tuple, cast | |
from jaxtyping import Bool, Float | |
L1Loss = nn.L1Loss | |
MSELoss = nn.MSELoss | |
def l1_loss(network_output, gt, mask=None): | |
l1 = torch.abs((network_output - gt)) | |
if mask is not None: | |
l1 = l1[:, mask] | |
return l1.mean() | |
def l2_loss(network_output, gt): | |
return ((network_output - gt) ** 2).mean() | |
def gaussian(window_size, sigma): | |
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) | |
return gauss / gauss.sum() | |
def create_window(window_size, channel): | |
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
return window | |
def ssim(img1, img2, window_size=11, size_average=True): | |
channel = img1.size(-3) | |
window = create_window(window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
return _ssim(img1, img2, window, window_size, channel, size_average) | |
def _ssim(img1, img2, window, window_size, channel, size_average=True): | |
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1 * mu2 | |
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
C1 = 0.01 ** 2 | |
C2 = 0.03 ** 2 | |
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
if size_average: | |
return ssim_map.mean() | |
else: | |
return ssim_map.mean(1).mean(1).mean(1) | |
def ssim_loss(img1, img2, window_size=11, size_average=True, mask=None): | |
channel = img1.size(-3) | |
window = create_window(window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
return _ssim_loss(img1, img2, window, window_size, channel, size_average, mask) | |
def _ssim_loss(img1, img2, window, window_size, channel, size_average=True, mask=None): | |
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1 * mu2 | |
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
C1 = 0.01 ** 2 | |
C2 = 0.03 ** 2 | |
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
ssim_map = 1 - ssim_map | |
if mask is not None: | |
ssim_map = ssim_map[:, mask] | |
if size_average: | |
return ssim_map.mean() | |
else: | |
return ssim_map.mean(1).mean(1).mean(1) | |
def masked_reduction( | |
input_tensor: Float[Tensor, "1 32 mult"], | |
mask: Bool[Tensor, "1 32 mult"], | |
reduction_type: Literal["image", "batch"], | |
) -> Tensor: | |
""" | |
Whether to consolidate the input_tensor across the batch or across the image | |
Args: | |
input_tensor: input tensor | |
mask: mask tensor | |
reduction_type: either "batch" or "image" | |
Returns: | |
input_tensor: reduced input_tensor | |
""" | |
if reduction_type == "batch": | |
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) | |
divisor = torch.sum(mask) | |
if divisor == 0: | |
return torch.tensor(0, device=input_tensor.device) | |
input_tensor = torch.sum(input_tensor) / divisor | |
elif reduction_type == "image": | |
# avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) | |
valid = mask.nonzero() | |
input_tensor[valid] = input_tensor[valid] / mask[valid] | |
input_tensor = torch.mean(input_tensor) | |
return input_tensor | |
def normalized_depth_scale_and_shift( | |
prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"] | |
): | |
""" | |
More info here: https://arxiv.org/pdf/2206.00665.pdf supplementary section A2 Depth Consistency Loss | |
This function computes scale/shift required to normalizes predicted depth map, | |
to allow for using normalized depth maps as input from monocular depth estimation networks. | |
These networks are trained such that they predict normalized depth maps. | |
Solves for scale/shift using a least squares approach with a closed form solution: | |
Based on: | |
https://github.com/autonomousvision/monosdf/blob/d9619e948bf3d85c6adec1a643f679e2e8e84d4b/code/model/loss.py#L7 | |
Args: | |
prediction: predicted depth map | |
target: ground truth depth map | |
mask: mask of valid pixels | |
Returns: | |
scale and shift for depth prediction | |
""" | |
# system matrix: A = [[a_00, a_01], [a_10, a_11]] | |
a_00 = torch.sum(mask * prediction * prediction, (1, 2)) | |
a_01 = torch.sum(mask * prediction, (1, 2)) | |
a_11 = torch.sum(mask, (1, 2)) | |
# right hand side: b = [b_0, b_1] | |
b_0 = torch.sum(mask * prediction * target, (1, 2)) | |
b_1 = torch.sum(mask * target, (1, 2)) | |
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b | |
scale = torch.zeros_like(b_0) | |
shift = torch.zeros_like(b_1) | |
det = a_00 * a_11 - a_01 * a_01 | |
valid = det.nonzero() | |
scale[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] | |
shift[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] | |
return scale, shift | |
class MiDaSMSELoss(nn.Module): | |
""" | |
data term from MiDaS paper | |
""" | |
def __init__(self, reduction_type: Literal["image", "batch"] = "batch"): | |
super().__init__() | |
self.reduction_type: Literal["image", "batch"] = reduction_type | |
# reduction here is different from the image/batch-based reduction. This is either "mean" or "sum" | |
self.mse_loss = MSELoss(reduction="none") | |
def forward( | |
self, | |
prediction: Float[Tensor, "1 32 mult"], | |
target: Float[Tensor, "1 32 mult"], | |
mask: Bool[Tensor, "1 32 mult"], | |
) -> Float[Tensor, "0"]: | |
""" | |
Args: | |
prediction: predicted depth map | |
target: ground truth depth map | |
mask: mask of valid pixels | |
Returns: | |
mse loss based on reduction function | |
""" | |
summed_mask = torch.sum(mask, (1, 2)) | |
image_loss = torch.sum(self.mse_loss(prediction, target) * mask, (1, 2)) | |
# multiply by 2 magic number? | |
image_loss = masked_reduction(image_loss, 2 * summed_mask, self.reduction_type) | |
return image_loss | |
class GradientLoss(nn.Module): | |
""" | |
multiscale, scale-invariant gradient matching term to the disparity space. | |
This term biases discontinuities to be sharp and to coincide with discontinuities in the ground truth | |
More info here https://arxiv.org/pdf/1907.01341.pdf Equation 11 | |
""" | |
def __init__(self, scales: int = 4, reduction_type: Literal["image", "batch"] = "batch"): | |
""" | |
Args: | |
scales: number of scales to use | |
reduction_type: either "batch" or "image" | |
""" | |
super().__init__() | |
self.reduction_type: Literal["image", "batch"] = reduction_type | |
self.__scales = scales | |
def forward( | |
self, | |
prediction: Float[Tensor, "1 32 mult"], | |
target: Float[Tensor, "1 32 mult"], | |
mask: Bool[Tensor, "1 32 mult"], | |
) -> Float[Tensor, "0"]: | |
""" | |
Args: | |
prediction: predicted depth map | |
target: ground truth depth map | |
mask: mask of valid pixels | |
Returns: | |
gradient loss based on reduction function | |
""" | |
assert self.__scales >= 1 | |
total = 0.0 | |
for scale in range(self.__scales): | |
step = pow(2, scale) | |
grad_loss = self.gradient_loss( | |
prediction[:, ::step, ::step], | |
target[:, ::step, ::step], | |
mask[:, ::step, ::step], | |
) | |
total += grad_loss | |
assert isinstance(total, Tensor) | |
return total | |
def gradient_loss( | |
self, | |
prediction: Float[Tensor, "1 32 mult"], | |
target: Float[Tensor, "1 32 mult"], | |
mask: Bool[Tensor, "1 32 mult"], | |
) -> Float[Tensor, "0"]: | |
""" | |
multiscale, scale-invariant gradient matching term to the disparity space. | |
This term biases discontinuities to be sharp and to coincide with discontinuities in the ground truth | |
More info here https://arxiv.org/pdf/1907.01341.pdf Equation 11 | |
Args: | |
prediction: predicted depth map | |
target: ground truth depth map | |
reduction: reduction function, either reduction_batch_based or reduction_image_based | |
Returns: | |
gradient loss based on reduction function | |
""" | |
summed_mask = torch.sum(mask, (1, 2)) | |
diff = prediction - target | |
diff = torch.mul(mask, diff) | |
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) | |
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) | |
grad_x = torch.mul(mask_x, grad_x) | |
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) | |
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) | |
grad_y = torch.mul(mask_y, grad_y) | |
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) | |
image_loss = masked_reduction(image_loss, summed_mask, self.reduction_type) | |
return image_loss | |
class ScaleAndShiftInvariantLoss(nn.Module): | |
""" | |
Scale and shift invariant loss as described in | |
"Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer" | |
https://arxiv.org/pdf/1907.01341.pdf | |
""" | |
def __init__(self, alpha: float = 0.5, scales: int = 4, reduction_type: Literal["image", "batch"] = "batch"): | |
""" | |
Args: | |
alpha: weight of the regularization term | |
scales: number of scales to use | |
reduction_type: either "batch" or "image" | |
""" | |
super().__init__() | |
self.__data_loss = MiDaSMSELoss(reduction_type=reduction_type) | |
self.__regularization_loss = GradientLoss(scales=scales, reduction_type=reduction_type) | |
self.__alpha = alpha | |
self.__prediction_ssi = None | |
def forward( | |
self, | |
prediction: Float[Tensor, "1 32 mult"], | |
target: Float[Tensor, "1 32 mult"], | |
mask: Bool[Tensor, "1 32 mult"], | |
) -> Float[Tensor, "0"]: | |
""" | |
Args: | |
prediction: predicted depth map (unnormalized) | |
target: ground truth depth map (normalized) | |
mask: mask of valid pixels | |
Returns: | |
scale and shift invariant loss | |
""" | |
scale, shift = normalized_depth_scale_and_shift(prediction, target, mask) | |
self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) | |
total = self.__data_loss(self.__prediction_ssi, target, mask) | |
# if self.__alpha > 0: | |
# total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) | |
return total | |
def __get_prediction_ssi(self): | |
""" | |
scale and shift invariant prediction | |
from https://arxiv.org/pdf/1907.01341.pdf equation 1 | |
""" | |
return self.__prediction_ssi | |
prediction_ssi = property(__get_prediction_ssi) |