|
from typing import Protocol, runtime_checkable |
|
|
|
import torch |
|
from einops import rearrange, reduce |
|
from jaxtyping import Bool, Float |
|
from torch import Tensor |
|
|
|
|
|
@runtime_checkable |
|
class ColorFunction(Protocol): |
|
def __call__( |
|
self, |
|
xy: Float[Tensor, "point 2"], |
|
) -> Float[Tensor, "point 4"]: |
|
pass |
|
|
|
|
|
def generate_sample_grid( |
|
shape: tuple[int, int], |
|
device: torch.device, |
|
) -> Float[Tensor, "height width 2"]: |
|
h, w = shape |
|
x = torch.arange(w, device=device) + 0.5 |
|
y = torch.arange(h, device=device) + 0.5 |
|
x, y = torch.meshgrid(x, y, indexing="xy") |
|
return torch.stack([x, y], dim=-1) |
|
|
|
|
|
def detect_msaa_pixels( |
|
image: Float[Tensor, "batch 4 height width"], |
|
) -> Bool[Tensor, "batch height width"]: |
|
b, _, h, w = image.shape |
|
|
|
mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) |
|
|
|
|
|
horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) |
|
mask[:, :, 1:] |= horizontal |
|
mask[:, :, :-1] |= horizontal |
|
|
|
|
|
vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) |
|
mask[:, 1:, :] |= vertical |
|
mask[:, :-1, :] |= vertical |
|
|
|
|
|
tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) |
|
mask[:, 1:, 1:] |= tlbr |
|
mask[:, :-1, :-1] |= tlbr |
|
|
|
|
|
trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) |
|
mask[:, :-1, 1:] |= trbl |
|
mask[:, 1:, :-1] |= trbl |
|
|
|
return mask |
|
|
|
|
|
def reduce_straight_alpha( |
|
rgba: Float[Tensor, "batch 4 height width"], |
|
) -> Float[Tensor, "batch 4"]: |
|
color, alpha = rgba.split((3, 1), dim=1) |
|
|
|
|
|
weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") |
|
alpha_sum = reduce(alpha, "b c h w -> b c", "sum") |
|
color = weighted_color / (alpha_sum + 1e-10) |
|
|
|
|
|
alpha = reduce(alpha, "b c h w -> b c", "mean") |
|
|
|
return torch.cat((color, alpha), dim=-1) |
|
|
|
|
|
@torch.no_grad() |
|
def run_msaa_pass( |
|
xy: Float[Tensor, "batch height width 2"], |
|
color_function: ColorFunction, |
|
scale: float, |
|
subdivision: int, |
|
remaining_passes: int, |
|
device: torch.device, |
|
batch_size: int = int(2**16), |
|
) -> Float[Tensor, "batch 4 height width"]: |
|
|
|
b, h, w, _ = xy.shape |
|
color = [ |
|
color_function(batch) |
|
for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) |
|
] |
|
color = torch.cat(color, dim=0) |
|
color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) |
|
|
|
|
|
if remaining_passes > 0: |
|
mask = detect_msaa_pixels(color) |
|
batch_index, row_index, col_index = torch.where(mask) |
|
xy = xy[batch_index, row_index, col_index] |
|
|
|
offsets = generate_sample_grid((subdivision, subdivision), device) |
|
offsets = (offsets / subdivision - 0.5) * scale |
|
|
|
color_fine = run_msaa_pass( |
|
xy[:, None, None] + offsets, |
|
color_function, |
|
scale / subdivision, |
|
subdivision, |
|
remaining_passes - 1, |
|
device, |
|
batch_size=batch_size, |
|
) |
|
color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) |
|
|
|
return color |
|
|
|
|
|
@torch.no_grad() |
|
def render( |
|
shape: tuple[int, int], |
|
color_function: ColorFunction, |
|
device: torch.device, |
|
subdivision: int = 8, |
|
num_passes: int = 2, |
|
) -> Float[Tensor, "4 height width"]: |
|
xy = generate_sample_grid(shape, device) |
|
return run_msaa_pass( |
|
xy[None], |
|
color_function, |
|
1.0, |
|
subdivision, |
|
num_passes, |
|
device, |
|
)[0] |
|
|
|
|
|
def render_over_image( |
|
image: Float[Tensor, "3 height width"], |
|
color_function: ColorFunction, |
|
device: torch.device, |
|
subdivision: int = 8, |
|
num_passes: int = 1, |
|
) -> Float[Tensor, "3 height width"]: |
|
_, h, w = image.shape |
|
overlay = render( |
|
(h, w), |
|
color_function, |
|
device, |
|
subdivision=subdivision, |
|
num_passes=num_passes, |
|
) |
|
color, alpha = overlay.split((3, 1), dim=0) |
|
return image * (1 - alpha) + color * alpha |
|
|