Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from typing import Any, Callable, Protocol | |
import torch | |
import kornia | |
from torch import profiler | |
import torch.nn.functional as F | |
from scenedino.common.util import kl_div, normalized_entropy | |
from scenedino.losses.base_loss import BaseLoss | |
from scenedino.common.errors import ( | |
alpha_consistency_uncert, | |
compute_l1ssim, | |
compute_edge_aware_smoothness, | |
compute_3d_smoothness, | |
compute_normalized_l1, | |
depth_smoothness_regularization, | |
depth_regularization, | |
alpha_regularization, | |
flow_regularization, | |
kl_prop, | |
max_alpha_inputframe_regularization, | |
surfaceness_regularization, | |
sdf_eikonal_regularization, | |
weight_entropy_regularization, | |
max_alpha_regularization, | |
density_grid_regularization, | |
alpha_consistency, | |
entropy_based_smoothness, | |
) | |
EPS = 1e-5 | |
# TODO: need wrappers around the different losses as an interface to the data variable | |
def make_reconstruction_error( | |
criterion: str, | |
) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: | |
match criterion: | |
case "l1": | |
return lambda a, b: torch.mean(torch.nn.L1Loss(reduction="none")(a, b), dim=1) | |
case "l1+ssim": | |
return compute_l1ssim | |
case "l2": | |
return lambda a, b: torch.mean(torch.nn.MSELoss(reduction="none")(a, b) / 2, dim=1) | |
case "cosine": | |
return lambda a, b: 1 - torch.nn.CosineSimilarity(dim=1)(a, b) | |
case _: | |
raise ValueError(f"Unknown reconstruction error: {criterion}") | |
def make_regularization( | |
config, ignore_invalid: bool | |
) -> Callable[[Any, int], torch.Tensor]: | |
"""Make a regularization function from the config. | |
Args: | |
config (dict): config dict | |
Returns: | |
Callable[[torch.Tensor], torch.Tensor]: regularization function | |
""" | |
match config["type"]: | |
case "edge_aware_smoothness": | |
def _wrapper(data, scale): | |
gt_img = data["rgb_gt"][..., :3] | |
depth = data["coarse"][scale]["depth"].permute(1, 0, 2, 3) | |
_, _, h, w = depth.shape | |
gt_img = ( | |
gt_img.unsqueeze(-2).permute(0, 1, 4, 5, 2, 3).reshape(-1, 3, h, w) | |
) | |
depth_input = 1 / depth.reshape(-1, 1, h, w).clamp(1e-3, 80) | |
depth_input = depth_input / torch.mean(depth_input, dim=[2, 3], keepdim=True) | |
return compute_edge_aware_smoothness( | |
gt_img, depth_input, temperature=1 | |
).mean() | |
return _wrapper | |
case "dino_edge_aware_smoothness": | |
def _wrapper(data, scale): | |
gt_img = data["rgb_gt"][..., :3] | |
dino = data["coarse"][scale]["dino_features"] | |
_, _, h, w, _, c_dino = dino.shape | |
gt_img = gt_img.unsqueeze(-2).permute(0, 1, 4, 5, 2, 3).reshape(-1, 3, h, w) | |
dino_input = dino.permute(0, 1, 4, 5, 2, 3).reshape(-1, c_dino, h, w) | |
return compute_edge_aware_smoothness( | |
gt_img, dino_input, temperature=25 | |
).mean() | |
return _wrapper | |
case _: | |
raise ValueError(f"Unknown regularization type: {config['type']}") | |
class PolicyCallable(Protocol): | |
def __call__(self, invalids: torch.Tensor, **kwargs) -> torch.Tensor: | |
... | |
def strict_policy(invalids: torch.Tensor, **kwargs: Any) -> torch.Tensor: | |
invalid = torch.all(torch.any(invalids > 0.5, dim=-2), dim=-1).unsqueeze(-1) | |
return invalid | |
def weight_guided_policy(invalids: torch.Tensor, **kwargs) -> torch.Tensor: | |
weights = kwargs["weights"] | |
invalid = torch.all( | |
(invalids.to(weights.dtype) * weights.unsqueeze(-1)).sum(-2) > 0.9, | |
dim=-1, | |
keepdim=True, | |
) | |
return invalid | |
def occ_and_weight_guided_policy(invalids: torch.Tensor, **kwargs) -> torch.Tensor: | |
weight_guided_invalid = weight_guided_policy(invalids, **kwargs) | |
# occs = 1 indicates that there can be a valid reprojection. Therefore, we have to negate it | |
occ = kwargs["occ"] | |
invalid = weight_guided_invalid | (~(occ.to(kwargs["weights"].dtype) > 0.5)) | |
return invalid | |
def weight_guided_diverse_policy(invalids: torch.Tensor, **kwargs) -> torch.Tensor: | |
rgb_samps = kwargs["rgb_samps"] | |
ray_std = torch.std(rgb_samps, dim=-3).mean(-1) | |
weights = kwargs["weights"] | |
invalid = torch.all( | |
((invalids.to(torch.float32) * weights.unsqueeze(-1)).sum(-2) > 0.9) | |
| (ray_std < 0.01), | |
dim=-1, | |
keepdim=True, | |
) | |
return invalid | |
def no_policy(invalids: torch.Tensor, **kwargs) -> torch.Tensor: | |
invalid = torch.zeros_like( | |
torch.all(torch.any(invalids > 0.5, dim=-2), dim=-1).unsqueeze(-1), | |
dtype=torch.bool, | |
) | |
return invalid | |
def invalid_policy( | |
invalid_policy: str, | |
) -> PolicyCallable: | |
match invalid_policy: | |
case "strict": | |
return strict_policy | |
case "weight_guided": | |
return weight_guided_policy | |
case "weight_guided_diverse": | |
return weight_guided_diverse_policy | |
case "occ_weight_guided": | |
return occ_and_weight_guided_policy | |
case None | "none": | |
return no_policy | |
case _: | |
raise ValueError(f"Unknown invalid policy: {invalid_policy}") | |
# TODO: scale all of them with a lambda factor | |
class ReconstructionLoss(BaseLoss): | |
def __init__(self, config, use_automasking: bool = False) -> None: | |
super().__init__(config) | |
if config.get("fine", None) is None: | |
self.rgb_fine_crit = None | |
else: | |
self.rgb_fine_crit = make_reconstruction_error( | |
config["fine"].get("criterion", "l2") | |
) | |
self.dino_fine_crit = make_reconstruction_error( | |
config["fine"].get("dino_criterion", "l2") | |
) | |
self.lambda_fine = config["fine"].get("lambda", 1) | |
if config.get("coarse", None) is None: | |
self.rgb_coarse_crit = None | |
else: | |
self.rgb_coarse_crit = make_reconstruction_error( | |
config["coarse"].get("criterion", "l2") | |
) | |
self.dino_coarse_crit = make_reconstruction_error( | |
config["coarse"].get("dino_criterion", "l2") | |
) | |
self.lambda_coarse = config["coarse"].get("lambda", 1) | |
self.invalid_policy = invalid_policy(config.get("invalid_policy", "strict")) | |
self.ignore_invalid = self.invalid_policy is not no_policy | |
self.regularizations: list[tuple] = [] | |
for regularization_config in config["regularizations"]: | |
reg_fn = make_regularization(regularization_config, self.ignore_invalid) | |
self.regularizations.append( | |
(regularization_config["type"], reg_fn, regularization_config["lambda"]) | |
) | |
self.median_thresholding = config.get("median_thresholding", False) | |
self.reconstruct_dino = config.get("reconstruct_dino", False) | |
self.lambda_dino_coarse = config.get("lambda_dino_coarse", 1) | |
self.lambda_dino_fine = config.get("lambda_dino_fine", 1) | |
self.temperature_dino = config.get("temperature_dino", 1) | |
def get_loss_metric_names(self) -> list[str]: | |
loss_metric_names = ["rec_loss"] | |
if self.rgb_fine_crit is not None: | |
loss_metric_names.append("loss_rgb_fine") | |
if self.reconstruct_dino: | |
loss_metric_names.append("loss_dino_fine") | |
if self.rgb_coarse_crit is not None: | |
loss_metric_names.append("loss_rgb_coarse") | |
if self.reconstruct_dino: | |
loss_metric_names.append("loss_dino_coarse") | |
for regularization in self.regularizations: | |
loss_metric_names.append(regularization[0]) | |
return loss_metric_names | |
def __call__(self, data) -> dict[str, torch.Tensor]: | |
# print(data["dino_gt"].shape) | |
# print(data["coarse"][0]["dino_features"].shape) | |
with profiler.record_function("loss_computation"): | |
n_scales = len(data["coarse"]) | |
if self.rgb_coarse_crit is not None: | |
invalid_coarse = self.invalid_policy( | |
data["coarse"][0]["invalid"], | |
weights=data["coarse"][0]["weights"], | |
# rgb_samps=data["coarse"][0]["rgb_samps"], | |
) | |
loss_device = invalid_coarse.device | |
if self.rgb_fine_crit is not None: | |
invalid_fine = self.invalid_policy( | |
data["fine"][0]["invalid"], | |
weights=data["fine"][0]["weights"], | |
# rgb_samps=data["fine"][0]["rgb_samps"], | |
) | |
loss_device = invalid_fine.device | |
losses = { | |
name: torch.tensor(0.0, device=loss_device) | |
for name in self.get_loss_metric_names() | |
} | |
for scale in range(n_scales): | |
if self.rgb_coarse_crit is not None: | |
coarse = data["coarse"][scale] | |
rgb_coarse = coarse["rgb"] | |
if "dino_features_downsampled" in coarse: | |
dino_coarse = coarse["dino_features_downsampled"] | |
else: | |
dino_coarse = coarse["dino_features"] | |
if self.rgb_fine_crit is not None: | |
fine = data["fine"][scale] | |
rgb_fine = fine["rgb"] | |
if "dino_features_downsampled" in fine: | |
dino_fine = fine["dino_features_downsampled"] | |
else: | |
dino_fine = fine["dino_features"] | |
if "dino_artifacts" in data: | |
dino_artifacts = data["dino_artifacts"].unsqueeze(-2).expand(dino_coarse.shape) | |
dino_coarse = dino_coarse + dino_artifacts | |
rgb_gt = data["rgb_gt"].unsqueeze(-2).expand(rgb_coarse.shape) | |
dino_gt = data["dino_gt"].unsqueeze(-2).expand(dino_coarse.shape) | |
def rgb_loss(pred, gt, invalid, criterion): | |
# TODO: move the reshaping and selection to the wrapper, maybe other functions as well | |
b, pc, h, w, num_views, channels = pred.shape | |
loss = ( | |
criterion( | |
pred.permute(0, 1, 4, 5, 2, 3).reshape(-1, channels, h, w), | |
gt.permute(0, 1, 4, 5, 2, 3).reshape(-1, channels, h, w), | |
) | |
.view(b, pc, num_views, h, w) | |
.permute(0, 1, 3, 4, 2) | |
.unsqueeze(-1) | |
) | |
loss = loss.amin(-2) | |
if self.ignore_invalid and invalid is not None: | |
loss = loss * (1 - invalid.to(torch.float32)) | |
if self.median_thresholding: | |
threshold = torch.median(loss.view(b, -1), dim=-1)[0].view( | |
-1, 1, 1, 1, 1 | |
) | |
loss = loss[loss <= threshold] | |
return loss.mean() | |
def dino_loss(pred, gt, invalid, criterion): | |
# TODO: move the reshaping and selection to the wrapper, maybe other functions as well | |
channels = pred.shape[-1] | |
loss = ( | |
criterion( | |
pred.reshape(-1, channels), | |
gt.reshape(-1, channels), | |
) | |
) | |
# TODO: invalid feature handling | |
return loss.nanmean() | |
if self.rgb_coarse_crit is not None: | |
loss_coarse = rgb_loss( | |
rgb_coarse, rgb_gt, invalid_coarse, self.rgb_coarse_crit | |
) | |
losses["loss_rgb_coarse"] += loss_coarse.item() | |
losses["rec_loss"] += loss_coarse * self.lambda_coarse | |
if self.reconstruct_dino: | |
loss_coarse = dino_loss( | |
self.temperature_dino * dino_coarse, self.temperature_dino * dino_gt, | |
None, self.dino_coarse_crit | |
) | |
losses["loss_dino_coarse"] += loss_coarse.item() | |
losses["rec_loss"] += loss_coarse * self.lambda_coarse * self.lambda_dino_coarse | |
if self.rgb_fine_crit is not None: | |
loss_fine = rgb_loss( | |
rgb_fine, rgb_gt, invalid_fine, self.rgb_fine_crit | |
) | |
losses["loss_rgb_fine"] += loss_fine.item() | |
losses["rec_loss"] += loss_fine * self.lambda_fine | |
if self.reconstruct_dino: | |
loss_fine = dino_loss( | |
dino_fine, dino_gt, invalid_fine.unsqueeze(-1), self.dino_fine_crit | |
) | |
losses["loss_dino_fine"] += loss_fine.item() | |
losses["rec_loss"] += loss_fine * self.lambda_fine * self.lambda_dino_fine | |
for regularization in self.regularizations: | |
# TODO: make it properly work with the different scales | |
reg_loss = regularization[1](data, scale) | |
if reg_loss: | |
losses[regularization[0]] += reg_loss.item() | |
losses["rec_loss"] += reg_loss * regularization[2] | |
losses = {name: value / n_scales for name, value in losses.items()} | |
return losses | |