File size: 4,417 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Protocol, runtime_checkable

import torch
from einops import rearrange, reduce
from jaxtyping import Bool, Float
from torch import Tensor


@runtime_checkable
class ColorFunction(Protocol):
    def __call__(
        self,
        xy: Float[Tensor, "point 2"],
    ) -> Float[Tensor, "point 4"]:  # RGBA color
        pass


def generate_sample_grid(
    shape: tuple[int, int],
    device: torch.device,
) -> Float[Tensor, "height width 2"]:
    h, w = shape
    x = torch.arange(w, device=device) + 0.5
    y = torch.arange(h, device=device) + 0.5
    x, y = torch.meshgrid(x, y, indexing="xy")
    return torch.stack([x, y], dim=-1)


def detect_msaa_pixels(
    image: Float[Tensor, "batch 4 height width"],
) -> Bool[Tensor, "batch height width"]:
    b, _, h, w = image.shape

    mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device)

    # Detect horizontal differences.
    horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1)
    mask[:, :, 1:] |= horizontal
    mask[:, :, :-1] |= horizontal

    # Detect vertical differences.
    vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1)
    mask[:, 1:, :] |= vertical
    mask[:, :-1, :] |= vertical

    # Detect diagonal (top left to bottom right) differences.
    tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1)
    mask[:, 1:, 1:] |= tlbr
    mask[:, :-1, :-1] |= tlbr

    # Detect diagonal (top right to bottom left) differences.
    trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1)
    mask[:, :-1, 1:] |= trbl
    mask[:, 1:, :-1] |= trbl

    return mask


def reduce_straight_alpha(
    rgba: Float[Tensor, "batch 4 height width"],
) -> Float[Tensor, "batch 4"]:
    color, alpha = rgba.split((3, 1), dim=1)

    # Color becomes a weighted average of color (weighted by alpha).
    weighted_color = reduce(color * alpha, "b c h w -> b c", "sum")
    alpha_sum = reduce(alpha, "b c h w -> b c", "sum")
    color = weighted_color / (alpha_sum + 1e-10)

    # Alpha becomes mean alpha.
    alpha = reduce(alpha, "b c h w -> b c", "mean")

    return torch.cat((color, alpha), dim=-1)


@torch.no_grad()
def run_msaa_pass(
    xy: Float[Tensor, "batch height width 2"],
    color_function: ColorFunction,
    scale: float,
    subdivision: int,
    remaining_passes: int,
    device: torch.device,
    batch_size: int = int(2**16),
) -> Float[Tensor, "batch 4 height width"]:  # color (RGBA with straight alpha)
    # Sample the color function.
    b, h, w, _ = xy.shape
    color = [
        color_function(batch)
        for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size)
    ]
    color = torch.cat(color, dim=0)
    color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w)

    # If any MSAA passes remain, subdivide.
    if remaining_passes > 0:
        mask = detect_msaa_pixels(color)
        batch_index, row_index, col_index = torch.where(mask)
        xy = xy[batch_index, row_index, col_index]

        offsets = generate_sample_grid((subdivision, subdivision), device)
        offsets = (offsets / subdivision - 0.5) * scale

        color_fine = run_msaa_pass(
            xy[:, None, None] + offsets,
            color_function,
            scale / subdivision,
            subdivision,
            remaining_passes - 1,
            device,
            batch_size=batch_size,
        )
        color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine)

    return color


@torch.no_grad()
def render(
    shape: tuple[int, int],
    color_function: ColorFunction,
    device: torch.device,
    subdivision: int = 8,
    num_passes: int = 2,
) -> Float[Tensor, "4 height width"]:  # color (RGBA with straight alpha)
    xy = generate_sample_grid(shape, device)
    return run_msaa_pass(
        xy[None],
        color_function,
        1.0,
        subdivision,
        num_passes,
        device,
    )[0]


def render_over_image(
    image: Float[Tensor, "3 height width"],
    color_function: ColorFunction,
    device: torch.device,
    subdivision: int = 8,
    num_passes: int = 1,
) -> Float[Tensor, "3 height width"]:
    _, h, w = image.shape
    overlay = render(
        (h, w),
        color_function,
        device,
        subdivision=subdivision,
        num_passes=num_passes,
    )
    color, alpha = overlay.split((3, 1), dim=0)
    return image * (1 - alpha) + color * alpha