AnySplat / src /loss /loss_point.py
alexnasa's picture
Upload 243 files
2568013 verified
import torch
import torch.nn as nn
from copy import copy, deepcopy
from src.geometry.ptc_geometry import geotrf, inv, normalize_pointcloud, depthmap_to_pts3d
# from torchmetrics.functional.regression import pearson_corrcoef
# from pytorch3d.loss import chamfer_distance
def get_pred_pts3d(gt, pred, use_pose=False):
if 'depth' in pred and 'pseudo_focal' in pred:
try:
pp = gt['camera_intrinsics'][..., :2, 2]
except KeyError:
pp = None
pts3d = depthmap_to_pts3d(**pred, pp=pp)
elif 'pts3d' in pred:
# pts3d from my camera
pts3d = pred['pts3d']
elif 'pts3d_in_other_view' in pred:
# pts3d from the other camera, already transformed
assert use_pose is True
return pred['pts3d_in_other_view'] # return!
if use_pose:
camera_pose = pred.get('camera_pose')
assert camera_pose is not None
pts3d = geotrf(camera_pose, pts3d)
return pts3d
class LLoss (nn.Module):
""" L-norm loss
"""
def __init__(self, reduction='mean'):
super().__init__()
self.reduction = reduction
def forward(self, a, b):
assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'
dist = self.distance(a, b)
assert dist.ndim == a.ndim-1 # one dimension less
if self.reduction == 'none':
return dist
if self.reduction == 'sum':
return dist.sum()
if self.reduction == 'mean':
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
raise ValueError(f'bad {self.reduction=} mode')
def distance(self, a, b):
raise NotImplementedError()
class L21Loss (LLoss):
""" Euclidean distance between 3d points """
def distance(self, a, b):
return torch.norm(a - b, dim=-1) # normalized L2 distance
class MultiLoss (nn.Module):
""" Easily combinable losses (also keep track of individual loss values):
loss = MyLoss1() + 0.1*MyLoss2()
Usage:
Inherit from this class and override get_name() and compute_loss()
"""
def __init__(self):
super().__init__()
self._alpha = 1
self._loss2 = None
def compute_loss(self, *args, **kwargs):
raise NotImplementedError()
def get_name(self):
raise NotImplementedError()
def __mul__(self, alpha):
assert isinstance(alpha, (int, float))
res = copy(self)
res._alpha = alpha
return res
__rmul__ = __mul__ # same
def __add__(self, loss2):
assert isinstance(loss2, MultiLoss)
res = cur = copy(self)
# find the end of the chain
while cur._loss2 is not None:
cur = cur._loss2
cur._loss2 = loss2
return res
def __repr__(self):
name = self.get_name()
if self._alpha != 1:
name = f'{self._alpha:g}*{name}'
if self._loss2:
name = f'{name} + {self._loss2}'
return name
def forward(self, *args, **kwargs):
loss = self.compute_loss(*args, **kwargs)
if isinstance(loss, tuple):
loss, details = loss
elif loss.ndim == 0:
details = {self.get_name(): float(loss)}
else:
details = {}
loss = loss * self._alpha
if self._loss2:
loss2, details2 = self._loss2(*args, **kwargs)
loss = loss + loss2
details |= details2
return loss, details
class Criterion (nn.Module):
def __init__(self, criterion=None):
super().__init__()
assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'+bb()
self.criterion = copy(criterion)
def get_name(self):
return f'{type(self).__name__}({self.criterion})'
def with_reduction(self, mode):
res = loss = deepcopy(self)
while loss is not None:
assert isinstance(loss, Criterion)
loss.criterion.reduction = 'none' # make it return the loss for each sample
loss = loss._loss2 # we assume loss is a Multiloss
return res
class ConfLoss (MultiLoss):
""" Weighted regression by learned confidence.
Assuming the input pixel_loss is a pixel-level regression loss.
Principle:
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
alpha: hyperparameter
"""
def __init__(self, pixel_loss, alpha=1):
super().__init__()
assert alpha > 0
self.alpha = alpha
self.pixel_loss = pixel_loss.with_reduction('none')
def get_name(self):
return f'ConfLoss({self.pixel_loss})'
def get_conf_log(self, x):
return x, torch.log(x)
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
# compute per-pixel loss
((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
if loss1.numel() == 0:
print('NO VALID POINTS in img1', force=True)
if loss2.numel() == 0:
print('NO VALID POINTS in img2', force=True)
# weight by confidence
conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1])
conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2])
conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
# average + nan protection (in case of no valid pixels at all)
conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details)
class Regr3D(nn.Module):
""" Ensure that all 3D points are correct.
Asymmetric loss: view1 is supposed to be the anchor.
P1 = RT1 @ D1
P2 = RT2 @ D2
loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)
loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)
= (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)
"""
def __init__(self, norm_mode='avg_dis', alpha=0.2, gt_scale=False):
super().__init__()
self.norm_mode = norm_mode
self.alpha = alpha
self.gt_scale = gt_scale
def get_conf_log(self, x):
return x, torch.log(x)
def forward(self, gt_pts1, gt_pts2, pr_pts1, pr_pts2, conf1=None, conf2=None, dist_clip=None, disable_view1=False):
valid1 = valid2 = torch.ones_like(conf1, dtype=torch.bool)
if dist_clip is not None:
# points that are too far-away == invalid
dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
valid1 = (dis1 <= dist_clip)
valid2 = (dis2 <= dist_clip)
else:
dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
# only keep the points norm whithin the range of 1% to 99% of each batch
# Flatten along the H and W dimensions
dis1_flat = dis1.view(dis1.shape[0], -1)
dis2_flat = dis2.view(dis2.shape[0], -1)
# Compute the 0.1% and 99.9% quantiles for each batch
# quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.01, 0.99]).to(dis1_flat.device), dim=1)
# quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.01, 0.99]).to(dis2_flat.device), dim=1)
quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.002, 0.998]).to(dis1_flat.device), dim=1)
quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.002, 0.998]).to(dis2_flat.device), dim=1)
# Create masks based on the quantiles
valid1 = (dis1 >= quantiles_1[0].view(-1, 1, 1)) & (dis1 <= quantiles_1[1].view(-1, 1, 1))
valid2 = (dis2 >= quantiles_2[0].view(-1, 1, 1)) & (dis2 <= quantiles_2[1].view(-1, 1, 1))
# set min confidence to 3
valid1 = valid1 & (conf1 >= 3)
valid2 = valid2 & (conf2 >= 3)
# normalize 3d points
if self.norm_mode:
pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2)
if self.norm_mode and not self.gt_scale:
gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2)
loss1 = torch.norm(pr_pts1 - gt_pts1, dim=-1)
loss2 = torch.norm(pr_pts2 - gt_pts2, dim=-1)
# loss1 = (pr_pts1[..., -1] - gt_pts1[..., -1]).abs()
# loss2 = (pr_pts2[..., -1] - gt_pts2[..., -1]).abs()
loss1, loss2 = loss1[valid1], loss2[valid2]
if disable_view1:
return loss2.mean()
return loss1.mean() + loss2.mean()
# conf1, conf2 = conf1[valid1], conf2[valid2]
# conf1, conf2 = conf1.softmax(dim=-1), conf2.softmax(dim=-1)
# loss1 = (loss1 * conf1).sum()
# loss2 = (loss2 * conf2).sum()
# return loss1 + loss2
#
# # weight by confidence
# conf1, log_conf1 = self.get_conf_log(conf1[valid1])
# conf2, log_conf2 = self.get_conf_log(conf2[valid2])
# conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
# conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
#
# # average + nan protection (in case of no valid pixels at all)
# conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
# conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
#
# return conf_loss1 + conf_loss2
# def forward(self, gt_pts1, gt_pts2, pr_pts1, pr_pts2, conf1=None, conf2=None, dist_clip=None, disable_view1=False):
# # valid1 = valid2 = torch.ones_like(conf1, dtype=torch.bool)
# if dist_clip is not None:
# # points that are too far-away == invalid
# dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
# dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
# valid1 = (dis1 <= dist_clip)
# valid2 = (dis2 <= dist_clip)
# else:
# dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
# dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
#
# # only keep the points norm whithin the range of 1% to 99% of each batch
# # Flatten along the H and W dimensions
# dis1_flat = dis1.view(dis1.shape[0], -1)
# dis2_flat = dis2.view(dis2.shape[0], -1)
#
# # Compute the 0.1% and 99.9% quantiles for each batch
# quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.1, 0.9]).to(dis1_flat.device), dim=1)
# quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.1, 0.9]).to(dis2_flat.device), dim=1)
# # quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.002, 0.998]).to(dis1_flat.device), dim=1)
# # quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.002, 0.998]).to(dis2_flat.device), dim=1)
#
# # Create masks based on the quantiles
# valid1 = (dis1 >= quantiles_1[0].view(-1, 1, 1)) & (dis1 <= quantiles_1[1].view(-1, 1, 1))
# valid2 = (dis2 >= quantiles_2[0].view(-1, 1, 1)) & (dis2 <= quantiles_2[1].view(-1, 1, 1))
#
# # set min opacity to 3
# valid1 = valid1 & (conf1 >= 0.2)
# valid2 = valid2 & (conf2 >= 0.2)
#
# # # normalize 3d points
# # if self.norm_mode:
# # pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2)
# # if self.norm_mode and not self.gt_scale:
# # gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2)
#
# # L1 loss
# # loss1 = (pr_pts1[..., -1] - gt_pts1[..., -1]).abs()
# # loss2 = (pr_pts2[..., -1] - gt_pts2[..., -1]).abs()
#
# # L2 loss
# loss1 = torch.norm(pr_pts1 - gt_pts1, dim=-1)
# loss2 = torch.norm(pr_pts2 - gt_pts2, dim=-1)
# loss1, loss2 = loss1[valid1], loss2[valid2]
#
# # Pearson correlation coefficient loss
# # pr_pts1, pr_pts2 = pr_pts1[valid1], pr_pts2[valid2]
# # gt_pts1, gt_pts2 = gt_pts1[valid1], gt_pts2[valid2]
# # loss1 = 1 - pearson_corrcoef(pr_pts1.view(-1, 3), gt_pts1.view(-1, 3))
# # loss2 = 1 - pearson_corrcoef(pr_pts2.view(-1, 3), gt_pts2.view(-1, 3))
#
# # # Chamfer distance loss
# # pr_pts = torch.cat([pr_pts1.flatten(1, 2), pr_pts2.flatten(1, 2)], dim=1)
# # gt_pts = torch.cat([gt_pts1.flatten(1, 2), gt_pts2.flatten(1, 2)], dim=1)
# # valid_mask = torch.cat([valid1.flatten(1, 2), valid2.flatten(1, 2)], dim=1)
# # nan_pts_pr, nnz = invalid_to_zeros(pr_pts, valid_mask, ndim=3)
# # nan_pts_gt, nnz = invalid_to_zeros(gt_pts, valid_mask, ndim=3)
# #
# # loss, _ = chamfer_distance(nan_pts_pr, nan_pts_gt, batch_reduction=None, point_reduction=None)
# # loss1, loss2 = loss[0], loss[1]
# # return loss1.sum() / valid_mask.sum()
#
# if disable_view1:
# return loss2.mean()
# return loss1.mean() + loss2.mean()
#
# # conf1, conf2 = conf1[valid1], conf2[valid2]
# # conf1, conf2 = conf1.softmax(dim=-1), conf2.softmax(dim=-1)
# # loss1 = (loss1 * conf1).sum()
# # loss2 = (loss2 * conf2).sum()
# # return loss1 + loss2
# #
# # # weight by confidence
# # conf1, log_conf1 = self.get_conf_log(conf1[valid1])
# # conf2, log_conf2 = self.get_conf_log(conf2[valid2])
# # conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
# # conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
# #
# # # average + nan protection (in case of no valid pixels at all)
# # conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
# # conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
# #
# # return conf_loss1 + conf_loss2