alexnasa's picture
Upload 243 files
2568013 verified
from typing import Optional
import torch
from einops import einsum, rearrange, repeat
from jaxtyping import Float
from torch import Tensor
from ...geometry.projection import unproject
from ..annotation import add_label
from .lines import draw_lines
from .types import Scalar, sanitize_scalar
def draw_cameras(
resolution: int,
extrinsics: Float[Tensor, "batch 4 4"],
intrinsics: Float[Tensor, "batch 3 3"],
color: Float[Tensor, "batch 3"],
near: Optional[Scalar] = None,
far: Optional[Scalar] = None,
margin: float = 0.1, # relative to AABB
frustum_scale: float = 0.05, # relative to image resolution
) -> Float[Tensor, "3 3 height width"]:
device = extrinsics.device
# Compute scene bounds.
minima, maxima = compute_aabb(extrinsics, intrinsics, near, far)
scene_minima, scene_maxima = compute_equal_aabb_with_margin(
minima, maxima, margin=margin
)
span = (scene_maxima - scene_minima).max()
# Compute frustum locations.
corner_depth = (span * frustum_scale)[None]
frustum_corners = unproject_frustum_corners(extrinsics, intrinsics, corner_depth)
if near is not None:
near_corners = unproject_frustum_corners(extrinsics, intrinsics, near)
if far is not None:
far_corners = unproject_frustum_corners(extrinsics, intrinsics, far)
# Project the cameras onto each axis-aligned plane.
projections = []
for projected_axis in range(3):
image = torch.zeros(
(3, resolution, resolution),
dtype=torch.float32,
device=device,
)
image_x_axis = (projected_axis + 1) % 3
image_y_axis = (projected_axis + 2) % 3
def project(points: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch 2"]:
x = points[..., image_x_axis]
y = points[..., image_y_axis]
return torch.stack([x, y], dim=-1)
x_range, y_range = torch.stack(
(project(scene_minima), project(scene_maxima)), dim=-1
)
# Draw near and far planes.
if near is not None:
projected_near_corners = project(near_corners)
image = draw_lines(
image,
rearrange(projected_near_corners, "b p xy -> (b p) xy"),
rearrange(projected_near_corners.roll(1, 1), "b p xy -> (b p) xy"),
color=0.25,
width=2,
x_range=x_range,
y_range=y_range,
)
if far is not None:
projected_far_corners = project(far_corners)
image = draw_lines(
image,
rearrange(projected_far_corners, "b p xy -> (b p) xy"),
rearrange(projected_far_corners.roll(1, 1), "b p xy -> (b p) xy"),
color=0.25,
width=2,
x_range=x_range,
y_range=y_range,
)
if near is not None and far is not None:
image = draw_lines(
image,
rearrange(projected_near_corners, "b p xy -> (b p) xy"),
rearrange(projected_far_corners, "b p xy -> (b p) xy"),
color=0.25,
width=2,
x_range=x_range,
y_range=y_range,
)
# Draw the camera frustums themselves.
projected_origins = project(extrinsics[:, :3, 3])
projected_frustum_corners = project(frustum_corners)
start = [
repeat(projected_origins, "b xy -> (b p) xy", p=4),
rearrange(projected_frustum_corners.roll(1, 1), "b p xy -> (b p) xy"),
]
start = rearrange(torch.cat(start, dim=0), "(r b p) xy -> (b r p) xy", r=2, p=4)
image = draw_lines(
image,
start,
repeat(projected_frustum_corners, "b p xy -> (b r p) xy", r=2),
color=repeat(color, "b c -> (b r p) c", r=2, p=4),
width=2,
x_range=x_range,
y_range=y_range,
)
x_name = "XYZ"[image_x_axis]
y_name = "XYZ"[image_y_axis]
image = add_label(image, f"{x_name}{y_name} Projection")
# TODO: Draw axis indicators.
projections.append(image)
return torch.stack(projections)
def compute_aabb(
extrinsics: Float[Tensor, "batch 4 4"],
intrinsics: Float[Tensor, "batch 3 3"],
near: Optional[Scalar] = None,
far: Optional[Scalar] = None,
) -> tuple[
Float[Tensor, "3"], # minima of the scene
Float[Tensor, "3"], # maxima of the scene
]:
"""Compute an axis-aligned bounding box for the camera frustums."""
device = extrinsics.device
# These points are included in the AABB.
points = [extrinsics[:, :3, 3]]
if near is not None:
near = sanitize_scalar(near, device)
corners = unproject_frustum_corners(extrinsics, intrinsics, near)
points.append(rearrange(corners, "b p xyz -> (b p) xyz"))
if far is not None:
far = sanitize_scalar(far, device)
corners = unproject_frustum_corners(extrinsics, intrinsics, far)
points.append(rearrange(corners, "b p xyz -> (b p) xyz"))
points = torch.cat(points, dim=0)
return points.min(dim=0).values, points.max(dim=0).values
def compute_equal_aabb_with_margin(
minima: Float[Tensor, "*#batch 3"],
maxima: Float[Tensor, "*#batch 3"],
margin: float = 0.1,
) -> tuple[
Float[Tensor, "*batch 3"], # minima of the scene
Float[Tensor, "*batch 3"], # maxima of the scene
]:
midpoint = (maxima + minima) * 0.5
span = (maxima - minima).max() * (1 + margin)
scene_minima = midpoint - 0.5 * span
scene_maxima = midpoint + 0.5 * span
return scene_minima, scene_maxima
def unproject_frustum_corners(
extrinsics: Float[Tensor, "batch 4 4"],
intrinsics: Float[Tensor, "batch 3 3"],
depth: Float[Tensor, "#batch"],
) -> Float[Tensor, "batch 4 3"]:
device = extrinsics.device
# Get coordinates for the corners. Following them in a circle makes a rectangle.
xy = torch.linspace(0, 1, 2, device=device)
xy = torch.stack(torch.meshgrid(xy, xy, indexing="xy"), dim=-1)
xy = rearrange(xy, "i j xy -> (i j) xy")
xy = xy[torch.tensor([0, 1, 3, 2], device=device)]
# Get ray directions in camera space.
directions = unproject(
xy,
torch.ones(1, dtype=torch.float32, device=device),
rearrange(intrinsics, "b i j -> b () i j"),
)
# Divide by the z coordinate so that multiplying by depth will produce orthographic
# depth (z depth) as opposed to Euclidean depth (distance from the camera).
directions = directions / directions[..., -1:]
directions = einsum(extrinsics[..., :3, :3], directions, "b i j, b r j -> b r i")
origins = rearrange(extrinsics[:, :3, 3], "b xyz -> b () xyz")
depth = rearrange(depth, "b -> b () ()")
return origins + depth * directions