Spaces:
Running
on
T4
Running
on
T4
import ipdb # noqa: F401 | |
import torch | |
from pytorch3d.transforms import Rotate, Translate | |
def intersect_skew_line_groups(p, r, mask=None): | |
# p, r both of shape (B, N, n_intersected_lines, 3) | |
# mask of shape (B, N, n_intersected_lines) | |
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) | |
if p_intersect is None: | |
return None, None, None, None | |
_, p_line_intersect = point_line_distance( | |
p, r, p_intersect[..., None, :].expand_as(p) | |
) | |
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( | |
dim=-1 | |
) | |
return p_intersect, p_line_intersect, intersect_dist_squared, r | |
def intersect_skew_lines_high_dim(p, r, mask=None): | |
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions | |
dim = p.shape[-1] | |
# make sure the heading vectors are l2-normed | |
if mask is None: | |
mask = torch.ones_like(p[..., 0]) | |
r = torch.nn.functional.normalize(r, dim=-1) | |
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] | |
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] | |
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) | |
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] | |
if torch.any(torch.isnan(p_intersect)): | |
print(p_intersect) | |
return None, None | |
ipdb.set_trace() | |
assert False | |
return p_intersect, r | |
def point_line_distance(p1, r1, p2): | |
df = p2 - p1 | |
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) | |
line_pt_nearest = p2 - proj_vector | |
d = (proj_vector).norm(dim=-1) | |
return d, line_pt_nearest | |
def compute_optical_axis_intersection(cameras): | |
centers = cameras.get_camera_center() | |
principal_points = cameras.principal_point | |
one_vec = torch.ones((len(cameras), 1), device=centers.device) | |
optical_axis = torch.cat((principal_points, one_vec), -1) | |
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) | |
pp2 = torch.diagonal(pp, dim1=0, dim2=1).T | |
directions = pp2 - centers | |
centers = centers.unsqueeze(0).unsqueeze(0) | |
directions = directions.unsqueeze(0).unsqueeze(0) | |
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( | |
p=centers, r=directions, mask=None | |
) | |
if p_intersect is None: | |
dist = None | |
else: | |
p_intersect = p_intersect.squeeze().unsqueeze(0) | |
dist = (p_intersect - centers).norm(dim=-1) | |
return p_intersect, dist, p_line_intersect, pp2, r | |
def first_camera_transform(cameras, rotation_only=True): | |
new_cameras = cameras.clone() | |
new_transform = new_cameras.get_world_to_view_transform() | |
tR = Rotate(new_cameras.R[0].unsqueeze(0)) | |
if rotation_only: | |
t = tR.inverse() | |
else: | |
tT = Translate(new_cameras.T[0].unsqueeze(0)) | |
t = tR.compose(tT).inverse() | |
new_transform = t.compose(new_transform) | |
new_cameras.R = new_transform.get_matrix()[:, :3, :3] | |
new_cameras.T = new_transform.get_matrix()[:, 3, :3] | |
return new_cameras | |