|
import itertools |
|
from typing import Iterable, Literal, Optional, TypedDict |
|
|
|
import torch |
|
from einops import einsum, repeat |
|
from jaxtyping import Bool, Float |
|
from torch import Tensor |
|
from torch.utils.data.dataloader import default_collate |
|
|
|
from .projection import ( |
|
get_world_rays, |
|
homogenize_points, |
|
homogenize_vectors, |
|
intersect_rays, |
|
project_camera_space, |
|
) |
|
|
|
|
|
def _is_in_bounds( |
|
xy: Float[Tensor, "*batch 2"], |
|
epsilon: float = 1e-6, |
|
) -> Bool[Tensor, " *batch"]: |
|
"""Check whether the specified XY coordinates are within the normalized image plane, |
|
which has a range from 0 to 1 in each direction. |
|
""" |
|
return (xy >= -epsilon).all(dim=-1) & (xy <= 1 + epsilon).all(dim=-1) |
|
|
|
|
|
def _is_in_front_of_camera( |
|
xyz: Float[Tensor, "*batch 3"], |
|
epsilon: float = 1e-6, |
|
) -> Bool[Tensor, " *batch"]: |
|
"""Check whether the specified points in camera space are in front of the camera.""" |
|
return xyz[..., -1] > -epsilon |
|
|
|
|
|
def _is_positive_t( |
|
t: Float[Tensor, " *batch"], |
|
epsilon: float = 1e-6, |
|
) -> Bool[Tensor, " *batch"]: |
|
"""Check whether the specified t value is positive.""" |
|
return t > -epsilon |
|
|
|
|
|
class PointProjection(TypedDict): |
|
t: Float[Tensor, " *batch"] |
|
xy: Float[Tensor, "*batch 2"] |
|
|
|
|
|
|
|
|
|
valid: Bool[Tensor, " *batch"] |
|
|
|
|
|
def _intersect_image_coordinate( |
|
intrinsics: Float[Tensor, "*#batch 3 3"], |
|
origins: Float[Tensor, "*#batch 3"], |
|
directions: Float[Tensor, "*#batch 3"], |
|
dimension: Literal["x", "y"], |
|
coordinate_value: float, |
|
) -> PointProjection: |
|
"""Compute the intersection of the projection of a camera-space ray with a line |
|
that's parallel to the image frame, either horizontally or vertically. |
|
""" |
|
|
|
|
|
dim = "xy".index(dimension) |
|
other_dim = 1 - dim |
|
fs = intrinsics[..., dim, dim] |
|
fo = intrinsics[..., other_dim, other_dim] |
|
cs = intrinsics[..., dim, 2] |
|
co = intrinsics[..., other_dim, 2] |
|
os = origins[..., dim] |
|
oo = origins[..., other_dim] |
|
ds = directions[..., dim] |
|
do = directions[..., other_dim] |
|
oz = origins[..., 2] |
|
dz = directions[..., 2] |
|
c = (coordinate_value - cs) / fs |
|
|
|
|
|
|
|
t_numerator = c * oz - os |
|
t_denominator = ds - c * dz |
|
t = t_numerator / t_denominator |
|
|
|
|
|
|
|
coordinate_numerator = fo * (oo * (c * dz - ds) + do * (os - c * oz)) |
|
coordinate_denominator = dz * os - ds * oz |
|
coordinate_other = co + coordinate_numerator / coordinate_denominator |
|
coordinate_same = torch.ones_like(coordinate_other) * coordinate_value |
|
xy = [coordinate_same] |
|
xy.insert(other_dim, coordinate_other) |
|
xy = torch.stack(xy, dim=-1) |
|
xyz = origins + t[..., None] * directions |
|
|
|
|
|
|
|
return { |
|
"t": t, |
|
"xy": xy, |
|
"valid": _is_in_bounds(xy) & _is_in_front_of_camera(xyz) & _is_positive_t(t), |
|
} |
|
|
|
|
|
def _compare_projections( |
|
intersections: Iterable[PointProjection], |
|
reduction: Literal["min", "max"], |
|
) -> PointProjection: |
|
intersections = {k: v.clone() for k, v in default_collate(intersections).items()} |
|
t = intersections["t"] |
|
xy = intersections["xy"] |
|
valid = intersections["valid"] |
|
|
|
|
|
lowest_priority = { |
|
"min": torch.inf, |
|
"max": -torch.inf, |
|
}[reduction] |
|
t[~valid] = lowest_priority |
|
|
|
|
|
reduced, selector = getattr(t, reduction)(dim=0) |
|
|
|
|
|
return { |
|
"t": reduced, |
|
"xy": xy.gather(0, repeat(selector, "... -> () ... xy", xy=2))[0], |
|
"valid": valid.gather(0, selector[None])[0], |
|
} |
|
|
|
|
|
def _compute_point_projection( |
|
xyz: Float[Tensor, "*#batch 3"], |
|
t: Float[Tensor, "*#batch"], |
|
intrinsics: Float[Tensor, "*#batch 3 3"], |
|
) -> PointProjection: |
|
xy = project_camera_space(xyz, intrinsics) |
|
return { |
|
"t": t, |
|
"xy": xy, |
|
"valid": _is_in_bounds(xy) & _is_in_front_of_camera(xyz) & _is_positive_t(t), |
|
} |
|
|
|
|
|
class RaySegmentProjection(TypedDict): |
|
t_min: Float[Tensor, " *batch"] |
|
t_max: Float[Tensor, " *batch"] |
|
xy_min: Float[Tensor, "*batch 2"] |
|
xy_max: Float[Tensor, "*batch 2"] |
|
|
|
|
|
overlaps_image: Bool[Tensor, " *batch"] |
|
|
|
|
|
def project_rays( |
|
origins: Float[Tensor, "*#batch 3"], |
|
directions: Float[Tensor, "*#batch 3"], |
|
extrinsics: Float[Tensor, "*#batch 4 4"], |
|
intrinsics: Float[Tensor, "*#batch 3 3"], |
|
near: Optional[Float[Tensor, "*#batch"]] = None, |
|
far: Optional[Float[Tensor, "*#batch"]] = None, |
|
epsilon: float = 1e-6, |
|
) -> RaySegmentProjection: |
|
|
|
world_to_cam = torch.linalg.inv(extrinsics) |
|
origins = homogenize_points(origins) |
|
origins = einsum(world_to_cam, origins, "... i j, ... j -> ... i") |
|
directions = homogenize_vectors(directions) |
|
directions = einsum(world_to_cam, directions, "... i j, ... j -> ... i") |
|
origins = origins[..., :3] |
|
directions = directions[..., :3] |
|
|
|
|
|
frame_intersections = ( |
|
_intersect_image_coordinate(intrinsics, origins, directions, "x", 0.0), |
|
_intersect_image_coordinate(intrinsics, origins, directions, "x", 1.0), |
|
_intersect_image_coordinate(intrinsics, origins, directions, "y", 0.0), |
|
_intersect_image_coordinate(intrinsics, origins, directions, "y", 1.0), |
|
) |
|
frame_intersection_min = _compare_projections(frame_intersections, "min") |
|
frame_intersection_max = _compare_projections(frame_intersections, "max") |
|
|
|
if near is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
origins_for_projection = origins.clone() |
|
mask_depth_zero = origins_for_projection[..., -1] < epsilon |
|
mask_at_camera = origins_for_projection.norm(dim=-1) < epsilon |
|
origins_for_projection[mask_at_camera] = directions[mask_at_camera] |
|
projection_at_zero = _compute_point_projection( |
|
origins_for_projection, |
|
torch.zeros_like(frame_intersection_min["t"]), |
|
intrinsics, |
|
) |
|
projection_at_zero["valid"][mask_depth_zero & ~mask_at_camera] = False |
|
else: |
|
|
|
t_near = near.broadcast_to(frame_intersection_min["t"].shape) |
|
projection_at_zero = _compute_point_projection( |
|
origins + near[..., None] * directions, |
|
t_near, |
|
intrinsics, |
|
) |
|
|
|
if far is None: |
|
|
|
|
|
|
|
projection_at_infinity = _compute_point_projection( |
|
directions, |
|
torch.ones_like(frame_intersection_min["t"]) * torch.inf, |
|
intrinsics, |
|
) |
|
else: |
|
|
|
t_far = far.broadcast_to(frame_intersection_min["t"].shape) |
|
projection_at_infinity = _compute_point_projection( |
|
origins + far[..., None] * directions, |
|
t_far, |
|
intrinsics, |
|
) |
|
|
|
|
|
result = { |
|
"t_min": torch.empty_like(projection_at_zero["t"]), |
|
"t_max": torch.empty_like(projection_at_infinity["t"]), |
|
"xy_min": torch.empty_like(projection_at_zero["xy"]), |
|
"xy_max": torch.empty_like(projection_at_infinity["xy"]), |
|
"overlaps_image": torch.empty_like(projection_at_zero["valid"]), |
|
} |
|
|
|
for min_valid, max_valid in itertools.product([True, False], [True, False]): |
|
min_mask = projection_at_zero["valid"] ^ (not min_valid) |
|
max_mask = projection_at_infinity["valid"] ^ (not max_valid) |
|
mask = min_mask & max_mask |
|
min_value = projection_at_zero if min_valid else frame_intersection_min |
|
max_value = projection_at_infinity if max_valid else frame_intersection_max |
|
result["t_min"][mask] = min_value["t"][mask] |
|
result["t_max"][mask] = max_value["t"][mask] |
|
result["xy_min"][mask] = min_value["xy"][mask] |
|
result["xy_max"][mask] = max_value["xy"][mask] |
|
result["overlaps_image"][mask] = (min_value["valid"] & max_value["valid"])[mask] |
|
|
|
return result |
|
|
|
|
|
class RaySegmentProjection(TypedDict): |
|
t_min: Float[Tensor, " *batch"] |
|
t_max: Float[Tensor, " *batch"] |
|
xy_min: Float[Tensor, "*batch 2"] |
|
xy_max: Float[Tensor, "*batch 2"] |
|
|
|
|
|
overlaps_image: Bool[Tensor, " *batch"] |
|
|
|
|
|
def lift_to_3d( |
|
origins: Float[Tensor, "*#batch 3"], |
|
directions: Float[Tensor, "*#batch 3"], |
|
xy: Float[Tensor, "*#batch 2"], |
|
extrinsics: Float[Tensor, "*#batch 4 4"], |
|
intrinsics: Float[Tensor, "*#batch 3 3"], |
|
) -> Float[Tensor, "*batch 3"]: |
|
"""Calculate the 3D positions that correspond to the specified 2D points on the |
|
epipolar lines defined by the origins and directions. The extrinsics and intrinsics |
|
are for the images the 2D points lie on. |
|
""" |
|
|
|
xy_origins, xy_directions = get_world_rays(xy, extrinsics, intrinsics) |
|
return intersect_rays(origins, directions, xy_origins, xy_directions) |
|
|
|
|
|
def get_depth( |
|
origins: Float[Tensor, "*#batch 3"], |
|
directions: Float[Tensor, "*#batch 3"], |
|
xy: Float[Tensor, "*#batch 2"], |
|
extrinsics: Float[Tensor, "*#batch 4 4"], |
|
intrinsics: Float[Tensor, "*#batch 3 3"], |
|
) -> Float[Tensor, " *batch"]: |
|
"""Calculate the depths that correspond to the specified 2D points on the epipolar |
|
lines defined by the origins and directions. The extrinsics and intrinsics are for |
|
the images the 2D points lie on. |
|
""" |
|
xyz = lift_to_3d(origins, directions, xy, extrinsics, intrinsics) |
|
return (xyz - origins).norm(dim=-1) |
|
|