import cv2 import numpy as np import torch from jaxtyping import Float from torch import Tensor import torch.nn.functional as F def decompose_extrinsic_RT(E: torch.Tensor): """ Decompose the standard extrinsic matrix into RT. Batched I/O. """ return E[:, :3, :] def compose_extrinsic_RT(RT: torch.Tensor): """ Compose the standard form extrinsic matrix from RT. Batched I/O. """ return torch.cat([ RT, torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1) ], dim=1) def camera_normalization(pivotal_pose: torch.Tensor, poses: torch.Tensor): # [1, 4, 4], [N, 4, 4] canonical_camera_extrinsics = torch.tensor([[ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], ]], dtype=torch.float32, device=pivotal_pose.device) pivotal_pose_inv = torch.inverse(pivotal_pose) camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv) # normalize all views poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses) return poses ####### Pose update from delta def rt2mat(R, T): mat = np.eye(4) mat[0:3, 0:3] = R mat[0:3, 3] = T return mat def skew_sym_mat(x): device = x.device dtype = x.dtype ssm = torch.zeros(3, 3, device=device, dtype=dtype) ssm[0, 1] = -x[2] ssm[0, 2] = x[1] ssm[1, 0] = x[2] ssm[1, 2] = -x[0] ssm[2, 0] = -x[1] ssm[2, 1] = x[0] return ssm def SO3_exp(theta): device = theta.device dtype = theta.dtype W = skew_sym_mat(theta) W2 = W @ W angle = torch.norm(theta) I = torch.eye(3, device=device, dtype=dtype) if angle < 1e-5: return I + W + 0.5 * W2 else: return ( I + (torch.sin(angle) / angle) * W + ((1 - torch.cos(angle)) / (angle**2)) * W2 ) def V(theta): dtype = theta.dtype device = theta.device I = torch.eye(3, device=device, dtype=dtype) W = skew_sym_mat(theta) W2 = W @ W angle = torch.norm(theta) if angle < 1e-5: V = I + 0.5 * W + (1.0 / 6.0) * W2 else: V = ( I + W * ((1.0 - torch.cos(angle)) / (angle**2)) + W2 * ((angle - torch.sin(angle)) / (angle**3)) ) return V def SE3_exp(tau): dtype = tau.dtype device = tau.device rho = tau[:3] theta = tau[3:] R = SO3_exp(theta) t = V(theta) @ rho T = torch.eye(4, device=device, dtype=dtype) T[:3, :3] = R T[:3, 3] = t return T def update_pose(cam_trans_delta: Float[Tensor, "batch 3"], cam_rot_delta: Float[Tensor, "batch 3"], extrinsics: Float[Tensor, "batch 4 4"], # original_rot: Float[Tensor, "batch 3 3"], # original_trans: Float[Tensor, "batch 3"], # converged_threshold: float = 1e-4 ): # extrinsics is c2w, here we need w2c as input, so we need to invert it bs = cam_trans_delta.shape[0] tau = torch.cat([cam_trans_delta, cam_rot_delta], dim=-1) T_w2c = extrinsics.inverse() new_w2c_list = [] for i in range(bs): new_w2c = SE3_exp(tau[i]) @ T_w2c[i] new_w2c_list.append(new_w2c) new_w2c = torch.stack(new_w2c_list, dim=0) return new_w2c.inverse() # converged = tau.norm() < converged_threshold # camera.update_RT(new_R, new_T) # # camera.cam_rot_delta.data.fill_(0) # camera.cam_trans_delta.data.fill_(0) # return converged ####### Pose estimation def inv(mat): """ Invert a torch or numpy matrix """ if isinstance(mat, torch.Tensor): return torch.linalg.inv(mat) if isinstance(mat, np.ndarray): return np.linalg.inv(mat) raise ValueError(f'bad matrix type = {type(mat)}') def get_pnp_pose(pts3d, opacity, K, H, W, opacity_threshold=0.3): pixels = np.mgrid[:W, :H].T.astype(np.float32) pts3d = pts3d.cpu().numpy() opacity = opacity.cpu().numpy() K = K.cpu().numpy() K[0, :] = K[0, :] * W K[1, :] = K[1, :] * H mask = opacity > opacity_threshold res = cv2.solvePnPRansac(pts3d[mask], pixels[mask], K, None, iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) success, R, T, inliers = res assert success R = cv2.Rodrigues(R)[0] # world to cam pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world return torch.from_numpy(pose.astype(np.float32)) def pose_auc(errors, thresholds): sort_idx = np.argsort(errors) errors = np.array(errors.copy())[sort_idx] recall = (np.arange(len(errors)) + 1) / len(errors) errors = np.r_[0.0, errors] recall = np.r_[0.0, recall] aucs = [] for t in thresholds: last_index = np.searchsorted(errors, t) r = np.r_[recall[:last_index], recall[last_index - 1]] e = np.r_[errors[:last_index], t] aucs.append(np.trapz(r, x=e) / t) return aucs def rotation_6d_to_matrix(d6): """ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. Args: d6: 6D rotation representation, of size (*, 6) Returns: batch of rotation matrices of size (*, 3, 3) [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ a1, a2 = d6[..., :3], d6[..., 3:] b1 = F.normalize(a1, dim=-1) b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 b2 = F.normalize(b2, dim=-1) b3 = torch.cross(b1, b2, dim=-1) return torch.stack((b1, b2, b3), dim=-2)