Spaces:
Configuration error
Configuration error
| import torch | |
| def compute_rotation_matrix_from_ortho6d(poses): | |
| """ | |
| Code from | |
| https://github.com/papagina/RotationContinuity | |
| On the Continuity of Rotation Representations in Neural Networks | |
| Zhou et al. CVPR19 | |
| https://zhouyisjtu.github.io/project_rotation/rotation.html | |
| """ | |
| x_raw = poses[:, 0:3] # batch*3 | |
| y_raw = poses[:, 3:6] # batch*3 | |
| x = normalize_vector(x_raw) # batch*3 | |
| z = cross_product(x, y_raw) # batch*3 | |
| z = normalize_vector(z) # batch*3 | |
| y = cross_product(z, x) # batch*3 | |
| x = x.view(-1, 3, 1) | |
| y = y.view(-1, 3, 1) | |
| z = z.view(-1, 3, 1) | |
| matrix = torch.cat((x, y, z), 2) # batch*3*3 | |
| return matrix | |
| def robust_compute_rotation_matrix_from_ortho6d(poses): | |
| """ | |
| Instead of making 2nd vector orthogonal to first | |
| create a base that takes into account the two predicted | |
| directions equally | |
| """ | |
| x_raw = poses[:, 0:3] # batch*3 | |
| y_raw = poses[:, 3:6] # batch*3 | |
| x = normalize_vector(x_raw) # batch*3 | |
| y = normalize_vector(y_raw) # batch*3 | |
| middle = normalize_vector(x + y) | |
| orthmid = normalize_vector(x - y) | |
| x = normalize_vector(middle + orthmid) | |
| y = normalize_vector(middle - orthmid) | |
| # Their scalar product should be small ! | |
| # assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001 | |
| z = normalize_vector(cross_product(x, y)) | |
| x = x.view(-1, 3, 1) | |
| y = y.view(-1, 3, 1) | |
| z = z.view(-1, 3, 1) | |
| matrix = torch.cat((x, y, z), 2) # batch*3*3 | |
| # Check for reflection in matrix ! If found, flip last vector TODO | |
| assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0 | |
| return matrix | |
| def normalize_vector(v): | |
| batch = v.shape[0] | |
| v_mag = torch.sqrt(v.pow(2).sum(1)) # batch | |
| v_mag = torch.max(v_mag, v.new([1e-8])) | |
| v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1]) | |
| v = v/v_mag | |
| return v | |
| def cross_product(u, v): | |
| batch = u.shape[0] | |
| i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] | |
| j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] | |
| k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] | |
| out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) | |
| return out | |