import ipdb # noqa: F401 import torch from pytorch3d.transforms import Rotate, Translate def intersect_skew_line_groups(p, r, mask=None): # p, r both of shape (B, N, n_intersected_lines, 3) # mask of shape (B, N, n_intersected_lines) p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) if p_intersect is None: return None, None, None, None _, p_line_intersect = point_line_distance( p, r, p_intersect[..., None, :].expand_as(p) ) intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( dim=-1 ) return p_intersect, p_line_intersect, intersect_dist_squared, r def intersect_skew_lines_high_dim(p, r, mask=None): # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions dim = p.shape[-1] # make sure the heading vectors are l2-normed if mask is None: mask = torch.ones_like(p[..., 0]) r = torch.nn.functional.normalize(r, dim=-1) eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] if torch.any(torch.isnan(p_intersect)): print(p_intersect) return None, None ipdb.set_trace() assert False return p_intersect, r def point_line_distance(p1, r1, p2): df = p2 - p1 proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) line_pt_nearest = p2 - proj_vector d = (proj_vector).norm(dim=-1) return d, line_pt_nearest def compute_optical_axis_intersection(cameras): centers = cameras.get_camera_center() principal_points = cameras.principal_point one_vec = torch.ones((len(cameras), 1), device=centers.device) optical_axis = torch.cat((principal_points, one_vec), -1) pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) pp2 = torch.diagonal(pp, dim1=0, dim2=1).T directions = pp2 - centers centers = centers.unsqueeze(0).unsqueeze(0) directions = directions.unsqueeze(0).unsqueeze(0) p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( p=centers, r=directions, mask=None ) if p_intersect is None: dist = None else: p_intersect = p_intersect.squeeze().unsqueeze(0) dist = (p_intersect - centers).norm(dim=-1) return p_intersect, dist, p_line_intersect, pp2, r def first_camera_transform(cameras, rotation_only=True): new_cameras = cameras.clone() new_transform = new_cameras.get_world_to_view_transform() tR = Rotate(new_cameras.R[0].unsqueeze(0)) if rotation_only: t = tR.inverse() else: tT = Translate(new_cameras.T[0].unsqueeze(0)) t = tR.compose(tT).inverse() new_transform = t.compose(new_transform) new_cameras.R = new_transform.get_matrix()[:, :3, :3] new_cameras.T = new_transform.get_matrix()[:, 3, :3] return new_cameras