AnySplat / src /visualization /drawing /coordinate_conversion.py
alexnasa's picture
Upload 243 files
2568013 verified
from typing import Optional, Protocol, runtime_checkable
import torch
from jaxtyping import Float
from torch import Tensor
from .types import Pair, sanitize_pair
@runtime_checkable
class ConversionFunction(Protocol):
def __call__(
self,
xy: Float[Tensor, "*batch 2"],
) -> Float[Tensor, "*batch 2"]:
pass
def generate_conversions(
shape: tuple[int, int],
device: torch.device,
x_range: Optional[Pair] = None,
y_range: Optional[Pair] = None,
) -> tuple[
ConversionFunction, # conversion from world coordinates to pixel coordinates
ConversionFunction, # conversion from pixel coordinates to world coordinates
]:
h, w = shape
x_range = sanitize_pair((0, w) if x_range is None else x_range, device)
y_range = sanitize_pair((0, h) if y_range is None else y_range, device)
minima, maxima = torch.stack((x_range, y_range), dim=-1)
wh = torch.tensor((w, h), dtype=torch.float32, device=device)
def convert_world_to_pixel(
xy: Float[Tensor, "*batch 2"],
) -> Float[Tensor, "*batch 2"]:
return (xy - minima) / (maxima - minima) * wh
def convert_pixel_to_world(
xy: Float[Tensor, "*batch 2"],
) -> Float[Tensor, "*batch 2"]:
return xy / wh * (maxima - minima) + minima
return convert_world_to_pixel, convert_pixel_to_world