alexnasa's picture
Upload 243 files
2568013 verified
from typing import Optional
import torch
from einops import repeat
from jaxtyping import Float
from torch import Tensor
from .coordinate_conversion import generate_conversions
from .rendering import render_over_image
from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector
def draw_points(
image: Float[Tensor, "3 height width"],
points: Vector,
color: Vector = [1, 1, 1],
radius: Scalar = 1,
inner_radius: Scalar = 0,
num_msaa_passes: int = 1,
x_range: Optional[Pair] = None,
y_range: Optional[Pair] = None,
) -> Float[Tensor, "3 height width"]:
device = image.device
points = sanitize_vector(points, 2, device)
color = sanitize_vector(color, 3, device)
radius = sanitize_scalar(radius, device)
inner_radius = sanitize_scalar(inner_radius, device)
(num_points,) = torch.broadcast_shapes(
points.shape[0],
color.shape[0],
radius.shape,
inner_radius.shape,
)
# Convert world-space points to pixel space.
_, h, w = image.shape
world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range)
points = world_to_pixel(points)
def color_function(
xy: Float[Tensor, "point 2"],
) -> Float[Tensor, "point 4"]:
# Define a vector between the start and end points.
delta = xy[:, None] - points[None]
delta_norm = delta.norm(dim=-1)
mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None])
# Determine the sample's color.
selectable_color = color.broadcast_to((num_points, 3))
arrangement = mask * torch.arange(num_points, device=device)
top_color = selectable_color.gather(
dim=0,
index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3),
)
rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1)
return rgba
return render_over_image(image, color_function, device, num_passes=num_msaa_passes)