SceneDINO / scenedino /losses /reconstruction_loss.py
jev-aleks's picture
scenedino init
9e15541
import math
from typing import Any, Callable, Protocol
import torch
import kornia
from torch import profiler
import torch.nn.functional as F
from scenedino.common.util import kl_div, normalized_entropy
from scenedino.losses.base_loss import BaseLoss
from scenedino.common.errors import (
alpha_consistency_uncert,
compute_l1ssim,
compute_edge_aware_smoothness,
compute_3d_smoothness,
compute_normalized_l1,
depth_smoothness_regularization,
depth_regularization,
alpha_regularization,
flow_regularization,
kl_prop,
max_alpha_inputframe_regularization,
surfaceness_regularization,
sdf_eikonal_regularization,
weight_entropy_regularization,
max_alpha_regularization,
density_grid_regularization,
alpha_consistency,
entropy_based_smoothness,
)
EPS = 1e-5
# TODO: need wrappers around the different losses as an interface to the data variable
def make_reconstruction_error(
criterion: str,
) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
match criterion:
case "l1":
return lambda a, b: torch.mean(torch.nn.L1Loss(reduction="none")(a, b), dim=1)
case "l1+ssim":
return compute_l1ssim
case "l2":
return lambda a, b: torch.mean(torch.nn.MSELoss(reduction="none")(a, b) / 2, dim=1)
case "cosine":
return lambda a, b: 1 - torch.nn.CosineSimilarity(dim=1)(a, b)
case _:
raise ValueError(f"Unknown reconstruction error: {criterion}")
def make_regularization(
config, ignore_invalid: bool
) -> Callable[[Any, int], torch.Tensor]:
"""Make a regularization function from the config.
Args:
config (dict): config dict
Returns:
Callable[[torch.Tensor], torch.Tensor]: regularization function
"""
match config["type"]:
case "edge_aware_smoothness":
def _wrapper(data, scale):
gt_img = data["rgb_gt"][..., :3]
depth = data["coarse"][scale]["depth"].permute(1, 0, 2, 3)
_, _, h, w = depth.shape
gt_img = (
gt_img.unsqueeze(-2).permute(0, 1, 4, 5, 2, 3).reshape(-1, 3, h, w)
)
depth_input = 1 / depth.reshape(-1, 1, h, w).clamp(1e-3, 80)
depth_input = depth_input / torch.mean(depth_input, dim=[2, 3], keepdim=True)
return compute_edge_aware_smoothness(
gt_img, depth_input, temperature=1
).mean()
return _wrapper
case "dino_edge_aware_smoothness":
def _wrapper(data, scale):
gt_img = data["rgb_gt"][..., :3]
dino = data["coarse"][scale]["dino_features"]
_, _, h, w, _, c_dino = dino.shape
gt_img = gt_img.unsqueeze(-2).permute(0, 1, 4, 5, 2, 3).reshape(-1, 3, h, w)
dino_input = dino.permute(0, 1, 4, 5, 2, 3).reshape(-1, c_dino, h, w)
return compute_edge_aware_smoothness(
gt_img, dino_input, temperature=25
).mean()
return _wrapper
case _:
raise ValueError(f"Unknown regularization type: {config['type']}")
class PolicyCallable(Protocol):
def __call__(self, invalids: torch.Tensor, **kwargs) -> torch.Tensor:
...
def strict_policy(invalids: torch.Tensor, **kwargs: Any) -> torch.Tensor:
invalid = torch.all(torch.any(invalids > 0.5, dim=-2), dim=-1).unsqueeze(-1)
return invalid
def weight_guided_policy(invalids: torch.Tensor, **kwargs) -> torch.Tensor:
weights = kwargs["weights"]
invalid = torch.all(
(invalids.to(weights.dtype) * weights.unsqueeze(-1)).sum(-2) > 0.9,
dim=-1,
keepdim=True,
)
return invalid
def occ_and_weight_guided_policy(invalids: torch.Tensor, **kwargs) -> torch.Tensor:
weight_guided_invalid = weight_guided_policy(invalids, **kwargs)
# occs = 1 indicates that there can be a valid reprojection. Therefore, we have to negate it
occ = kwargs["occ"]
invalid = weight_guided_invalid | (~(occ.to(kwargs["weights"].dtype) > 0.5))
return invalid
def weight_guided_diverse_policy(invalids: torch.Tensor, **kwargs) -> torch.Tensor:
rgb_samps = kwargs["rgb_samps"]
ray_std = torch.std(rgb_samps, dim=-3).mean(-1)
weights = kwargs["weights"]
invalid = torch.all(
((invalids.to(torch.float32) * weights.unsqueeze(-1)).sum(-2) > 0.9)
| (ray_std < 0.01),
dim=-1,
keepdim=True,
)
return invalid
def no_policy(invalids: torch.Tensor, **kwargs) -> torch.Tensor:
invalid = torch.zeros_like(
torch.all(torch.any(invalids > 0.5, dim=-2), dim=-1).unsqueeze(-1),
dtype=torch.bool,
)
return invalid
def invalid_policy(
invalid_policy: str,
) -> PolicyCallable:
match invalid_policy:
case "strict":
return strict_policy
case "weight_guided":
return weight_guided_policy
case "weight_guided_diverse":
return weight_guided_diverse_policy
case "occ_weight_guided":
return occ_and_weight_guided_policy
case None | "none":
return no_policy
case _:
raise ValueError(f"Unknown invalid policy: {invalid_policy}")
# TODO: scale all of them with a lambda factor
class ReconstructionLoss(BaseLoss):
def __init__(self, config, use_automasking: bool = False) -> None:
super().__init__(config)
if config.get("fine", None) is None:
self.rgb_fine_crit = None
else:
self.rgb_fine_crit = make_reconstruction_error(
config["fine"].get("criterion", "l2")
)
self.dino_fine_crit = make_reconstruction_error(
config["fine"].get("dino_criterion", "l2")
)
self.lambda_fine = config["fine"].get("lambda", 1)
if config.get("coarse", None) is None:
self.rgb_coarse_crit = None
else:
self.rgb_coarse_crit = make_reconstruction_error(
config["coarse"].get("criterion", "l2")
)
self.dino_coarse_crit = make_reconstruction_error(
config["coarse"].get("dino_criterion", "l2")
)
self.lambda_coarse = config["coarse"].get("lambda", 1)
self.invalid_policy = invalid_policy(config.get("invalid_policy", "strict"))
self.ignore_invalid = self.invalid_policy is not no_policy
self.regularizations: list[tuple] = []
for regularization_config in config["regularizations"]:
reg_fn = make_regularization(regularization_config, self.ignore_invalid)
self.regularizations.append(
(regularization_config["type"], reg_fn, regularization_config["lambda"])
)
self.median_thresholding = config.get("median_thresholding", False)
self.reconstruct_dino = config.get("reconstruct_dino", False)
self.lambda_dino_coarse = config.get("lambda_dino_coarse", 1)
self.lambda_dino_fine = config.get("lambda_dino_fine", 1)
self.temperature_dino = config.get("temperature_dino", 1)
def get_loss_metric_names(self) -> list[str]:
loss_metric_names = ["rec_loss"]
if self.rgb_fine_crit is not None:
loss_metric_names.append("loss_rgb_fine")
if self.reconstruct_dino:
loss_metric_names.append("loss_dino_fine")
if self.rgb_coarse_crit is not None:
loss_metric_names.append("loss_rgb_coarse")
if self.reconstruct_dino:
loss_metric_names.append("loss_dino_coarse")
for regularization in self.regularizations:
loss_metric_names.append(regularization[0])
return loss_metric_names
def __call__(self, data) -> dict[str, torch.Tensor]:
# print(data["dino_gt"].shape)
# print(data["coarse"][0]["dino_features"].shape)
with profiler.record_function("loss_computation"):
n_scales = len(data["coarse"])
if self.rgb_coarse_crit is not None:
invalid_coarse = self.invalid_policy(
data["coarse"][0]["invalid"],
weights=data["coarse"][0]["weights"],
# rgb_samps=data["coarse"][0]["rgb_samps"],
)
loss_device = invalid_coarse.device
if self.rgb_fine_crit is not None:
invalid_fine = self.invalid_policy(
data["fine"][0]["invalid"],
weights=data["fine"][0]["weights"],
# rgb_samps=data["fine"][0]["rgb_samps"],
)
loss_device = invalid_fine.device
losses = {
name: torch.tensor(0.0, device=loss_device)
for name in self.get_loss_metric_names()
}
for scale in range(n_scales):
if self.rgb_coarse_crit is not None:
coarse = data["coarse"][scale]
rgb_coarse = coarse["rgb"]
if "dino_features_downsampled" in coarse:
dino_coarse = coarse["dino_features_downsampled"]
else:
dino_coarse = coarse["dino_features"]
if self.rgb_fine_crit is not None:
fine = data["fine"][scale]
rgb_fine = fine["rgb"]
if "dino_features_downsampled" in fine:
dino_fine = fine["dino_features_downsampled"]
else:
dino_fine = fine["dino_features"]
if "dino_artifacts" in data:
dino_artifacts = data["dino_artifacts"].unsqueeze(-2).expand(dino_coarse.shape)
dino_coarse = dino_coarse + dino_artifacts
rgb_gt = data["rgb_gt"].unsqueeze(-2).expand(rgb_coarse.shape)
dino_gt = data["dino_gt"].unsqueeze(-2).expand(dino_coarse.shape)
def rgb_loss(pred, gt, invalid, criterion):
# TODO: move the reshaping and selection to the wrapper, maybe other functions as well
b, pc, h, w, num_views, channels = pred.shape
loss = (
criterion(
pred.permute(0, 1, 4, 5, 2, 3).reshape(-1, channels, h, w),
gt.permute(0, 1, 4, 5, 2, 3).reshape(-1, channels, h, w),
)
.view(b, pc, num_views, h, w)
.permute(0, 1, 3, 4, 2)
.unsqueeze(-1)
)
loss = loss.amin(-2)
if self.ignore_invalid and invalid is not None:
loss = loss * (1 - invalid.to(torch.float32))
if self.median_thresholding:
threshold = torch.median(loss.view(b, -1), dim=-1)[0].view(
-1, 1, 1, 1, 1
)
loss = loss[loss <= threshold]
return loss.mean()
def dino_loss(pred, gt, invalid, criterion):
# TODO: move the reshaping and selection to the wrapper, maybe other functions as well
channels = pred.shape[-1]
loss = (
criterion(
pred.reshape(-1, channels),
gt.reshape(-1, channels),
)
)
# TODO: invalid feature handling
return loss.nanmean()
if self.rgb_coarse_crit is not None:
loss_coarse = rgb_loss(
rgb_coarse, rgb_gt, invalid_coarse, self.rgb_coarse_crit
)
losses["loss_rgb_coarse"] += loss_coarse.item()
losses["rec_loss"] += loss_coarse * self.lambda_coarse
if self.reconstruct_dino:
loss_coarse = dino_loss(
self.temperature_dino * dino_coarse, self.temperature_dino * dino_gt,
None, self.dino_coarse_crit
)
losses["loss_dino_coarse"] += loss_coarse.item()
losses["rec_loss"] += loss_coarse * self.lambda_coarse * self.lambda_dino_coarse
if self.rgb_fine_crit is not None:
loss_fine = rgb_loss(
rgb_fine, rgb_gt, invalid_fine, self.rgb_fine_crit
)
losses["loss_rgb_fine"] += loss_fine.item()
losses["rec_loss"] += loss_fine * self.lambda_fine
if self.reconstruct_dino:
loss_fine = dino_loss(
dino_fine, dino_gt, invalid_fine.unsqueeze(-1), self.dino_fine_crit
)
losses["loss_dino_fine"] += loss_fine.item()
losses["rec_loss"] += loss_fine * self.lambda_fine * self.lambda_dino_fine
for regularization in self.regularizations:
# TODO: make it properly work with the different scales
reg_loss = regularization[1](data, scale)
if reg_loss:
losses[regularization[0]] += reg_loss.item()
losses["rec_loss"] += reg_loss * regularization[2]
losses = {name: value / n_scales for name, value in losses.items()}
return losses