|
import cv2 |
|
import numpy as np |
|
import torch |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
|
|
|
|
def decompose_extrinsic_RT(E: torch.Tensor): |
|
""" |
|
Decompose the standard extrinsic matrix into RT. |
|
Batched I/O. |
|
""" |
|
return E[:, :3, :] |
|
|
|
|
|
def compose_extrinsic_RT(RT: torch.Tensor): |
|
""" |
|
Compose the standard form extrinsic matrix from RT. |
|
Batched I/O. |
|
""" |
|
return torch.cat([ |
|
RT, |
|
torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1) |
|
], dim=1) |
|
|
|
|
|
def camera_normalization(pivotal_pose: torch.Tensor, poses: torch.Tensor): |
|
|
|
|
|
canonical_camera_extrinsics = torch.tensor([[ |
|
[1, 0, 0, 0], |
|
[0, 1, 0, 0], |
|
[0, 0, 1, 0], |
|
[0, 0, 0, 1], |
|
]], dtype=torch.float32, device=pivotal_pose.device) |
|
pivotal_pose_inv = torch.inverse(pivotal_pose) |
|
camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv) |
|
|
|
|
|
poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses) |
|
|
|
return poses |
|
|
|
|
|
|
|
|
|
def rt2mat(R, T): |
|
mat = np.eye(4) |
|
mat[0:3, 0:3] = R |
|
mat[0:3, 3] = T |
|
return mat |
|
|
|
|
|
def skew_sym_mat(x): |
|
device = x.device |
|
dtype = x.dtype |
|
ssm = torch.zeros(3, 3, device=device, dtype=dtype) |
|
ssm[0, 1] = -x[2] |
|
ssm[0, 2] = x[1] |
|
ssm[1, 0] = x[2] |
|
ssm[1, 2] = -x[0] |
|
ssm[2, 0] = -x[1] |
|
ssm[2, 1] = x[0] |
|
return ssm |
|
|
|
|
|
def SO3_exp(theta): |
|
device = theta.device |
|
dtype = theta.dtype |
|
|
|
W = skew_sym_mat(theta) |
|
W2 = W @ W |
|
angle = torch.norm(theta) |
|
I = torch.eye(3, device=device, dtype=dtype) |
|
if angle < 1e-5: |
|
return I + W + 0.5 * W2 |
|
else: |
|
return ( |
|
I |
|
+ (torch.sin(angle) / angle) * W |
|
+ ((1 - torch.cos(angle)) / (angle**2)) * W2 |
|
) |
|
|
|
|
|
def V(theta): |
|
dtype = theta.dtype |
|
device = theta.device |
|
I = torch.eye(3, device=device, dtype=dtype) |
|
W = skew_sym_mat(theta) |
|
W2 = W @ W |
|
angle = torch.norm(theta) |
|
if angle < 1e-5: |
|
V = I + 0.5 * W + (1.0 / 6.0) * W2 |
|
else: |
|
V = ( |
|
I |
|
+ W * ((1.0 - torch.cos(angle)) / (angle**2)) |
|
+ W2 * ((angle - torch.sin(angle)) / (angle**3)) |
|
) |
|
return V |
|
|
|
|
|
def SE3_exp(tau): |
|
dtype = tau.dtype |
|
device = tau.device |
|
|
|
rho = tau[:3] |
|
theta = tau[3:] |
|
R = SO3_exp(theta) |
|
t = V(theta) @ rho |
|
|
|
T = torch.eye(4, device=device, dtype=dtype) |
|
T[:3, :3] = R |
|
T[:3, 3] = t |
|
return T |
|
|
|
|
|
def update_pose(cam_trans_delta: Float[Tensor, "batch 3"], |
|
cam_rot_delta: Float[Tensor, "batch 3"], |
|
extrinsics: Float[Tensor, "batch 4 4"], |
|
|
|
|
|
|
|
): |
|
|
|
bs = cam_trans_delta.shape[0] |
|
|
|
tau = torch.cat([cam_trans_delta, cam_rot_delta], dim=-1) |
|
T_w2c = extrinsics.inverse() |
|
|
|
new_w2c_list = [] |
|
for i in range(bs): |
|
new_w2c = SE3_exp(tau[i]) @ T_w2c[i] |
|
new_w2c_list.append(new_w2c) |
|
|
|
new_w2c = torch.stack(new_w2c_list, dim=0) |
|
return new_w2c.inverse() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inv(mat): |
|
""" Invert a torch or numpy matrix |
|
""" |
|
if isinstance(mat, torch.Tensor): |
|
return torch.linalg.inv(mat) |
|
if isinstance(mat, np.ndarray): |
|
return np.linalg.inv(mat) |
|
raise ValueError(f'bad matrix type = {type(mat)}') |
|
|
|
|
|
def get_pnp_pose(pts3d, opacity, K, H, W, opacity_threshold=0.3): |
|
pixels = np.mgrid[:W, :H].T.astype(np.float32) |
|
pts3d = pts3d.cpu().numpy() |
|
opacity = opacity.cpu().numpy() |
|
K = K.cpu().numpy() |
|
|
|
K[0, :] = K[0, :] * W |
|
K[1, :] = K[1, :] * H |
|
|
|
mask = opacity > opacity_threshold |
|
|
|
res = cv2.solvePnPRansac(pts3d[mask], pixels[mask], K, None, |
|
iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) |
|
success, R, T, inliers = res |
|
|
|
assert success |
|
|
|
R = cv2.Rodrigues(R)[0] |
|
pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) |
|
|
|
return torch.from_numpy(pose.astype(np.float32)) |
|
|
|
|
|
def pose_auc(errors, thresholds): |
|
sort_idx = np.argsort(errors) |
|
errors = np.array(errors.copy())[sort_idx] |
|
recall = (np.arange(len(errors)) + 1) / len(errors) |
|
errors = np.r_[0.0, errors] |
|
recall = np.r_[0.0, recall] |
|
aucs = [] |
|
for t in thresholds: |
|
last_index = np.searchsorted(errors, t) |
|
r = np.r_[recall[:last_index], recall[last_index - 1]] |
|
e = np.r_[errors[:last_index], t] |
|
aucs.append(np.trapz(r, x=e) / t) |
|
return aucs |
|
|
|
|
|
def rotation_6d_to_matrix(d6): |
|
""" |
|
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix |
|
using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. |
|
Args: |
|
d6: 6D rotation representation, of size (*, 6) |
|
|
|
Returns: |
|
batch of rotation matrices of size (*, 3, 3) |
|
|
|
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. |
|
On the Continuity of Rotation Representations in Neural Networks. |
|
IEEE Conference on Computer Vision and Pattern Recognition, 2019. |
|
Retrieved from http://arxiv.org/abs/1812.07035 |
|
""" |
|
|
|
a1, a2 = d6[..., :3], d6[..., 3:] |
|
b1 = F.normalize(a1, dim=-1) |
|
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 |
|
b2 = F.normalize(b2, dim=-1) |
|
b3 = torch.cross(b1, b2, dim=-1) |
|
return torch.stack((b1, b2, b3), dim=-2) |