import torch from einops import einsum, rearrange, reduce from jaxtyping import Float from scipy.spatial.transform import Rotation as R from torch import Tensor def interpolate_intrinsics( initial: Float[Tensor, "*#batch 3 3"], final: Float[Tensor, "*#batch 3 3"], t: Float[Tensor, " time_step"], ) -> Float[Tensor, "*batch time_step 3 3"]: initial = rearrange(initial, "... i j -> ... () i j") final = rearrange(final, "... i j -> ... () i j") t = rearrange(t, "t -> t () ()") return initial + (final - initial) * t def intersect_rays( a_origins: Float[Tensor, "*#batch dim"], a_directions: Float[Tensor, "*#batch dim"], b_origins: Float[Tensor, "*#batch dim"], b_directions: Float[Tensor, "*#batch dim"], ) -> Float[Tensor, "*batch dim"]: """Compute the least-squares intersection of rays. Uses the math from here: https://math.stackexchange.com/a/1762491/286022 """ # Broadcast and stack the tensors. a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors( a_origins, a_directions, b_origins, b_directions ) origins = torch.stack((a_origins, b_origins), dim=-2) directions = torch.stack((a_directions, b_directions), dim=-2) # Compute n_i * n_i^T - eye(3) from the equation. n = einsum(directions, directions, "... n i, ... n j -> ... n i j") n = n - torch.eye(3, dtype=origins.dtype, device=origins.device) # Compute the left-hand side of the equation. lhs = reduce(n, "... n i j -> ... i j", "sum") # Compute the right-hand side of the equation. rhs = einsum(n, origins, "... n i j, ... n j -> ... n i") rhs = reduce(rhs, "... n i -> ... i", "sum") # Left-matrix-multiply both sides by the inverse of lhs to find p. return torch.linalg.lstsq(lhs, rhs).solution def normalize(a: Float[Tensor, "*#batch dim"]) -> Float[Tensor, "*#batch dim"]: return a / a.norm(dim=-1, keepdim=True) def generate_coordinate_frame( y: Float[Tensor, "*#batch 3"], z: Float[Tensor, "*#batch 3"], ) -> Float[Tensor, "*batch 3 3"]: """Generate a coordinate frame given perpendicular, unit-length Y and Z vectors.""" y, z = torch.broadcast_tensors(y, z) return torch.stack([y.cross(z), y, z], dim=-1) def generate_rotation_coordinate_frame( a: Float[Tensor, "*#batch 3"], b: Float[Tensor, "*#batch 3"], eps: float = 1e-4, ) -> Float[Tensor, "*batch 3 3"]: """Generate a coordinate frame where the Y direction is normal to the plane defined by unit vectors a and b. The other axes are arbitrary.""" device = a.device # Replace every entry in b that's parallel to the corresponding entry in a with an # arbitrary vector. b = b.detach().clone() parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device) parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device) # Generate the coordinate frame. The initial cross product defines the plane. return generate_coordinate_frame(normalize(a.cross(b)), a) def matrix_to_euler( rotations: Float[Tensor, "*batch 3 3"], pattern: str, ) -> Float[Tensor, "*batch 3"]: *batch, _, _ = rotations.shape rotations = rotations.reshape(-1, 3, 3) angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern) rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device) return rotations.reshape(*batch, 3) def euler_to_matrix( rotations: Float[Tensor, "*batch 3"], pattern: str, ) -> Float[Tensor, "*batch 3 3"]: *batch, _ = rotations.shape rotations = rotations.reshape(-1, 3) matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix() rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device) return rotations.reshape(*batch, 3, 3) def extrinsics_to_pivot_parameters( extrinsics: Float[Tensor, "*#batch 4 4"], pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"], pivot_point: Float[Tensor, "*#batch 3"], ) -> Float[Tensor, "*batch 5"]: """Convert the extrinsics to a representation with 5 degrees of freedom: 1. Distance from pivot point in the "X" (look cross pivot axis) direction. 2. Distance from pivot point in the "Y" (pivot axis) direction. 3. Distance from pivot point in the Z (look) direction 4. Angle in plane 5. Twist (rotation not in plane) """ # The pivot coordinate frame's Z axis is normal to the plane. pivot_axis = pivot_coordinate_frame[..., :, 1] # Compute the translation elements of the pivot parametrization. translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2]) origin = extrinsics[..., :3, 3] delta = pivot_point - origin translation = einsum(translation_frame, delta, "... i j, ... i -> ... j") # Add the rotation elements of the pivot parametrization. inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3] y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1) return torch.cat([translation, y[..., None], z[..., None]], dim=-1) def pivot_parameters_to_extrinsics( parameters: Float[Tensor, "*#batch 5"], pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"], pivot_point: Float[Tensor, "*#batch 3"], ) -> Float[Tensor, "*batch 4 4"]: translation, y, z = parameters.split((3, 1, 1), dim=-1) euler = torch.cat((y, torch.zeros_like(y), z), dim=-1) rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ") # The pivot coordinate frame's Z axis is normal to the plane. pivot_axis = pivot_coordinate_frame[..., :, 1] translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2]) delta = einsum(translation_frame, translation, "... i j, ... j -> ... i") origin = pivot_point - delta *batch, _ = origin.shape extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device) extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone() extrinsics[..., 3, 3] = 1 extrinsics[..., :3, :3] = rotation extrinsics[..., :3, 3] = origin return extrinsics def interpolate_circular( a: Float[Tensor, "*#batch"], b: Float[Tensor, "*#batch"], t: Float[Tensor, "*#batch"], ) -> Float[Tensor, " *batch"]: a, b, t = torch.broadcast_tensors(a, b, t) tau = 2 * torch.pi a = a % tau b = b % tau # Consider piecewise edge cases. d = (b - a).abs() a_left = a - tau d_left = (b - a_left).abs() a_right = a + tau d_right = (b - a_right).abs() use_d = (d < d_left) & (d < d_right) use_d_left = (d_left < d_right) & (~use_d) use_d_right = (~use_d) & (~use_d_left) result = a + (b - a) * t result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left] result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right] return result def interpolate_pivot_parameters( initial: Float[Tensor, "*#batch 5"], final: Float[Tensor, "*#batch 5"], t: Float[Tensor, " time_step"], ) -> Float[Tensor, "*batch time_step 5"]: initial = rearrange(initial, "... d -> ... () d") final = rearrange(final, "... d -> ... () d") t = rearrange(t, "t -> t ()") ti, ri = initial.split((3, 2), dim=-1) tf, rf = final.split((3, 2), dim=-1) t_lerp = ti + (tf - ti) * t r_lerp = interpolate_circular(ri, rf, t) return torch.cat((t_lerp, r_lerp), dim=-1) @torch.no_grad() def interpolate_extrinsics( initial: Float[Tensor, "*#batch 4 4"], final: Float[Tensor, "*#batch 4 4"], t: Float[Tensor, " time_step"], eps: float = 1e-4, ) -> Float[Tensor, "*batch time_step 4 4"]: """Interpolate extrinsics by rotating around their "focus point," which is the least-squares intersection between the look vectors of the initial and final extrinsics. """ initial = initial.type(torch.float64) final = final.type(torch.float64) t = t.type(torch.float64) # Based on the dot product between the look vectors, pick from one of two cases: # 1. Look vectors are parallel: interpolate about their origins' midpoint. # 3. Look vectors aren't parallel: interpolate about their focus point. initial_look = initial[..., :3, 2] final_look = final[..., :3, 2] dot_products = einsum(initial_look, final_look, "... i, ... i -> ...") parallel_mask = (dot_products.abs() - 1).abs() < eps # Pick focus points. initial_origin = initial[..., :3, 3] final_origin = final[..., :3, 3] pivot_point = 0.5 * (initial_origin + final_origin) pivot_point[~parallel_mask] = intersect_rays( initial_origin[~parallel_mask], initial_look[~parallel_mask], final_origin[~parallel_mask], final_look[~parallel_mask], ) # Convert to pivot parameters. pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps) initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point) final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point) # Interpolate the pivot parameters. interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t) # Convert back. return pivot_parameters_to_extrinsics( interpolated_params.type(torch.float32), rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32), rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32), )