Spaces:
Sleeping
Sleeping
| import ipdb # noqa: F401 | |
| import torch | |
| from pytorch3d.transforms import Rotate, Translate | |
| def intersect_skew_line_groups(p, r, mask=None): | |
| # p, r both of shape (B, N, n_intersected_lines, 3) | |
| # mask of shape (B, N, n_intersected_lines) | |
| p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) | |
| if p_intersect is None: | |
| return None, None, None, None | |
| _, p_line_intersect = point_line_distance( | |
| p, r, p_intersect[..., None, :].expand_as(p) | |
| ) | |
| intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( | |
| dim=-1 | |
| ) | |
| return p_intersect, p_line_intersect, intersect_dist_squared, r | |
| def intersect_skew_lines_high_dim(p, r, mask=None): | |
| # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions | |
| dim = p.shape[-1] | |
| # make sure the heading vectors are l2-normed | |
| if mask is None: | |
| mask = torch.ones_like(p[..., 0]) | |
| r = torch.nn.functional.normalize(r, dim=-1) | |
| eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] | |
| I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] | |
| sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) | |
| p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] | |
| if torch.any(torch.isnan(p_intersect)): | |
| print(p_intersect) | |
| return None, None | |
| ipdb.set_trace() | |
| assert False | |
| return p_intersect, r | |
| def point_line_distance(p1, r1, p2): | |
| df = p2 - p1 | |
| proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) | |
| line_pt_nearest = p2 - proj_vector | |
| d = (proj_vector).norm(dim=-1) | |
| return d, line_pt_nearest | |
| def compute_optical_axis_intersection(cameras): | |
| centers = cameras.get_camera_center() | |
| principal_points = cameras.principal_point | |
| one_vec = torch.ones((len(cameras), 1), device=centers.device) | |
| optical_axis = torch.cat((principal_points, one_vec), -1) | |
| pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) | |
| pp2 = torch.diagonal(pp, dim1=0, dim2=1).T | |
| directions = pp2 - centers | |
| centers = centers.unsqueeze(0).unsqueeze(0) | |
| directions = directions.unsqueeze(0).unsqueeze(0) | |
| p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( | |
| p=centers, r=directions, mask=None | |
| ) | |
| if p_intersect is None: | |
| dist = None | |
| else: | |
| p_intersect = p_intersect.squeeze().unsqueeze(0) | |
| dist = (p_intersect - centers).norm(dim=-1) | |
| return p_intersect, dist, p_line_intersect, pp2, r | |
| def first_camera_transform(cameras, rotation_only=True): | |
| new_cameras = cameras.clone() | |
| new_transform = new_cameras.get_world_to_view_transform() | |
| tR = Rotate(new_cameras.R[0].unsqueeze(0)) | |
| if rotation_only: | |
| t = tR.inverse() | |
| else: | |
| tT = Translate(new_cameras.T[0].unsqueeze(0)) | |
| t = tR.compose(tT).inverse() | |
| new_transform = t.compose(new_transform) | |
| new_cameras.R = new_transform.get_matrix()[:, :3, :3] | |
| new_cameras.T = new_transform.get_matrix()[:, 3, :3] | |
| return new_cameras | |