Spaces:
Runtime error
Runtime error
File size: 4,689 Bytes
2df809d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R
import torch
import torch.nn.functional as F
def postprocess(out, depth_mode, conf_mode, pos_z=False):
"""
extract 3D points/confidence from prediction head output
"""
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode, pos_z=pos_z))
if conf_mode is not None:
res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
return res
def postprocess_rgb(out, eps=1e-6):
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
res = torch.sigmoid(fmap) * (1 - 2 * eps) + eps
res = (res - 0.5) * 2
return dict(rgb=res)
def postprocess_pose(out, mode, inverse=False):
"""
extract pose from prediction head output
"""
mode, vmin, vmax = mode
no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
assert no_bounds
trans = out[..., 0:3]
quats = out[..., 3:7]
if mode == "linear":
if no_bounds:
return trans # [-inf, +inf]
return trans.clip(min=vmin, max=vmax)
d = trans.norm(dim=-1, keepdim=True)
if mode == "square":
if inverse:
scale = d / d.square().clip(min=1e-8)
else:
scale = d.square() / d.clip(min=1e-8)
if mode == "exp":
if inverse:
scale = d / torch.expm1(d).clip(min=1e-8)
else:
scale = torch.expm1(d) / d.clip(min=1e-8)
trans = trans * scale
quats = standardize_quaternion(quats)
return torch.cat([trans, quats], dim=-1)
def postprocess_pose_conf(out):
fmap = out.permute(0, 2, 3, 1) # B,H,W,1
return dict(pose_conf=torch.sigmoid(fmap))
def postprocess_desc(out, depth_mode, conf_mode, desc_dim, double_channel=False):
"""
extract 3D points/confidence from prediction head output
"""
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
if conf_mode is not None:
res["conf"] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
if double_channel:
res["pts3d_self"] = reg_dense_depth(
fmap[
:, :, :, 3 + int(conf_mode is not None) : 6 + int(conf_mode is not None)
],
mode=depth_mode,
)
if conf_mode is not None:
res["conf_self"] = reg_dense_conf(
fmap[:, :, :, 6 + int(conf_mode is not None)], mode=conf_mode
)
start = (
3
+ int(conf_mode is not None)
+ int(double_channel) * (3 + int(conf_mode is not None))
)
res["desc"] = reg_desc(fmap[:, :, :, start : start + desc_dim], mode="norm")
res["desc_conf"] = reg_dense_conf(fmap[:, :, :, start + desc_dim], mode=conf_mode)
assert start + desc_dim + 1 == fmap.shape[-1]
return res
def reg_desc(desc, mode="norm"):
if "norm" in mode:
desc = desc / desc.norm(dim=-1, keepdim=True)
else:
raise ValueError(f"Unknown desc mode {mode}")
return desc
def reg_dense_depth(xyz, mode, pos_z=False):
"""
extract 3D points from prediction head output
"""
mode, vmin, vmax = mode
no_bounds = (vmin == -float("inf")) and (vmax == float("inf"))
assert no_bounds
if mode == "linear":
if no_bounds:
return xyz # [-inf, +inf]
return xyz.clip(min=vmin, max=vmax)
if pos_z:
sign = torch.sign(xyz[..., -1:])
xyz *= sign
d = xyz.norm(dim=-1, keepdim=True)
xyz = xyz / d.clip(min=1e-8)
if mode == "square":
return xyz * d.square()
if mode == "exp":
return xyz * torch.expm1(d)
raise ValueError(f"bad {mode=}")
def reg_dense_conf(x, mode):
"""
extract confidence from prediction head output
"""
mode, vmin, vmax = mode
if mode == "exp":
return vmin + x.exp().clip(max=vmax - vmin)
if mode == "sigmoid":
return (vmax - vmin) * torch.sigmoid(x) + vmin
raise ValueError(f"bad {mode=}")
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
quaternions = F.normalize(quaternions, p=2, dim=-1)
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|