liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
raw
history blame
43.2 kB
from copy import copy, deepcopy
import torch
import torch.nn as nn
from dust3r.inference import get_pred_pts3d, find_opt_scaling
from dust3r.utils.geometry import (
inv,
geotrf,
normalize_pointcloud,
normalize_pointcloud_group,
)
from dust3r.utils.geometry import (
get_group_pointcloud_depth,
get_group_pointcloud_center_scale,
weighted_procrustes,
)
# from gsplat import rasterization
import numpy as np
import lpips
from dust3r.utils.camera import (
pose_encoding_to_camera,
camera_to_pose_encoding,
relative_pose_absT_quatR,
)
def Sum(*losses_and_masks):
loss, mask = losses_and_masks[0]
if loss.ndim > 0:
# we are actually returning the loss for every pixels
return losses_and_masks
else:
# we are returning the global loss
for loss2, mask2 in losses_and_masks[1:]:
loss = loss + loss2
return loss
class BaseCriterion(nn.Module):
def __init__(self, reduction="mean"):
super().__init__()
self.reduction = reduction
class LLoss(BaseCriterion):
"""L-norm loss"""
def forward(self, a, b):
assert (
a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
), f"Bad shape = {a.shape}"
dist = self.distance(a, b)
if self.reduction == "none":
return dist
if self.reduction == "sum":
return dist.sum()
if self.reduction == "mean":
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
raise ValueError(f"bad {self.reduction=} mode")
def distance(self, a, b):
raise NotImplementedError()
class L21Loss(LLoss):
"""Euclidean distance between 3d points"""
def distance(self, a, b):
return torch.norm(a - b, dim=-1) # normalized L2 distance
L21 = L21Loss()
class MSELoss(LLoss):
def distance(self, a, b):
return (a - b) ** 2
MSE = MSELoss()
class Criterion(nn.Module):
def __init__(self, criterion=None):
super().__init__()
assert isinstance(
criterion, BaseCriterion
), f"{criterion} is not a proper criterion!"
self.criterion = copy(criterion)
def get_name(self):
return f"{type(self).__name__}({self.criterion})"
def with_reduction(self, mode="none"):
res = loss = deepcopy(self)
while loss is not None:
assert isinstance(loss, Criterion)
loss.criterion.reduction = mode # make it return the loss for each sample
loss = loss._loss2 # we assume loss is a Multiloss
return res
class MultiLoss(nn.Module):
"""Easily combinable losses (also keep track of individual loss values):
loss = MyLoss1() + 0.1*MyLoss2()
Usage:
Inherit from this class and override get_name() and compute_loss()
"""
def __init__(self):
super().__init__()
self._alpha = 1
self._loss2 = None
def compute_loss(self, *args, **kwargs):
raise NotImplementedError()
def get_name(self):
raise NotImplementedError()
def __mul__(self, alpha):
assert isinstance(alpha, (int, float))
res = copy(self)
res._alpha = alpha
return res
__rmul__ = __mul__ # same
def __add__(self, loss2):
assert isinstance(loss2, MultiLoss)
res = cur = copy(self)
# find the end of the chain
while cur._loss2 is not None:
cur = cur._loss2
cur._loss2 = loss2
return res
def __repr__(self):
name = self.get_name()
if self._alpha != 1:
name = f"{self._alpha:g}*{name}"
if self._loss2:
name = f"{name} + {self._loss2}"
return name
def forward(self, *args, **kwargs):
loss = self.compute_loss(*args, **kwargs)
if isinstance(loss, tuple):
loss, details = loss
elif loss.ndim == 0:
details = {self.get_name(): float(loss)}
else:
details = {}
loss = loss * self._alpha
if self._loss2:
loss2, details2 = self._loss2(*args, **kwargs)
loss = loss + loss2
details |= details2
return loss, details
class SSIM(nn.Module):
"""Layer to compute the SSIM loss between a pair of images"""
def __init__(self):
super(SSIM, self).__init__()
self.mu_x_pool = nn.AvgPool2d(3, 1)
self.mu_y_pool = nn.AvgPool2d(3, 1)
self.sig_x_pool = nn.AvgPool2d(3, 1)
self.sig_y_pool = nn.AvgPool2d(3, 1)
self.sig_xy_pool = nn.AvgPool2d(3, 1)
self.refl = nn.ReflectionPad2d(1)
self.C1 = 0.01**2
self.C2 = 0.03**2
def forward(self, x, y):
x = self.refl(x)
y = self.refl(y)
mu_x = self.mu_x_pool(x)
mu_y = self.mu_y_pool(y)
sigma_x = self.sig_x_pool(x**2) - mu_x**2
sigma_y = self.sig_y_pool(y**2) - mu_y**2
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
SSIM_d = (mu_x**2 + mu_y**2 + self.C1) * (sigma_x + sigma_y + self.C2)
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
class RGBLoss(Criterion, MultiLoss):
def __init__(self, criterion):
super().__init__(criterion)
self.ssim = SSIM()
def img_loss(self, a, b):
return self.criterion(a, b)
def compute_loss(self, gts, preds, **kw):
gt_rgbs = [gt["img"].permute(0, 2, 3, 1) for gt in gts]
pred_rgbs = [pred["rgb"] for pred in preds]
ls = [
self.img_loss(pred_rgb, gt_rgb)
for pred_rgb, gt_rgb in zip(pred_rgbs, gt_rgbs)
]
details = {}
self_name = type(self).__name__
for i, l in enumerate(ls):
details[self_name + f"_rgb/{i+1}"] = float(l)
details[f"pred_rgb_{i+1}"] = pred_rgbs[i]
rgb_loss = sum(ls) / len(ls)
return rgb_loss, details
class DepthScaleShiftInvLoss(BaseCriterion):
"""scale and shift invariant loss"""
def __init__(self, reduction="none"):
super().__init__(reduction)
def forward(self, pred, gt, mask):
assert pred.shape == gt.shape and pred.ndim == 3, f"Bad shape = {pred.shape}"
dist = self.distance(pred, gt, mask)
# assert dist.ndim == a.ndim - 1 # one dimension less
if self.reduction == "none":
return dist
if self.reduction == "sum":
return dist.sum()
if self.reduction == "mean":
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
raise ValueError(f"bad {self.reduction=} mode")
def normalize(self, x, mask):
x_valid = x[mask]
splits = mask.sum(dim=(1, 2)).tolist()
x_valid_list = torch.split(x_valid, splits)
shift = [x.mean() for x in x_valid_list]
x_valid_centered = [x - m for x, m in zip(x_valid_list, shift)]
scale = [x.abs().mean() for x in x_valid_centered]
scale = torch.stack(scale)
shift = torch.stack(shift)
x = (x - shift.view(-1, 1, 1)) / scale.view(-1, 1, 1).clamp(min=1e-6)
return x
def distance(self, pred, gt, mask):
pred = self.normalize(pred, mask)
gt = self.normalize(gt, mask)
return torch.abs((pred - gt)[mask])
class ScaleInvLoss(BaseCriterion):
"""scale invariant loss"""
def __init__(self, reduction="none"):
super().__init__(reduction)
def forward(self, pred, gt, mask):
assert pred.shape == gt.shape and pred.ndim == 4, f"Bad shape = {pred.shape}"
dist = self.distance(pred, gt, mask)
# assert dist.ndim == a.ndim - 1 # one dimension less
if self.reduction == "none":
return dist
if self.reduction == "sum":
return dist.sum()
if self.reduction == "mean":
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
raise ValueError(f"bad {self.reduction=} mode")
def distance(self, pred, gt, mask):
pred_norm_factor = (torch.norm(pred, dim=-1) * mask).sum(dim=(1, 2)) / mask.sum(
dim=(1, 2)
).clamp(min=1e-6)
gt_norm_factor = (torch.norm(gt, dim=-1) * mask).sum(dim=(1, 2)) / mask.sum(
dim=(1, 2)
).clamp(min=1e-6)
pred = pred / pred_norm_factor.view(-1, 1, 1, 1).clamp(min=1e-6)
gt = gt / gt_norm_factor.view(-1, 1, 1, 1).clamp(min=1e-6)
return torch.norm(pred - gt, dim=-1)[mask]
class Regr3DPose(Criterion, MultiLoss):
"""Ensure that all 3D points are correct.
Asymmetric loss: view1 is supposed to be the anchor.
P1 = RT1 @ D1
P2 = RT2 @ D2
loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)
loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)
= (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)
"""
def __init__(
self,
criterion,
norm_mode="?avg_dis",
gt_scale=False,
sky_loss_value=2,
max_metric_scale=False,
):
super().__init__(criterion)
if norm_mode.startswith("?"):
# do no norm pts from metric scale datasets
self.norm_all = False
self.norm_mode = norm_mode[1:]
else:
self.norm_all = True
self.norm_mode = norm_mode
self.gt_scale = gt_scale
self.sky_loss_value = sky_loss_value
self.max_metric_scale = max_metric_scale
def get_norm_factor_point_cloud(
self, pts_self, pts_cross, valids, conf_self, conf_cross, norm_self_only=False
):
if norm_self_only:
norm_factor = normalize_pointcloud_group(
pts_self, self.norm_mode, valids, conf_self, ret_factor_only=True
)
else:
pts = [torch.cat([x, y], dim=2) for x, y in zip(pts_self, pts_cross)]
valids = [torch.cat([x, x], dim=2) for x in valids]
confs = [torch.cat([x, y], dim=2) for x, y in zip(conf_self, conf_cross)]
norm_factor = normalize_pointcloud_group(
pts, self.norm_mode, valids, confs, ret_factor_only=True
)
return norm_factor
def get_norm_factor_poses(self, gt_trans, pr_trans, not_metric_mask):
if self.norm_mode and not self.gt_scale:
gt_trans = [x[:, None, None, :].clone() for x in gt_trans]
valids = [torch.ones_like(x[..., 0], dtype=torch.bool) for x in gt_trans]
norm_factor_gt = (
normalize_pointcloud_group(
gt_trans,
self.norm_mode,
valids,
ret_factor_only=True,
)
.squeeze(-1)
.squeeze(-1)
)
else:
norm_factor_gt = torch.ones(
len(gt_trans), dtype=gt_trans[0].dtype, device=gt_trans[0].device
)
norm_factor_pr = norm_factor_gt.clone()
if self.norm_mode and not_metric_mask.sum() > 0 and not self.gt_scale:
pr_trans_not_metric = [
x[not_metric_mask][:, None, None, :].clone() for x in pr_trans
]
valids = [
torch.ones_like(x[..., 0], dtype=torch.bool)
for x in pr_trans_not_metric
]
norm_factor_pr_not_metric = (
normalize_pointcloud_group(
pr_trans_not_metric,
self.norm_mode,
valids,
ret_factor_only=True,
)
.squeeze(-1)
.squeeze(-1)
)
norm_factor_pr[not_metric_mask] = norm_factor_pr_not_metric
return norm_factor_gt, norm_factor_pr
def get_all_pts3d(
self,
gts,
preds,
dist_clip=None,
norm_self_only=False,
norm_pose_separately=False,
eps=1e-3,
camera1=None,
):
# everything is normalized w.r.t. camera of view1
in_camera1 = inv(gts[0]["camera_pose"]) if camera1 is None else inv(camera1)
gt_pts_self = [geotrf(inv(gt["camera_pose"]), gt["pts3d"]) for gt in gts]
gt_pts_cross = [geotrf(in_camera1, gt["pts3d"]) for gt in gts]
valids = [gt["valid_mask"].clone() for gt in gts]
camera_only = gts[0]["camera_only"]
if dist_clip is not None:
# points that are too far-away == invalid
dis = [gt_pt.norm(dim=-1) for gt_pt in gt_pts_cross]
valids = [valid & (dis <= dist_clip) for valid, dis in zip(valids, dis)]
pr_pts_self = [pred["pts3d_in_self_view"] for pred in preds]
pr_pts_cross = [pred["pts3d_in_other_view"] for pred in preds]
conf_self = [torch.log(pred["conf_self"]).detach().clip(eps) for pred in preds]
conf_cross = [torch.log(pred["conf"]).detach().clip(eps) for pred in preds]
if not self.norm_all:
if self.max_metric_scale:
B = valids[0].shape[0]
dist = [
torch.where(valid, torch.linalg.norm(gt_pt_cross, dim=-1), 0).view(
B, -1
)
for valid, gt_pt_cross in zip(valids, gt_pts_cross)
]
for d in dist:
gts[0]["is_metric"] = gts[0]["is_metric_scale"] & (
d.max(dim=-1).values < self.max_metric_scale
)
not_metric_mask = ~gts[0]["is_metric"]
else:
not_metric_mask = torch.ones_like(gts[0]["is_metric"])
# normalize 3d points
# compute the scale using only the self view point maps
if self.norm_mode and not self.gt_scale:
norm_factor_gt = self.get_norm_factor_point_cloud(
gt_pts_self,
gt_pts_cross,
valids,
conf_self,
conf_cross,
norm_self_only=norm_self_only,
)
else:
norm_factor_gt = torch.ones_like(
preds[0]["pts3d_in_other_view"][:, :1, :1, :1]
)
norm_factor_pr = norm_factor_gt.clone()
if self.norm_mode and not_metric_mask.sum() > 0 and not self.gt_scale:
norm_factor_pr_not_metric = self.get_norm_factor_point_cloud(
[pr_pt_self[not_metric_mask] for pr_pt_self in pr_pts_self],
[pr_pt_cross[not_metric_mask] for pr_pt_cross in pr_pts_cross],
[valid[not_metric_mask] for valid in valids],
[conf[not_metric_mask] for conf in conf_self],
[conf[not_metric_mask] for conf in conf_cross],
norm_self_only=norm_self_only,
)
norm_factor_pr[not_metric_mask] = norm_factor_pr_not_metric
norm_factor_gt = norm_factor_gt.clip(eps)
norm_factor_pr = norm_factor_pr.clip(eps)
gt_pts_self = [pts / norm_factor_gt for pts in gt_pts_self]
gt_pts_cross = [pts / norm_factor_gt for pts in gt_pts_cross]
pr_pts_self = [pts / norm_factor_pr for pts in pr_pts_self]
pr_pts_cross = [pts / norm_factor_pr for pts in pr_pts_cross]
# [(Bx3, BX4), (BX3, BX4), ...], 3 for translation, 4 for quaternion
gt_poses = [
camera_to_pose_encoding(in_camera1 @ gt["camera_pose"]).clone()
for gt in gts
]
pr_poses = [pred["camera_pose"].clone() for pred in preds]
pose_norm_factor_gt = norm_factor_gt.clone().squeeze(2, 3)
pose_norm_factor_pr = norm_factor_pr.clone().squeeze(2, 3)
if norm_pose_separately:
gt_trans = [gt[:, :3] for gt in gt_poses]
pr_trans = [pr[:, :3] for pr in pr_poses]
pose_norm_factor_gt, pose_norm_factor_pr = self.get_norm_factor_poses(
gt_trans, pr_trans, not_metric_mask
)
elif any(camera_only):
gt_trans = [gt[:, :3] for gt in gt_poses]
pr_trans = [pr[:, :3] for pr in pr_poses]
pose_only_norm_factor_gt, pose_only_norm_factor_pr = (
self.get_norm_factor_poses(gt_trans, pr_trans, not_metric_mask)
)
pose_norm_factor_gt = torch.where(
camera_only[:, None], pose_only_norm_factor_gt, pose_norm_factor_gt
)
pose_norm_factor_pr = torch.where(
camera_only[:, None], pose_only_norm_factor_pr, pose_norm_factor_pr
)
gt_poses = [
(gt[:, :3] / pose_norm_factor_gt.clip(eps), gt[:, 3:]) for gt in gt_poses
]
pr_poses = [
(pr[:, :3] / pose_norm_factor_pr.clip(eps), pr[:, 3:]) for pr in pr_poses
]
pose_masks = (pose_norm_factor_gt.squeeze() > eps) & (
pose_norm_factor_pr.squeeze() > eps
)
if any(camera_only):
# this is equal to a loss for camera intrinsics
gt_pts_self = [
torch.where(
camera_only[:, None, None, None],
(gt / gt[..., -1:].clip(1e-6)).clip(-2, 2),
gt,
)
for gt in gt_pts_self
]
pr_pts_self = [
torch.where(
camera_only[:, None, None, None],
(pr / pr[..., -1:].clip(1e-6)).clip(-2, 2),
pr,
)
for pr in pr_pts_self
]
# # do not add cross view loss when there is only camera supervision
skys = [gt["sky_mask"] & ~valid for gt, valid in zip(gts, valids)]
return (
gt_pts_self,
gt_pts_cross,
pr_pts_self,
pr_pts_cross,
gt_poses,
pr_poses,
valids,
skys,
pose_masks,
{},
)
def get_all_pts3d_with_scale_loss(
self,
gts,
preds,
dist_clip=None,
norm_self_only=False,
norm_pose_separately=False,
eps=1e-3,
):
# everything is normalized w.r.t. camera of view1
in_camera1 = inv(gts[0]["camera_pose"])
gt_pts_self = [geotrf(inv(gt["camera_pose"]), gt["pts3d"]) for gt in gts]
gt_pts_cross = [geotrf(in_camera1, gt["pts3d"]) for gt in gts]
valids = [gt["valid_mask"].clone() for gt in gts]
camera_only = gts[0]["camera_only"]
if dist_clip is not None:
# points that are too far-away == invalid
dis = [gt_pt.norm(dim=-1) for gt_pt in gt_pts_cross]
valids = [valid & (dis <= dist_clip) for valid, dis in zip(valids, dis)]
pr_pts_self = [pred["pts3d_in_self_view"] for pred in preds]
pr_pts_cross = [pred["pts3d_in_other_view"] for pred in preds]
conf_self = [torch.log(pred["conf_self"]).detach().clip(eps) for pred in preds]
conf_cross = [torch.log(pred["conf"]).detach().clip(eps) for pred in preds]
if not self.norm_all:
if self.max_metric_scale:
B = valids[0].shape[0]
dist = [
torch.where(valid, torch.linalg.norm(gt_pt_cross, dim=-1), 0).view(
B, -1
)
for valid, gt_pt_cross in zip(valids, gt_pts_cross)
]
for d in dist:
gts[0]["is_metric"] = gts[0]["is_metric_scale"] & (
d.max(dim=-1).values < self.max_metric_scale
)
not_metric_mask = ~gts[0]["is_metric"]
else:
not_metric_mask = torch.ones_like(gts[0]["is_metric"])
# normalize 3d points
# compute the scale using only the self view point maps
if self.norm_mode and not self.gt_scale:
norm_factor_gt = self.get_norm_factor_point_cloud(
gt_pts_self[:1],
gt_pts_cross[:1],
valids[:1],
conf_self[:1],
conf_cross[:1],
norm_self_only=norm_self_only,
)
else:
norm_factor_gt = torch.ones_like(
preds[0]["pts3d_in_other_view"][:, :1, :1, :1]
)
if self.norm_mode:
norm_factor_pr = self.get_norm_factor_point_cloud(
pr_pts_self[:1],
pr_pts_cross[:1],
valids[:1],
conf_self[:1],
conf_cross[:1],
norm_self_only=norm_self_only,
)
else:
raise NotImplementedError
# only add loss to metric scale norm factor
if (~not_metric_mask).sum() > 0:
pts_scale_loss = torch.abs(
norm_factor_pr[~not_metric_mask] - norm_factor_gt[~not_metric_mask]
).mean()
else:
pts_scale_loss = 0.0
norm_factor_gt = norm_factor_gt.clip(eps)
norm_factor_pr = norm_factor_pr.clip(eps)
gt_pts_self = [pts / norm_factor_gt for pts in gt_pts_self]
gt_pts_cross = [pts / norm_factor_gt for pts in gt_pts_cross]
pr_pts_self = [pts / norm_factor_pr for pts in pr_pts_self]
pr_pts_cross = [pts / norm_factor_pr for pts in pr_pts_cross]
# [(Bx3, BX4), (BX3, BX4), ...], 3 for translation, 4 for quaternion
gt_poses = [
camera_to_pose_encoding(in_camera1 @ gt["camera_pose"]).clone()
for gt in gts
]
pr_poses = [pred["camera_pose"].clone() for pred in preds]
pose_norm_factor_gt = norm_factor_gt.clone().squeeze(2, 3)
pose_norm_factor_pr = norm_factor_pr.clone().squeeze(2, 3)
if norm_pose_separately:
gt_trans = [gt[:, :3] for gt in gt_poses][:1]
pr_trans = [pr[:, :3] for pr in pr_poses][:1]
pose_norm_factor_gt, pose_norm_factor_pr = self.get_norm_factor_poses(
gt_trans, pr_trans, torch.ones_like(not_metric_mask)
)
elif any(camera_only):
gt_trans = [gt[:, :3] for gt in gt_poses][:1]
pr_trans = [pr[:, :3] for pr in pr_poses][:1]
pose_only_norm_factor_gt, pose_only_norm_factor_pr = (
self.get_norm_factor_poses(
gt_trans, pr_trans, torch.ones_like(not_metric_mask)
)
)
pose_norm_factor_gt = torch.where(
camera_only[:, None], pose_only_norm_factor_gt, pose_norm_factor_gt
)
pose_norm_factor_pr = torch.where(
camera_only[:, None], pose_only_norm_factor_pr, pose_norm_factor_pr
)
# only add loss to metric scale norm factor
if (~not_metric_mask).sum() > 0:
pose_scale_loss = torch.abs(
pose_norm_factor_pr[~not_metric_mask]
- pose_norm_factor_gt[~not_metric_mask]
).mean()
else:
pose_scale_loss = 0.0
gt_poses = [
(gt[:, :3] / pose_norm_factor_gt.clip(eps), gt[:, 3:]) for gt in gt_poses
]
pr_poses = [
(pr[:, :3] / pose_norm_factor_pr.clip(eps), pr[:, 3:]) for pr in pr_poses
]
pose_masks = (pose_norm_factor_gt.squeeze() > eps) & (
pose_norm_factor_pr.squeeze() > eps
)
if any(camera_only):
# this is equal to a loss for camera intrinsics
gt_pts_self = [
torch.where(
camera_only[:, None, None, None],
(gt / gt[..., -1:].clip(1e-6)).clip(-2, 2),
gt,
)
for gt in gt_pts_self
]
pr_pts_self = [
torch.where(
camera_only[:, None, None, None],
(pr / pr[..., -1:].clip(1e-6)).clip(-2, 2),
pr,
)
for pr in pr_pts_self
]
# # do not add cross view loss when there is only camera supervision
skys = [gt["sky_mask"] & ~valid for gt, valid in zip(gts, valids)]
return (
gt_pts_self,
gt_pts_cross,
pr_pts_self,
pr_pts_cross,
gt_poses,
pr_poses,
valids,
skys,
pose_masks,
{"scale_loss": pose_scale_loss + pts_scale_loss},
)
def compute_relative_pose_loss(
self, gt_trans, gt_quats, pr_trans, pr_quats, masks=None
):
if masks is None:
masks = torch.ones(len(gt_trans), dtype=torch.bool, device=gt_trans.device)
gt_trans_matrix1 = gt_trans[:, :, None, :].repeat(1, 1, gt_trans.shape[1], 1)[
masks
]
gt_trans_matrix2 = gt_trans[:, None, :, :].repeat(1, gt_trans.shape[1], 1, 1)[
masks
]
gt_quats_matrix1 = gt_quats[:, :, None, :].repeat(1, 1, gt_quats.shape[1], 1)[
masks
]
gt_quats_matrix2 = gt_quats[:, None, :, :].repeat(1, gt_quats.shape[1], 1, 1)[
masks
]
pr_trans_matrix1 = pr_trans[:, :, None, :].repeat(1, 1, pr_trans.shape[1], 1)[
masks
]
pr_trans_matrix2 = pr_trans[:, None, :, :].repeat(1, pr_trans.shape[1], 1, 1)[
masks
]
pr_quats_matrix1 = pr_quats[:, :, None, :].repeat(1, 1, pr_quats.shape[1], 1)[
masks
]
pr_quats_matrix2 = pr_quats[:, None, :, :].repeat(1, pr_quats.shape[1], 1, 1)[
masks
]
gt_rel_trans, gt_rel_quats = relative_pose_absT_quatR(
gt_trans_matrix1, gt_quats_matrix1, gt_trans_matrix2, gt_quats_matrix2
)
pr_rel_trans, pr_rel_quats = relative_pose_absT_quatR(
pr_trans_matrix1, pr_quats_matrix1, pr_trans_matrix2, pr_quats_matrix2
)
rel_trans_err = torch.norm(gt_rel_trans - pr_rel_trans, dim=-1)
rel_quats_err = torch.norm(gt_rel_quats - pr_rel_quats, dim=-1)
return rel_trans_err.mean() + rel_quats_err.mean()
def compute_pose_loss(self, gt_poses, pred_poses, masks=None):
"""
gt_pose: list of (Bx3, Bx4)
pred_pose: list of (Bx3, Bx4)
masks: None, or B
"""
gt_trans = torch.stack([gt[0] for gt in gt_poses], dim=1) # BxNx3
gt_quats = torch.stack([gt[1] for gt in gt_poses], dim=1) # BXNX3
pred_trans = torch.stack([pr[0] for pr in pred_poses], dim=1) # BxNx4
pred_quats = torch.stack([pr[1] for pr in pred_poses], dim=1) # BxNx4
if masks == None:
pose_loss = (
torch.norm(pred_trans - gt_trans, dim=-1).mean()
+ torch.norm(pred_quats - gt_quats, dim=-1).mean()
)
else:
if not any(masks):
return torch.tensor(0.0)
pose_loss = (
torch.norm(pred_trans - gt_trans, dim=-1)[masks].mean()
+ torch.norm(pred_quats - gt_quats, dim=-1)[masks].mean()
)
return pose_loss
def compute_loss(self, gts, preds, **kw):
(
gt_pts_self,
gt_pts_cross,
pred_pts_self,
pred_pts_cross,
gt_poses,
pr_poses,
masks,
skys,
pose_masks,
monitoring,
) = self.get_all_pts3d(gts, preds, **kw)
if self.sky_loss_value > 0:
assert (
self.criterion.reduction == "none"
), "sky_loss_value should be 0 if no conf loss"
masks = [mask | sky for mask, sky in zip(masks, skys)]
# self view loss and details
if "Quantile" in self.criterion.__class__.__name__:
# masks are overwritten taking into account self view losses
ls_self, masks = self.criterion(
pred_pts_self, gt_pts_self, masks, gts[0]["quantile"]
)
else:
ls_self = [
self.criterion(pred_pt[mask], gt_pt[mask])
for pred_pt, gt_pt, mask in zip(pred_pts_self, gt_pts_self, masks)
]
if self.sky_loss_value > 0:
assert (
self.criterion.reduction == "none"
), "sky_loss_value should be 0 if no conf loss"
for i, l in enumerate(ls_self):
ls_self[i] = torch.where(skys[i][masks[i]], self.sky_loss_value, l)
self_name = type(self).__name__
details = {}
for i in range(len(ls_self)):
details[self_name + f"_self_pts3d/{i+1}"] = float(ls_self[i].mean())
details[f"gt_img{i+1}"] = gts[i]["img"].permute(0, 2, 3, 1).detach()
details[f"self_conf_{i+1}"] = preds[i]["conf_self"].detach()
details[f"valid_mask_{i+1}"] = masks[i].detach()
if "img_mask" in gts[i] and "ray_mask" in gts[i]:
details[f"img_mask_{i+1}"] = gts[i]["img_mask"].detach()
details[f"ray_mask_{i+1}"] = gts[i]["ray_mask"].detach()
if "desc" in preds[i]:
details[f"desc_{i+1}"] = preds[i]["desc"].detach()
# cross view loss and details
camera_only = gts[0]["camera_only"]
pred_pts_cross = [pred_pts[~camera_only] for pred_pts in pred_pts_cross]
gt_pts_cross = [gt_pts[~camera_only] for gt_pts in gt_pts_cross]
masks_cross = [mask[~camera_only] for mask in masks]
skys_cross = [sky[~camera_only] for sky in skys]
if "Quantile" in self.criterion.__class__.__name__:
# quantile masks have already been determined by self view losses, here pass in None as quantile
ls_cross, _ = self.criterion(
pred_pts_cross, gt_pts_cross, masks_cross, None
)
else:
ls_cross = [
self.criterion(pred_pt[mask], gt_pt[mask])
for pred_pt, gt_pt, mask in zip(
pred_pts_cross, gt_pts_cross, masks_cross
)
]
if self.sky_loss_value > 0:
assert (
self.criterion.reduction == "none"
), "sky_loss_value should be 0 if no conf loss"
for i, l in enumerate(ls_cross):
ls_cross[i] = torch.where(
skys_cross[i][masks_cross[i]], self.sky_loss_value, l
)
for i in range(len(ls_cross)):
details[self_name + f"_pts3d/{i+1}"] = float(
ls_cross[i].mean() if ls_cross[i].numel() > 0 else 0
)
details[f"conf_{i+1}"] = preds[i]["conf"].detach()
ls = ls_self + ls_cross
masks = masks + masks_cross
details["is_self"] = [True] * len(ls_self) + [False] * len(ls_cross)
details["img_ids"] = (
np.arange(len(ls_self)).tolist() + np.arange(len(ls_cross)).tolist()
)
details["pose_loss"] = self.compute_pose_loss(gt_poses, pr_poses, pose_masks)
return Sum(*list(zip(ls, masks))), (details | monitoring)
class Regr3DPoseBatchList(Regr3DPose):
"""Ensure that all 3D points are correct.
Asymmetric loss: view1 is supposed to be the anchor.
P1 = RT1 @ D1
P2 = RT2 @ D2
loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)
loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)
= (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)
"""
def __init__(
self,
criterion,
norm_mode="?avg_dis",
gt_scale=False,
sky_loss_value=2,
max_metric_scale=False,
):
super().__init__(
criterion, norm_mode, gt_scale, sky_loss_value, max_metric_scale
)
self.depth_only_criterion = DepthScaleShiftInvLoss()
self.single_view_criterion = ScaleInvLoss()
def reorg(self, ls_b, masks_b):
ids_split = [mask.sum(dim=(1, 2)) for mask in masks_b]
ls = [[] for _ in range(len(masks_b[0]))]
for i in range(len(ls_b)):
ls_splitted_i = torch.split(ls_b[i], ids_split[i].tolist())
for j in range(len(masks_b[0])):
ls[j].append(ls_splitted_i[j])
ls = [torch.cat(l) for l in ls]
return ls
def compute_loss(self, gts, preds, **kw):
(
gt_pts_self,
gt_pts_cross,
pred_pts_self,
pred_pts_cross,
gt_poses,
pr_poses,
masks,
skys,
pose_masks,
monitoring,
) = self.get_all_pts3d(gts, preds, **kw)
if self.sky_loss_value > 0:
assert (
self.criterion.reduction == "none"
), "sky_loss_value should be 0 if no conf loss"
masks = [mask | sky for mask, sky in zip(masks, skys)]
camera_only = gts[0]["camera_only"]
depth_only = gts[0]["depth_only"]
single_view = gts[0]["single_view"]
is_metric = gts[0]["is_metric"]
# self view loss and details
if "Quantile" in self.criterion.__class__.__name__:
raise NotImplementedError
else:
# list [(B, h, w, 3)] x num_views -> list [num_views, h, w, 3] x B
gt_pts_self_b = torch.unbind(torch.stack(gt_pts_self, dim=1), dim=0)
pred_pts_self_b = torch.unbind(torch.stack(pred_pts_self, dim=1), dim=0)
masks_b = torch.unbind(torch.stack(masks, dim=1), dim=0)
ls_self_b = []
for i in range(len(gt_pts_self_b)):
if depth_only[
i
]: # if only have relative depth, no intrinsics or anything
ls_self_b.append(
self.depth_only_criterion(
pred_pts_self_b[i][..., -1],
gt_pts_self_b[i][..., -1],
masks_b[i],
)
)
elif (
single_view[i] and not is_metric[i]
): # if single view, with intrinsics and not metric
ls_self_b.append(
self.single_view_criterion(
pred_pts_self_b[i], gt_pts_self_b[i], masks_b[i]
)
)
else: # if multiple view, or metric single view
ls_self_b.append(
self.criterion(
pred_pts_self_b[i][masks_b[i]], gt_pts_self_b[i][masks_b[i]]
)
)
ls_self = self.reorg(ls_self_b, masks_b)
if self.sky_loss_value > 0:
assert (
self.criterion.reduction == "none"
), "sky_loss_value should be 0 if no conf loss"
for i, l in enumerate(ls_self):
ls_self[i] = torch.where(skys[i][masks[i]], self.sky_loss_value, l)
self_name = type(self).__name__
details = {}
for i in range(len(ls_self)):
details[self_name + f"_self_pts3d/{i+1}"] = float(ls_self[i].mean())
details[f"self_conf_{i+1}"] = preds[i]["conf_self"].detach()
details[f"gt_img{i+1}"] = gts[i]["img"].permute(0, 2, 3, 1).detach()
details[f"valid_mask_{i+1}"] = masks[i].detach()
if "img_mask" in gts[i] and "ray_mask" in gts[i]:
details[f"img_mask_{i+1}"] = gts[i]["img_mask"].detach()
details[f"ray_mask_{i+1}"] = gts[i]["ray_mask"].detach()
if "desc" in preds[i]:
details[f"desc_{i+1}"] = preds[i]["desc"].detach()
if "Quantile" in self.criterion.__class__.__name__:
# quantile masks have already been determined by self view losses, here pass in None as quantile
raise NotImplementedError
else:
gt_pts_cross_b = torch.unbind(
torch.stack(gt_pts_cross, dim=1)[~camera_only], dim=0
)
pred_pts_cross_b = torch.unbind(
torch.stack(pred_pts_cross, dim=1)[~camera_only], dim=0
)
masks_cross_b = torch.unbind(torch.stack(masks, dim=1)[~camera_only], dim=0)
ls_cross_b = []
for i in range(len(gt_pts_cross_b)):
if depth_only[~camera_only][i]:
ls_cross_b.append(
self.depth_only_criterion(
pred_pts_cross_b[i][..., -1],
gt_pts_cross_b[i][..., -1],
masks_cross_b[i],
)
)
elif single_view[~camera_only][i] and not is_metric[~camera_only][i]:
ls_cross_b.append(
self.single_view_criterion(
pred_pts_cross_b[i], gt_pts_cross_b[i], masks_cross_b[i]
)
)
else:
ls_cross_b.append(
self.criterion(
pred_pts_cross_b[i][masks_cross_b[i]],
gt_pts_cross_b[i][masks_cross_b[i]],
)
)
ls_cross = self.reorg(ls_cross_b, masks_cross_b)
if self.sky_loss_value > 0:
assert (
self.criterion.reduction == "none"
), "sky_loss_value should be 0 if no conf loss"
masks_cross = [mask[~camera_only] for mask in masks]
skys_cross = [sky[~camera_only] for sky in skys]
for i, l in enumerate(ls_cross):
ls_cross[i] = torch.where(
skys_cross[i][masks_cross[i]], self.sky_loss_value, l
)
for i in range(len(ls_cross)):
details[self_name + f"_pts3d/{i+1}"] = float(
ls_cross[i].mean() if ls_cross[i].numel() > 0 else 0
)
details[f"conf_{i+1}"] = preds[i]["conf"].detach()
ls = ls_self + ls_cross
masks = masks + masks_cross
details["is_self"] = [True] * len(ls_self) + [False] * len(ls_cross)
details["img_ids"] = (
np.arange(len(ls_self)).tolist() + np.arange(len(ls_cross)).tolist()
)
pose_masks = pose_masks * gts[i]["img_mask"]
details["pose_loss"] = self.compute_pose_loss(gt_poses, pr_poses, pose_masks)
return Sum(*list(zip(ls, masks))), (details | monitoring)
class ConfLoss(MultiLoss):
"""Weighted regression by learned confidence.
Assuming the input pixel_loss is a pixel-level regression loss.
Principle:
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
alpha: hyperparameter
"""
def __init__(self, pixel_loss, alpha=1):
super().__init__()
assert alpha > 0
self.alpha = alpha
self.pixel_loss = pixel_loss.with_reduction("none")
def get_name(self):
return f"ConfLoss({self.pixel_loss})"
def get_conf_log(self, x):
return x, torch.log(x)
def compute_loss(self, gts, preds, **kw):
# compute per-pixel loss
losses_and_masks, details = self.pixel_loss(gts, preds, **kw)
if "is_self" in details and "img_ids" in details:
is_self = details["is_self"]
img_ids = details["img_ids"]
else:
is_self = [False] * len(losses_and_masks)
img_ids = list(range(len(losses_and_masks)))
# weight by confidence
conf_losses = []
for i in range(len(losses_and_masks)):
pred = preds[img_ids[i]]
conf_key = "conf_self" if is_self[i] else "conf"
if not is_self[i]:
camera_only = gts[0]["camera_only"]
conf, log_conf = self.get_conf_log(
pred[conf_key][~camera_only][losses_and_masks[i][1]]
)
else:
conf, log_conf = self.get_conf_log(
pred[conf_key][losses_and_masks[i][1]]
)
conf_loss = losses_and_masks[i][0] * conf - self.alpha * log_conf
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
conf_losses.append(conf_loss)
if is_self[i]:
details[self.get_name() + f"_conf_loss_self/{img_ids[i]+1}"] = float(
conf_loss
)
else:
details[self.get_name() + f"_conf_loss/{img_ids[i]+1}"] = float(
conf_loss
)
details.pop("is_self", None)
details.pop("img_ids", None)
final_loss = sum(conf_losses) / len(conf_losses) * 2.0
if "pose_loss" in details:
final_loss = (
final_loss + details["pose_loss"].clip(max=0.3) * 5.0
) # , details
if "scale_loss" in details:
final_loss = final_loss + details["scale_loss"]
return final_loss, details
class Regr3DPose_ScaleInv(Regr3DPose):
"""Same than Regr3D but invariant to depth shift.
if gt_scale == True: enforce the prediction to take the same scale than GT
"""
def get_all_pts3d(self, gts, preds):
# compute depth-normalized points
(
gt_pts_self,
gt_pts_cross,
pr_pts_self,
pr_pts_cross,
gt_poses,
pr_poses,
masks,
skys,
pose_masks,
monitoring,
) = super().get_all_pts3d(gts, preds)
# measure scene scale
_, gt_scale_self = get_group_pointcloud_center_scale(gt_pts_self, masks)
_, pred_scale_self = get_group_pointcloud_center_scale(pr_pts_self, masks)
_, gt_scale_cross = get_group_pointcloud_center_scale(gt_pts_cross, masks)
_, pred_scale_cross = get_group_pointcloud_center_scale(pr_pts_cross, masks)
# prevent predictions to be in a ridiculous range
pred_scale_self = pred_scale_self.clip(min=1e-3, max=1e3)
pred_scale_cross = pred_scale_cross.clip(min=1e-3, max=1e3)
# subtract the median depth
if self.gt_scale:
pr_pts_self = [
pr_pt_self * gt_scale_self / pred_scale_self
for pr_pt_self in pr_pts_self
]
pr_pts_cross = [
pr_pt_cross * gt_scale_cross / pred_scale_cross
for pr_pt_cross in pr_pts_cross
]
else:
gt_pts_self = [gt_pt_self / gt_scale_self for gt_pt_self in gt_pts_self]
gt_pts_cross = [
gt_pt_cross / gt_scale_cross for gt_pt_cross in gt_pts_cross
]
pr_pts_self = [pr_pt_self / pred_scale_self for pr_pt_self in pr_pts_self]
pr_pts_cross = [
pr_pt_cross / pred_scale_cross for pr_pt_cross in pr_pts_cross
]
return (
gt_pts_self,
gt_pts_cross,
pr_pts_self,
pr_pts_cross,
gt_poses,
pr_poses,
masks,
skys,
pose_masks,
monitoring,
)