AnySplat / src /misc /cam_utils.py
alexnasa's picture
Upload 243 files
2568013 verified
import cv2
import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor
import torch.nn.functional as F
def decompose_extrinsic_RT(E: torch.Tensor):
"""
Decompose the standard extrinsic matrix into RT.
Batched I/O.
"""
return E[:, :3, :]
def compose_extrinsic_RT(RT: torch.Tensor):
"""
Compose the standard form extrinsic matrix from RT.
Batched I/O.
"""
return torch.cat([
RT,
torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
], dim=1)
def camera_normalization(pivotal_pose: torch.Tensor, poses: torch.Tensor):
# [1, 4, 4], [N, 4, 4]
canonical_camera_extrinsics = torch.tensor([[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]], dtype=torch.float32, device=pivotal_pose.device)
pivotal_pose_inv = torch.inverse(pivotal_pose)
camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)
# normalize all views
poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
return poses
####### Pose update from delta
def rt2mat(R, T):
mat = np.eye(4)
mat[0:3, 0:3] = R
mat[0:3, 3] = T
return mat
def skew_sym_mat(x):
device = x.device
dtype = x.dtype
ssm = torch.zeros(3, 3, device=device, dtype=dtype)
ssm[0, 1] = -x[2]
ssm[0, 2] = x[1]
ssm[1, 0] = x[2]
ssm[1, 2] = -x[0]
ssm[2, 0] = -x[1]
ssm[2, 1] = x[0]
return ssm
def SO3_exp(theta):
device = theta.device
dtype = theta.dtype
W = skew_sym_mat(theta)
W2 = W @ W
angle = torch.norm(theta)
I = torch.eye(3, device=device, dtype=dtype)
if angle < 1e-5:
return I + W + 0.5 * W2
else:
return (
I
+ (torch.sin(angle) / angle) * W
+ ((1 - torch.cos(angle)) / (angle**2)) * W2
)
def V(theta):
dtype = theta.dtype
device = theta.device
I = torch.eye(3, device=device, dtype=dtype)
W = skew_sym_mat(theta)
W2 = W @ W
angle = torch.norm(theta)
if angle < 1e-5:
V = I + 0.5 * W + (1.0 / 6.0) * W2
else:
V = (
I
+ W * ((1.0 - torch.cos(angle)) / (angle**2))
+ W2 * ((angle - torch.sin(angle)) / (angle**3))
)
return V
def SE3_exp(tau):
dtype = tau.dtype
device = tau.device
rho = tau[:3]
theta = tau[3:]
R = SO3_exp(theta)
t = V(theta) @ rho
T = torch.eye(4, device=device, dtype=dtype)
T[:3, :3] = R
T[:3, 3] = t
return T
def update_pose(cam_trans_delta: Float[Tensor, "batch 3"],
cam_rot_delta: Float[Tensor, "batch 3"],
extrinsics: Float[Tensor, "batch 4 4"],
# original_rot: Float[Tensor, "batch 3 3"],
# original_trans: Float[Tensor, "batch 3"],
# converged_threshold: float = 1e-4
):
# extrinsics is c2w, here we need w2c as input, so we need to invert it
bs = cam_trans_delta.shape[0]
tau = torch.cat([cam_trans_delta, cam_rot_delta], dim=-1)
T_w2c = extrinsics.inverse()
new_w2c_list = []
for i in range(bs):
new_w2c = SE3_exp(tau[i]) @ T_w2c[i]
new_w2c_list.append(new_w2c)
new_w2c = torch.stack(new_w2c_list, dim=0)
return new_w2c.inverse()
# converged = tau.norm() < converged_threshold
# camera.update_RT(new_R, new_T)
#
# camera.cam_rot_delta.data.fill_(0)
# camera.cam_trans_delta.data.fill_(0)
# return converged
####### Pose estimation
def inv(mat):
""" Invert a torch or numpy matrix
"""
if isinstance(mat, torch.Tensor):
return torch.linalg.inv(mat)
if isinstance(mat, np.ndarray):
return np.linalg.inv(mat)
raise ValueError(f'bad matrix type = {type(mat)}')
def get_pnp_pose(pts3d, opacity, K, H, W, opacity_threshold=0.3):
pixels = np.mgrid[:W, :H].T.astype(np.float32)
pts3d = pts3d.cpu().numpy()
opacity = opacity.cpu().numpy()
K = K.cpu().numpy()
K[0, :] = K[0, :] * W
K[1, :] = K[1, :] * H
mask = opacity > opacity_threshold
res = cv2.solvePnPRansac(pts3d[mask], pixels[mask], K, None,
iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
success, R, T, inliers = res
assert success
R = cv2.Rodrigues(R)[0] # world to cam
pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
return torch.from_numpy(pose.astype(np.float32))
def pose_auc(errors, thresholds):
sort_idx = np.argsort(errors)
errors = np.array(errors.copy())[sort_idx]
recall = (np.arange(len(errors)) + 1) / len(errors)
errors = np.r_[0.0, errors]
recall = np.r_[0.0, recall]
aucs = []
for t in thresholds:
last_index = np.searchsorted(errors, t)
r = np.r_[recall[:last_index], recall[last_index - 1]]
e = np.r_[errors[:last_index], t]
aucs.append(np.trapz(r, x=e) / t)
return aucs
def rotation_6d_to_matrix(d6):
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d.
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)