Spaces:
Running
on
T4
Running
on
T4
File size: 3,096 Bytes
4562a06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import ipdb # noqa: F401
import torch
from pytorch3d.transforms import Rotate, Translate
def intersect_skew_line_groups(p, r, mask=None):
# p, r both of shape (B, N, n_intersected_lines, 3)
# mask of shape (B, N, n_intersected_lines)
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
if p_intersect is None:
return None, None, None, None
_, p_line_intersect = point_line_distance(
p, r, p_intersect[..., None, :].expand_as(p)
)
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
dim=-1
)
return p_intersect, p_line_intersect, intersect_dist_squared, r
def intersect_skew_lines_high_dim(p, r, mask=None):
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
dim = p.shape[-1]
# make sure the heading vectors are l2-normed
if mask is None:
mask = torch.ones_like(p[..., 0])
r = torch.nn.functional.normalize(r, dim=-1)
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
if torch.any(torch.isnan(p_intersect)):
print(p_intersect)
return None, None
ipdb.set_trace()
assert False
return p_intersect, r
def point_line_distance(p1, r1, p2):
df = p2 - p1
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
line_pt_nearest = p2 - proj_vector
d = (proj_vector).norm(dim=-1)
return d, line_pt_nearest
def compute_optical_axis_intersection(cameras):
centers = cameras.get_camera_center()
principal_points = cameras.principal_point
one_vec = torch.ones((len(cameras), 1), device=centers.device)
optical_axis = torch.cat((principal_points, one_vec), -1)
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
directions = pp2 - centers
centers = centers.unsqueeze(0).unsqueeze(0)
directions = directions.unsqueeze(0).unsqueeze(0)
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
p=centers, r=directions, mask=None
)
if p_intersect is None:
dist = None
else:
p_intersect = p_intersect.squeeze().unsqueeze(0)
dist = (p_intersect - centers).norm(dim=-1)
return p_intersect, dist, p_line_intersect, pp2, r
def first_camera_transform(cameras, rotation_only=True):
new_cameras = cameras.clone()
new_transform = new_cameras.get_world_to_view_transform()
tR = Rotate(new_cameras.R[0].unsqueeze(0))
if rotation_only:
t = tR.inverse()
else:
tT = Translate(new_cameras.T[0].unsqueeze(0))
t = tR.compose(tT).inverse()
new_transform = t.compose(new_transform)
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
new_cameras.T = new_transform.get_matrix()[:, 3, :3]
return new_cameras
|