Spaces:
Running
on
T4
Running
on
T4
import numpy as np | |
import torch | |
from pytorch3d.renderer import FoVPerspectiveCameras | |
from pytorch3d.transforms import quaternion_to_matrix | |
def generate_random_rotations(N=1, device="cpu"): | |
q = torch.randn(N, 4, device=device) | |
q = q / q.norm(dim=-1, keepdim=True) | |
return quaternion_to_matrix(q) | |
def symmetric_orthogonalization(x): | |
"""Maps 9D input vectors onto SO(3) via symmetric orthogonalization. | |
x: should have size [batch_size, 9] | |
Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3). | |
""" | |
m = x.view(-1, 3, 3) | |
u, s, v = torch.svd(m) | |
vt = torch.transpose(v, 1, 2) | |
det = torch.det(torch.matmul(u, vt)) | |
det = det.view(-1, 1, 1) | |
vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) | |
r = torch.matmul(u, vt) | |
return r | |
def get_permutations(num_images): | |
permutations = [] | |
for i in range(0, num_images): | |
for j in range(0, num_images): | |
if i != j: | |
permutations.append((j, i)) | |
return permutations | |
def n_to_np_rotations(num_frames, n_rots): | |
R_pred_rel = [] | |
permutations = get_permutations(num_frames) | |
for i, j in permutations: | |
R_pred_rel.append(n_rots[i].T @ n_rots[j]) | |
R_pred_rel = torch.stack(R_pred_rel) | |
return R_pred_rel | |
def compute_angular_error_batch(rotation1, rotation2): | |
R_rel = np.einsum("Bij,Bjk ->Bik", rotation2, rotation1.transpose(0, 2, 1)) | |
t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2 | |
theta = np.arccos(np.clip(t, -1, 1)) | |
return theta * 180 / np.pi | |
# A should be GT, B should be predicted | |
def compute_optimal_alignment(A, B): | |
""" | |
Compute the optimal scale s, rotation R, and translation t that minimizes: | |
|| A - (s * B @ R + T) || ^ 2 | |
Reference: Umeyama (TPAMI 91) | |
Args: | |
A (torch.Tensor): (N, 3). | |
B (torch.Tensor): (N, 3). | |
Returns: | |
s (float): scale. | |
R (torch.Tensor): rotation matrix (3, 3). | |
t (torch.Tensor): translation (3,). | |
""" | |
A_bar = A.mean(0) | |
B_bar = B.mean(0) | |
# normally with R @ B, this would be A @ B.T | |
H = (B - B_bar).T @ (A - A_bar) | |
U, S, Vh = torch.linalg.svd(H, full_matrices=True) | |
s = torch.linalg.det(U @ Vh) | |
S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device)) | |
variance = torch.sum((B - B_bar) ** 2) | |
scale = 1 / variance * torch.trace(torch.diag(S) @ S_prime) | |
R = U @ S_prime @ Vh | |
t = A_bar - scale * B_bar @ R | |
A_hat = scale * B @ R + t | |
return A_hat, scale, R, t | |
def compute_optimal_translation_alignment(T_A, T_B, R_B): | |
""" | |
Assuming right-multiplied rotation matrices. | |
E.g., for world2cam R and T, a world coordinate is transformed to camera coordinate | |
system using X_cam = X_world.T @ R + T = R.T @ X_world + T | |
Finds s, t that minimizes || T_A - (s * T_B + R_B.T @ t) ||^2 | |
Args: | |
T_A (torch.Tensor): Target translation (N, 3). | |
T_B (torch.Tensor): Initial translation (N, 3). | |
R_B (torch.Tensor): Initial rotation (N, 3, 3). | |
Returns: | |
T_A_hat (torch.Tensor): s * T_B + t @ R_B (N, 3). | |
scale s (torch.Tensor): (1,). | |
translation t (torch.Tensor): (1, 3). | |
""" | |
n = len(T_A) | |
T_A = T_A.unsqueeze(2) | |
T_B = T_B.unsqueeze(2) | |
A = torch.sum(T_B * T_A) | |
B = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) @ (R_B @ T_A).sum(0) / n | |
C = torch.sum(T_B * T_B) | |
D = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) | |
E = (D * D).sum() / n | |
s = (A - B.sum()) / (C - E.sum()) | |
t = (R_B @ (T_A - s * T_B)).sum(0) / n | |
T_A_hat = s * T_B + R_B.transpose(1, 2) @ t | |
return T_A_hat.squeeze(2), s, t.transpose(1, 0) | |
def get_error(predict_rotations, R_pred, T_pred, R_gt, T_gt, gt_scene_scale): | |
if predict_rotations: | |
cameras_gt = FoVPerspectiveCameras(R=R_gt, T=T_gt) | |
cc_gt = cameras_gt.get_camera_center() | |
cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred) | |
cc_pred = cameras_pred.get_camera_center() | |
A_hat, _, _, _ = compute_optimal_alignment(cc_gt, cc_pred) | |
norm = torch.linalg.norm(cc_gt - A_hat, dim=1) / gt_scene_scale | |
norms = np.ndarray.tolist(norm.detach().cpu().numpy()) | |
return norms, A_hat | |
else: | |
T_A_hat, _, _ = compute_optimal_translation_alignment(T_gt, T_pred, R_pred) | |
norm = torch.linalg.norm(T_gt - T_A_hat, dim=1) / gt_scene_scale | |
norms = np.ndarray.tolist(norm.detach().cpu().numpy()) | |
return norms, T_A_hat | |