vmem / extern /CUT3R /src /dust3r /heads /postprocess.py
liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R
import torch
import torch.nn.functional as F
def postprocess(out, depth_mode, conf_mode, pos_z=False):
"""
extract 3D points/confidence from prediction head output
"""
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode, pos_z=pos_z))
if conf_mode is not None:
res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
return res
def postprocess_rgb(out, eps=1e-6):
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
res = torch.sigmoid(fmap) * (1 - 2 * eps) + eps
res = (res - 0.5) * 2
return dict(rgb=res)
def postprocess_pose(out, mode, inverse=False):
"""
extract pose from prediction head output
"""
mode, vmin, vmax = mode
no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
assert no_bounds
trans = out[..., 0:3]
quats = out[..., 3:7]
if mode == "linear":
if no_bounds:
return trans # [-inf, +inf]
return trans.clip(min=vmin, max=vmax)
d = trans.norm(dim=-1, keepdim=True)
if mode == "square":
if inverse:
scale = d / d.square().clip(min=1e-8)
else:
scale = d.square() / d.clip(min=1e-8)
if mode == "exp":
if inverse:
scale = d / torch.expm1(d).clip(min=1e-8)
else:
scale = torch.expm1(d) / d.clip(min=1e-8)
trans = trans * scale
quats = standardize_quaternion(quats)
return torch.cat([trans, quats], dim=-1)
def postprocess_pose_conf(out):
fmap = out.permute(0, 2, 3, 1) # B,H,W,1
return dict(pose_conf=torch.sigmoid(fmap))
def postprocess_desc(out, depth_mode, conf_mode, desc_dim, double_channel=False):
"""
extract 3D points/confidence from prediction head output
"""
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
if conf_mode is not None:
res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
if double_channel:
res["pts3d_self"] = reg_dense_depth(
fmap[
:, :, :, 3 + int(conf_mode is not None) : 6 + int(conf_mode is not None)
],
mode=depth_mode,
)
if conf_mode is not None:
res["conf_self"] = reg_dense_conf(
fmap[:, :, :, 6 + int(conf_mode is not None)], mode=conf_mode
)
start = (
3
+ int(conf_mode is not None)
+ int(double_channel) * (3 + int(conf_mode is not None))
)
res["desc"] = reg_desc(fmap[:, :, :, start : start + desc_dim], mode="norm")
res["desc_conf"] = reg_dense_conf(fmap[:, :, :, start + desc_dim], mode=conf_mode)
assert start + desc_dim + 1 == fmap.shape[-1]
return res
def reg_desc(desc, mode="norm"):
if "norm" in mode:
desc = desc / desc.norm(dim=-1, keepdim=True)
else:
raise ValueError(f"Unknown desc mode {mode}")
return desc
def reg_dense_depth(xyz, mode, pos_z=False):
"""
extract 3D points from prediction head output
"""
mode, vmin, vmax = mode
no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
assert no_bounds
if mode == "linear":
if no_bounds:
return xyz # [-inf, +inf]
return xyz.clip(min=vmin, max=vmax)
if pos_z:
sign = torch.sign(xyz[..., -1:])
xyz *= sign
d = xyz.norm(dim=-1, keepdim=True)
xyz = xyz / d.clip(min=1e-8)
if mode == "square":
return xyz * d.square()
if mode == "exp":
return xyz * torch.expm1(d)
raise ValueError(f"bad {mode=}")
def reg_dense_conf(x, mode):
"""
extract confidence from prediction head output
"""
mode, vmin, vmax = mode
if mode == "exp":
return vmin + x.exp().clip(max=vmax - vmin)
if mode == "sigmoid":
return (vmax - vmin) * torch.sigmoid(x) + vmin
raise ValueError(f"bad {mode=}")
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
quaternions = F.normalize(quaternions, p=2, dim=-1)
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)