|
import torch |
|
import torch.nn as nn |
|
from copy import copy, deepcopy |
|
|
|
from src.geometry.ptc_geometry import geotrf, inv, normalize_pointcloud, depthmap_to_pts3d |
|
|
|
|
|
|
|
|
|
def get_pred_pts3d(gt, pred, use_pose=False): |
|
if 'depth' in pred and 'pseudo_focal' in pred: |
|
try: |
|
pp = gt['camera_intrinsics'][..., :2, 2] |
|
except KeyError: |
|
pp = None |
|
pts3d = depthmap_to_pts3d(**pred, pp=pp) |
|
|
|
elif 'pts3d' in pred: |
|
|
|
pts3d = pred['pts3d'] |
|
|
|
elif 'pts3d_in_other_view' in pred: |
|
|
|
assert use_pose is True |
|
return pred['pts3d_in_other_view'] |
|
|
|
if use_pose: |
|
camera_pose = pred.get('camera_pose') |
|
assert camera_pose is not None |
|
pts3d = geotrf(camera_pose, pts3d) |
|
|
|
return pts3d |
|
|
|
|
|
class LLoss (nn.Module): |
|
""" L-norm loss |
|
""" |
|
|
|
def __init__(self, reduction='mean'): |
|
super().__init__() |
|
self.reduction = reduction |
|
|
|
def forward(self, a, b): |
|
assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' |
|
dist = self.distance(a, b) |
|
assert dist.ndim == a.ndim-1 |
|
if self.reduction == 'none': |
|
return dist |
|
if self.reduction == 'sum': |
|
return dist.sum() |
|
if self.reduction == 'mean': |
|
return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) |
|
raise ValueError(f'bad {self.reduction=} mode') |
|
|
|
def distance(self, a, b): |
|
raise NotImplementedError() |
|
|
|
|
|
class L21Loss (LLoss): |
|
""" Euclidean distance between 3d points """ |
|
|
|
def distance(self, a, b): |
|
return torch.norm(a - b, dim=-1) |
|
|
|
|
|
class MultiLoss (nn.Module): |
|
""" Easily combinable losses (also keep track of individual loss values): |
|
loss = MyLoss1() + 0.1*MyLoss2() |
|
Usage: |
|
Inherit from this class and override get_name() and compute_loss() |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._alpha = 1 |
|
self._loss2 = None |
|
|
|
def compute_loss(self, *args, **kwargs): |
|
raise NotImplementedError() |
|
|
|
def get_name(self): |
|
raise NotImplementedError() |
|
|
|
def __mul__(self, alpha): |
|
assert isinstance(alpha, (int, float)) |
|
res = copy(self) |
|
res._alpha = alpha |
|
return res |
|
__rmul__ = __mul__ |
|
|
|
def __add__(self, loss2): |
|
assert isinstance(loss2, MultiLoss) |
|
res = cur = copy(self) |
|
|
|
while cur._loss2 is not None: |
|
cur = cur._loss2 |
|
cur._loss2 = loss2 |
|
return res |
|
|
|
def __repr__(self): |
|
name = self.get_name() |
|
if self._alpha != 1: |
|
name = f'{self._alpha:g}*{name}' |
|
if self._loss2: |
|
name = f'{name} + {self._loss2}' |
|
return name |
|
|
|
def forward(self, *args, **kwargs): |
|
loss = self.compute_loss(*args, **kwargs) |
|
if isinstance(loss, tuple): |
|
loss, details = loss |
|
elif loss.ndim == 0: |
|
details = {self.get_name(): float(loss)} |
|
else: |
|
details = {} |
|
loss = loss * self._alpha |
|
|
|
if self._loss2: |
|
loss2, details2 = self._loss2(*args, **kwargs) |
|
loss = loss + loss2 |
|
details |= details2 |
|
|
|
return loss, details |
|
|
|
|
|
class Criterion (nn.Module): |
|
def __init__(self, criterion=None): |
|
super().__init__() |
|
assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'+bb() |
|
self.criterion = copy(criterion) |
|
|
|
def get_name(self): |
|
return f'{type(self).__name__}({self.criterion})' |
|
|
|
def with_reduction(self, mode): |
|
res = loss = deepcopy(self) |
|
while loss is not None: |
|
assert isinstance(loss, Criterion) |
|
loss.criterion.reduction = 'none' |
|
loss = loss._loss2 |
|
return res |
|
|
|
|
|
class ConfLoss (MultiLoss): |
|
""" Weighted regression by learned confidence. |
|
Assuming the input pixel_loss is a pixel-level regression loss. |
|
|
|
Principle: |
|
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) |
|
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) |
|
|
|
alpha: hyperparameter |
|
""" |
|
|
|
def __init__(self, pixel_loss, alpha=1): |
|
super().__init__() |
|
assert alpha > 0 |
|
self.alpha = alpha |
|
self.pixel_loss = pixel_loss.with_reduction('none') |
|
|
|
def get_name(self): |
|
return f'ConfLoss({self.pixel_loss})' |
|
|
|
def get_conf_log(self, x): |
|
return x, torch.log(x) |
|
|
|
def compute_loss(self, gt1, gt2, pred1, pred2, **kw): |
|
|
|
((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) |
|
if loss1.numel() == 0: |
|
print('NO VALID POINTS in img1', force=True) |
|
if loss2.numel() == 0: |
|
print('NO VALID POINTS in img2', force=True) |
|
|
|
|
|
conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) |
|
conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) |
|
conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 |
|
conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 |
|
|
|
|
|
conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 |
|
conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 |
|
|
|
return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) |
|
|
|
|
|
class Regr3D(nn.Module): |
|
""" Ensure that all 3D points are correct. |
|
Asymmetric loss: view1 is supposed to be the anchor. |
|
|
|
P1 = RT1 @ D1 |
|
P2 = RT2 @ D2 |
|
loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1) |
|
loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) |
|
= (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) |
|
""" |
|
|
|
def __init__(self, norm_mode='avg_dis', alpha=0.2, gt_scale=False): |
|
super().__init__() |
|
self.norm_mode = norm_mode |
|
self.alpha = alpha |
|
self.gt_scale = gt_scale |
|
|
|
def get_conf_log(self, x): |
|
return x, torch.log(x) |
|
|
|
def forward(self, gt_pts1, gt_pts2, pr_pts1, pr_pts2, conf1=None, conf2=None, dist_clip=None, disable_view1=False): |
|
valid1 = valid2 = torch.ones_like(conf1, dtype=torch.bool) |
|
if dist_clip is not None: |
|
|
|
dis1 = gt_pts1.norm(dim=-1) |
|
dis2 = gt_pts2.norm(dim=-1) |
|
valid1 = (dis1 <= dist_clip) |
|
valid2 = (dis2 <= dist_clip) |
|
else: |
|
dis1 = gt_pts1.norm(dim=-1) |
|
dis2 = gt_pts2.norm(dim=-1) |
|
|
|
|
|
|
|
dis1_flat = dis1.view(dis1.shape[0], -1) |
|
dis2_flat = dis2.view(dis2.shape[0], -1) |
|
|
|
|
|
|
|
|
|
quantiles_1 = torch.quantile(dis1_flat, torch.tensor([0.002, 0.998]).to(dis1_flat.device), dim=1) |
|
quantiles_2 = torch.quantile(dis2_flat, torch.tensor([0.002, 0.998]).to(dis2_flat.device), dim=1) |
|
|
|
|
|
valid1 = (dis1 >= quantiles_1[0].view(-1, 1, 1)) & (dis1 <= quantiles_1[1].view(-1, 1, 1)) |
|
valid2 = (dis2 >= quantiles_2[0].view(-1, 1, 1)) & (dis2 <= quantiles_2[1].view(-1, 1, 1)) |
|
|
|
|
|
valid1 = valid1 & (conf1 >= 3) |
|
valid2 = valid2 & (conf2 >= 3) |
|
|
|
|
|
if self.norm_mode: |
|
pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2) |
|
if self.norm_mode and not self.gt_scale: |
|
gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2) |
|
|
|
loss1 = torch.norm(pr_pts1 - gt_pts1, dim=-1) |
|
loss2 = torch.norm(pr_pts2 - gt_pts2, dim=-1) |
|
|
|
|
|
|
|
loss1, loss2 = loss1[valid1], loss2[valid2] |
|
|
|
if disable_view1: |
|
return loss2.mean() |
|
return loss1.mean() + loss2.mean() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|