alexnasa's picture
Upload 243 files
2568013 verified
from typing import Protocol, runtime_checkable
import torch
from einops import rearrange, reduce
from jaxtyping import Bool, Float
from torch import Tensor
@runtime_checkable
class ColorFunction(Protocol):
def __call__(
self,
xy: Float[Tensor, "point 2"],
) -> Float[Tensor, "point 4"]: # RGBA color
pass
def generate_sample_grid(
shape: tuple[int, int],
device: torch.device,
) -> Float[Tensor, "height width 2"]:
h, w = shape
x = torch.arange(w, device=device) + 0.5
y = torch.arange(h, device=device) + 0.5
x, y = torch.meshgrid(x, y, indexing="xy")
return torch.stack([x, y], dim=-1)
def detect_msaa_pixels(
image: Float[Tensor, "batch 4 height width"],
) -> Bool[Tensor, "batch height width"]:
b, _, h, w = image.shape
mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device)
# Detect horizontal differences.
horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1)
mask[:, :, 1:] |= horizontal
mask[:, :, :-1] |= horizontal
# Detect vertical differences.
vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1)
mask[:, 1:, :] |= vertical
mask[:, :-1, :] |= vertical
# Detect diagonal (top left to bottom right) differences.
tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1)
mask[:, 1:, 1:] |= tlbr
mask[:, :-1, :-1] |= tlbr
# Detect diagonal (top right to bottom left) differences.
trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1)
mask[:, :-1, 1:] |= trbl
mask[:, 1:, :-1] |= trbl
return mask
def reduce_straight_alpha(
rgba: Float[Tensor, "batch 4 height width"],
) -> Float[Tensor, "batch 4"]:
color, alpha = rgba.split((3, 1), dim=1)
# Color becomes a weighted average of color (weighted by alpha).
weighted_color = reduce(color * alpha, "b c h w -> b c", "sum")
alpha_sum = reduce(alpha, "b c h w -> b c", "sum")
color = weighted_color / (alpha_sum + 1e-10)
# Alpha becomes mean alpha.
alpha = reduce(alpha, "b c h w -> b c", "mean")
return torch.cat((color, alpha), dim=-1)
@torch.no_grad()
def run_msaa_pass(
xy: Float[Tensor, "batch height width 2"],
color_function: ColorFunction,
scale: float,
subdivision: int,
remaining_passes: int,
device: torch.device,
batch_size: int = int(2**16),
) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha)
# Sample the color function.
b, h, w, _ = xy.shape
color = [
color_function(batch)
for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size)
]
color = torch.cat(color, dim=0)
color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w)
# If any MSAA passes remain, subdivide.
if remaining_passes > 0:
mask = detect_msaa_pixels(color)
batch_index, row_index, col_index = torch.where(mask)
xy = xy[batch_index, row_index, col_index]
offsets = generate_sample_grid((subdivision, subdivision), device)
offsets = (offsets / subdivision - 0.5) * scale
color_fine = run_msaa_pass(
xy[:, None, None] + offsets,
color_function,
scale / subdivision,
subdivision,
remaining_passes - 1,
device,
batch_size=batch_size,
)
color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine)
return color
@torch.no_grad()
def render(
shape: tuple[int, int],
color_function: ColorFunction,
device: torch.device,
subdivision: int = 8,
num_passes: int = 2,
) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha)
xy = generate_sample_grid(shape, device)
return run_msaa_pass(
xy[None],
color_function,
1.0,
subdivision,
num_passes,
device,
)[0]
def render_over_image(
image: Float[Tensor, "3 height width"],
color_function: ColorFunction,
device: torch.device,
subdivision: int = 8,
num_passes: int = 1,
) -> Float[Tensor, "3 height width"]:
_, h, w = image.shape
overlay = render(
(h, w),
color_function,
device,
subdivision=subdivision,
num_passes=num_passes,
)
color, alpha = overlay.split((3, 1), dim=0)
return image * (1 - alpha) + color * alpha