liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
raw
history blame
16.7 kB
import torch
import torch.nn as nn
from copy import copy, deepcopy
from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
from dust3r.utils.geometry import inv, geotrf, depthmap_to_pts3d
from dust3r.utils.camera import pose_encoding_to_camera
class BaseCriterion(nn.Module):
def __init__(self, reduction="mean"):
super().__init__()
self.reduction = reduction
class Criterion(nn.Module):
def __init__(self, criterion=None):
super().__init__()
assert isinstance(
criterion, BaseCriterion
), f"{criterion} is not a proper criterion!"
self.criterion = copy(criterion)
def get_name(self):
return f"{type(self).__name__}({self.criterion})"
def with_reduction(self, mode="none"):
res = loss = deepcopy(self)
while loss is not None:
assert isinstance(loss, Criterion)
loss.criterion.reduction = mode # make it return the loss for each sample
loss = loss._loss2 # we assume loss is a Multiloss
return res
class MultiLoss(nn.Module):
"""Easily combinable losses (also keep track of individual loss values):
loss = MyLoss1() + 0.1*MyLoss2()
Usage:
Inherit from this class and override get_name() and compute_loss()
"""
def __init__(self):
super().__init__()
self._alpha = 1
self._loss2 = None
def compute_loss(self, *args, **kwargs):
raise NotImplementedError()
def get_name(self):
raise NotImplementedError()
def __mul__(self, alpha):
assert isinstance(alpha, (int, float))
res = copy(self)
res._alpha = alpha
return res
__rmul__ = __mul__ # same
def __add__(self, loss2):
assert isinstance(loss2, MultiLoss)
res = cur = copy(self)
while cur._loss2 is not None:
cur = cur._loss2
cur._loss2 = loss2
return res
def __repr__(self):
name = self.get_name()
if self._alpha != 1:
name = f"{self._alpha:g}*{name}"
if self._loss2:
name = f"{name} + {self._loss2}"
return name
def forward(self, *args, **kwargs):
loss = self.compute_loss(*args, **kwargs)
if isinstance(loss, tuple):
loss, details = loss
elif loss.ndim == 0:
details = {self.get_name(): float(loss)}
else:
details = {}
loss = loss * self._alpha
if self._loss2:
loss2, details2 = self._loss2(*args, **kwargs)
loss = loss + loss2
details |= details2
return loss, details
class LLoss(BaseCriterion):
"""L-norm loss"""
def forward(self, a, b):
assert (
a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
), f"Bad shape = {a.shape}"
dist = self.distance(a, b)
if self.reduction == "none":
return dist
if self.reduction == "sum":
return dist.sum()
if self.reduction == "mean":
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
raise ValueError(f"bad {self.reduction=} mode")
def distance(self, a, b):
raise NotImplementedError()
class L21Loss(LLoss):
"""Euclidean distance between 3d points"""
def distance(self, a, b):
return torch.norm(a - b, dim=-1) # normalized L2 distance
L21 = L21Loss()
def get_pred_pts3d(gt, pred, use_pose=False):
if "depth" in pred and "pseudo_focal" in pred:
try:
pp = gt["camera_intrinsics"][..., :2, 2]
except KeyError:
pp = None
pts3d = depthmap_to_pts3d(**pred, pp=pp)
elif "pts3d" in pred:
# pts3d from my camera
pts3d = pred["pts3d"]
elif "pts3d_in_other_view" in pred:
# pts3d from the other camera, already transformed
assert use_pose is True
return pred["pts3d_in_other_view"] # return!
if use_pose:
camera_pose = pred.get("camera_pose")
pts3d = pred.get("pts3d_in_self_view")
assert camera_pose is not None
assert pts3d is not None
pts3d = geotrf(pose_encoding_to_camera(camera_pose), pts3d)
return pts3d
def Sum(losses, masks, conf=None):
loss, mask = losses[0], masks[0]
if loss.ndim > 0:
# we are actually returning the loss for every pixels
if conf is not None:
return losses, masks, conf
return losses, masks
else:
# we are returning the global loss
for loss2 in losses[1:]:
loss = loss + loss2
return loss
def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
norm_mode, dis_mode = norm_mode.split("_")
nan_pts = []
nnzs = []
if norm_mode == "avg":
# gather all points together (joint normalization)
for i, pt in enumerate(pts):
nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
nan_pts.append(nan_pt)
nnzs.append(nnz)
if fix_first:
break
all_pts = torch.cat(nan_pts, dim=1)
# compute distance to origin
all_dis = all_pts.norm(dim=-1)
if dis_mode == "dis":
pass # do nothing
elif dis_mode == "log1p":
all_dis = torch.log1p(all_dis)
else:
raise ValueError(f"bad {dis_mode=}")
norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
else:
raise ValueError(f"Not implemented {norm_mode=}")
norm_factor = norm_factor.clip(min=1e-8)
while norm_factor.ndim < pts[0].ndim:
norm_factor.unsqueeze_(-1)
return norm_factor
def normalize_pointcloud_t(
pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
):
if gt:
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
res = []
for i, pt in enumerate(pts):
res.append(pt / norm_factor)
else:
# pts_l, pts_r = pts
# use pts_l and pts_r[-1] as pts to normalize
norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
res = []
for i in range(len(pts)):
res.append(pts[i] / norm_factor)
# res_r.append(pts_r[i] / norm_factor)
# res = [res_l, res_r]
return res, norm_factor
@torch.no_grad()
def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
# set invalid points to NaN
_zs = []
for i in range(len(zs)):
valid_mask = valid_masks[i] if valid_masks is not None else None
_z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
_zs.append(_z)
_zs = torch.cat(_zs, dim=-1)
# compute median depth overall (ignoring nans)
if quantile == 0.5:
shift_z = torch.nanmedian(_zs, dim=-1).values
else:
shift_z = torch.nanquantile(_zs, quantile, dim=-1)
return shift_z # (B,)
@torch.no_grad()
def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
# set invalid points to NaN
_pts = []
for i in range(len(pts)):
valid_mask = valid_masks[i] if valid_masks is not None else None
_pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
_pts.append(_pt)
_pts = torch.cat(_pts, dim=1)
# compute median center
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
if z_only:
_center[..., :2] = 0 # do not center X and Y
# compute median norm
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
scale = torch.nanmedian(_norm, dim=1).values
return _center[:, None, :, :], scale[:, None, None, None]
class Regr3D_t(Criterion, MultiLoss):
def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
super().__init__(criterion)
self.norm_mode = norm_mode
self.gt_scale = gt_scale
self.fix_first = fix_first
def get_all_pts3d_t(self, gts, preds, dist_clip=None):
# everything is normalized w.r.t. camera of view1
in_camera1 = inv(gts[0]["camera_pose"])
gt_pts = []
valids = []
pr_pts = []
for i, gt in enumerate(gts):
# in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
valid = gt["valid_mask"].clone()
if dist_clip is not None:
# points that are too far-away == invalid
dis = gt["pts3d"].norm(dim=-1)
valid = valid & (dis <= dist_clip)
valids.append(valid)
pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
# if i != len(gts)-1:
# pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
# if i != 0:
# pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
# pr_pts = (pr_pts_l, pr_pts_r)
if self.norm_mode:
pr_pts, pr_factor = normalize_pointcloud_t(
pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
)
else:
pr_factor = None
if self.norm_mode and not self.gt_scale:
gt_pts, gt_factor = normalize_pointcloud_t(
gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
)
else:
gt_factor = None
return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
def compute_frame_loss(self, gts, preds, **kw):
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
self.get_all_pts3d_t(gts, preds, **kw)
)
pred_pts_l, pred_pts_r = pred_pts
loss_all = []
mask_all = []
conf_all = []
loss_left = 0
loss_right = 0
pred_conf_l = 0
pred_conf_r = 0
for i in range(len(gt_pts)):
# Left (Reference)
if i != len(gt_pts) - 1:
frame_loss = self.criterion(
pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
)
loss_all.append(frame_loss)
mask_all.append(masks[i])
conf_all.append(preds[i][0]["conf"])
# To compare target/reference loss
if i != 0:
loss_left += frame_loss.cpu().detach().numpy().mean()
pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
# Right (Target)
if i != 0:
frame_loss = self.criterion(
pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
)
loss_all.append(frame_loss)
mask_all.append(masks[i])
conf_all.append(preds[i - 1][1]["conf"])
# To compare target/reference loss
if i != len(gt_pts) - 1:
loss_right += frame_loss.cpu().detach().numpy().mean()
pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
if pr_factor is not None and gt_factor is not None:
filter_factor = pr_factor[pr_factor > gt_factor]
else:
filter_factor = []
if len(filter_factor) > 0:
factor_loss = (filter_factor - gt_factor).abs().mean()
else:
factor_loss = 0.0
self_name = type(self).__name__
details = {
self_name + "_pts3d_1": float(loss_all[0].mean()),
self_name + "_pts3d_2": float(loss_all[1].mean()),
self_name + "loss_left": float(loss_left),
self_name + "loss_right": float(loss_right),
self_name + "conf_left": float(pred_conf_l),
self_name + "conf_right": float(pred_conf_r),
}
return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
class ConfLoss_t(MultiLoss):
"""Weighted regression by learned confidence.
Assuming the input pixel_loss is a pixel-level regression loss.
Principle:
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
alpha: hyperparameter
"""
def __init__(self, pixel_loss, alpha=1):
super().__init__()
assert alpha > 0
self.alpha = alpha
self.pixel_loss = pixel_loss.with_reduction("none")
def get_name(self):
return f"ConfLoss({self.pixel_loss})"
def get_conf_log(self, x):
return x, torch.log(x)
def compute_frame_loss(self, gts, preds, **kw):
# compute per-pixel loss
(losses, masks, confs), details, loss_factor = (
self.pixel_loss.compute_frame_loss(gts, preds, **kw)
)
# weight by confidence
conf_losses = []
conf_sum = 0
for i in range(len(losses)):
conf, log_conf = self.get_conf_log(confs[i][masks[i]])
conf_sum += conf.mean()
conf_loss = losses[i] * conf - self.alpha * log_conf
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
conf_losses.append(conf_loss)
conf_losses = torch.stack(conf_losses) * 2.0
conf_loss_mean = conf_losses.mean()
return (
conf_loss_mean,
dict(
conf_loss_1=float(conf_losses[0]),
conf_loss2=float(conf_losses[1]),
conf_mean=conf_sum / len(losses),
**details,
),
loss_factor,
)
class Regr3D_t_ShiftInv(Regr3D_t):
"""Same than Regr3D but invariant to depth shift."""
def get_all_pts3d_t(self, gts, preds):
# compute unnormalized points
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
super().get_all_pts3d_t(gts, preds)
)
# pred_pts_l, pred_pts_r = pred_pts
gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
# pred_zs.append(pred_pts_r[-1][..., 2])
# compute median depth
gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
# subtract the median depth
for i in range(len(gt_pts)):
gt_pts[i][..., 2] -= gt_shift_z
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i][..., 2] -= pred_shift_z
monitoring = dict(
monitoring,
gt_shift_z=gt_shift_z.mean().detach(),
pred_shift_z=pred_shift_z.mean().detach(),
)
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
class Regr3D_t_ScaleInv(Regr3D_t):
"""Same than Regr3D but invariant to depth shift.
if gt_scale == True: enforce the prediction to take the same scale than GT
"""
def get_all_pts3d_t(self, gts, preds):
# compute depth-normalized points
gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
super().get_all_pts3d_t(gts, preds)
)
# measure scene scale
# pred_pts_l, pred_pts_r = pred_pts
pred_pts_all = [
x.clone() for x in pred_pts
] # [pred_pt for pred_pt in pred_pts_l]
# pred_pts_all.append(pred_pts_r[-1])
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
# prevent predictions to be in a ridiculous range
pred_scale = pred_scale.clip(min=1e-3, max=1e3)
# subtract the median depth
if self.gt_scale:
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i] *= gt_scale / pred_scale
else:
for i in range(len(pred_pts)):
# for j in range(len(pred_pts[i])):
pred_pts[i] *= pred_scale / gt_scale
for i in range(len(gt_pts)):
gt_pts[i] *= gt_scale / pred_scale
monitoring = dict(
monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
)
return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
# calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
pass