File size: 6,752 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
"""This file contains useful layout utilities for images. They are:
- add_border: Add a border to an image.
- cat/hcat/vcat: Join images by arranging them in a line. If the images have different
sizes, they are aligned as specified (start, end, center). Allows you to specify a gap
between images.
Images are assumed to be float32 tensors with shape (channel, height, width).
"""
from typing import Any, Generator, Iterable, Literal, Optional, Union
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
Alignment = Literal["start", "center", "end"]
Axis = Literal["horizontal", "vertical"]
Color = Union[
int,
float,
Iterable[int],
Iterable[float],
Float[Tensor, "#channel"],
Float[Tensor, ""],
]
def _sanitize_color(color: Color) -> Float[Tensor, "#channel"]:
# Convert tensor to list (or individual item).
if isinstance(color, torch.Tensor):
color = color.tolist()
# Turn iterators and individual items into lists.
if isinstance(color, Iterable):
color = list(color)
else:
color = [color]
return torch.tensor(color, dtype=torch.float32)
def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]:
it = iter(iterable)
yield next(it)
for item in it:
yield delimiter
yield item
def _get_main_dim(main_axis: Axis) -> int:
return {
"horizontal": 2,
"vertical": 1,
}[main_axis]
def _get_cross_dim(main_axis: Axis) -> int:
return {
"horizontal": 1,
"vertical": 2,
}[main_axis]
def _compute_offset(base: int, overlay: int, align: Alignment) -> slice:
assert base >= overlay
offset = {
"start": 0,
"center": (base - overlay) // 2,
"end": base - overlay,
}[align]
return slice(offset, offset + overlay)
def overlay(
base: Float[Tensor, "channel base_height base_width"],
overlay: Float[Tensor, "channel overlay_height overlay_width"],
main_axis: Axis,
main_axis_alignment: Alignment,
cross_axis_alignment: Alignment,
) -> Float[Tensor, "channel base_height base_width"]:
# The overlay must be smaller than the base.
_, base_height, base_width = base.shape
_, overlay_height, overlay_width = overlay.shape
assert base_height >= overlay_height and base_width >= overlay_width
# Compute spacing on the main dimension.
main_dim = _get_main_dim(main_axis)
main_slice = _compute_offset(
base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment
)
# Compute spacing on the cross dimension.
cross_dim = _get_cross_dim(main_axis)
cross_slice = _compute_offset(
base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment
)
# Combine the slices and paste the overlay onto the base accordingly.
selector = [..., None, None]
selector[main_dim] = main_slice
selector[cross_dim] = cross_slice
result = base.clone()
result[selector] = overlay
return result
def cat(
main_axis: Axis,
*images: Iterable[Float[Tensor, "channel _ _"]],
align: Alignment = "center",
gap: int = 8,
gap_color: Color = 1,
) -> Float[Tensor, "channel height width"]:
"""Arrange images in a line. The interface resembles a CSS div with flexbox."""
device = images[0].device
gap_color = _sanitize_color(gap_color).to(device)
# Find the maximum image side length in the cross axis dimension.
cross_dim = _get_cross_dim(main_axis)
cross_axis_length = max(image.shape[cross_dim] for image in images)
# Pad the images.
padded_images = []
for image in images:
# Create an empty image with the correct size.
padded_shape = list(image.shape)
padded_shape[cross_dim] = cross_axis_length
base = torch.ones(padded_shape, dtype=torch.float32, device=device)
base = base * gap_color[:, None, None]
padded_images.append(overlay(base, image, main_axis, "start", align))
# Intersperse separators if necessary.
if gap > 0:
# Generate a separator.
c, _, _ = images[0].shape
separator_size = [gap, gap]
separator_size[cross_dim - 1] = cross_axis_length
separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device)
separator = separator * gap_color[:, None, None]
# Intersperse the separator between the images.
padded_images = list(_intersperse(padded_images, separator))
return torch.cat(padded_images, dim=_get_main_dim(main_axis))
def hcat(
*images: Iterable[Float[Tensor, "channel _ _"]],
align: Literal["start", "center", "end", "top", "bottom"] = "start",
gap: int = 8,
gap_color: Color = 1,
):
"""Shorthand for a horizontal linear concatenation."""
return cat(
"horizontal",
*images,
align={
"start": "start",
"center": "center",
"end": "end",
"top": "start",
"bottom": "end",
}[align],
gap=gap,
gap_color=gap_color,
)
def vcat(
*images: Iterable[Float[Tensor, "channel _ _"]],
align: Literal["start", "center", "end", "left", "right"] = "start",
gap: int = 8,
gap_color: Color = 1,
):
"""Shorthand for a horizontal linear concatenation."""
return cat(
"vertical",
*images,
align={
"start": "start",
"center": "center",
"end": "end",
"left": "start",
"right": "end",
}[align],
gap=gap,
gap_color=gap_color,
)
def add_border(
image: Float[Tensor, "channel height width"],
border: int = 8,
color: Color = 1,
) -> Float[Tensor, "channel new_height new_width"]:
color = _sanitize_color(color).to(image)
c, h, w = image.shape
result = torch.empty(
(c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device
)
result[:] = color[:, None, None]
result[:, border : h + border, border : w + border] = image
return result
def resize(
image: Float[Tensor, "channel height width"],
shape: Optional[tuple[int, int]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> Float[Tensor, "channel new_height new_width"]:
assert (shape is not None) + (width is not None) + (height is not None) == 1
_, h, w = image.shape
if width is not None:
shape = (int(h * width / w), width)
elif height is not None:
shape = (height, int(w * height / h))
return F.interpolate(
image[None],
shape,
mode="bilinear",
align_corners=False,
antialias="bilinear",
)[0]
|