|
from typing import Optional, Protocol, runtime_checkable |
|
|
|
import torch |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from .types import Pair, sanitize_pair |
|
|
|
|
|
@runtime_checkable |
|
class ConversionFunction(Protocol): |
|
def __call__( |
|
self, |
|
xy: Float[Tensor, "*batch 2"], |
|
) -> Float[Tensor, "*batch 2"]: |
|
pass |
|
|
|
|
|
def generate_conversions( |
|
shape: tuple[int, int], |
|
device: torch.device, |
|
x_range: Optional[Pair] = None, |
|
y_range: Optional[Pair] = None, |
|
) -> tuple[ |
|
ConversionFunction, |
|
ConversionFunction, |
|
]: |
|
h, w = shape |
|
x_range = sanitize_pair((0, w) if x_range is None else x_range, device) |
|
y_range = sanitize_pair((0, h) if y_range is None else y_range, device) |
|
minima, maxima = torch.stack((x_range, y_range), dim=-1) |
|
wh = torch.tensor((w, h), dtype=torch.float32, device=device) |
|
|
|
def convert_world_to_pixel( |
|
xy: Float[Tensor, "*batch 2"], |
|
) -> Float[Tensor, "*batch 2"]: |
|
return (xy - minima) / (maxima - minima) * wh |
|
|
|
def convert_pixel_to_world( |
|
xy: Float[Tensor, "*batch 2"], |
|
) -> Float[Tensor, "*batch 2"]: |
|
return xy / wh * (maxima - minima) + minima |
|
|
|
return convert_world_to_pixel, convert_pixel_to_world |
|
|