|
from pathlib import Path |
|
from string import ascii_letters, digits, punctuation |
|
|
|
import numpy as np |
|
import torch |
|
from einops import rearrange |
|
from jaxtyping import Float |
|
from PIL import Image, ImageDraw, ImageFont |
|
from torch import Tensor |
|
|
|
from .layout import vcat |
|
|
|
EXPECTED_CHARACTERS = digits + punctuation + ascii_letters |
|
|
|
|
|
def draw_label( |
|
text: str, |
|
font: Path, |
|
font_size: int, |
|
device: torch.device = torch.device("cpu"), |
|
) -> Float[Tensor, "3 height width"]: |
|
"""Draw a black label on a white background with no border.""" |
|
try: |
|
font = ImageFont.truetype(str(font), font_size) |
|
except OSError: |
|
font = ImageFont.load_default() |
|
left, _, right, _ = font.getbbox(text) |
|
width = right - left |
|
_, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) |
|
height = bottom - top |
|
image = Image.new("RGB", (width, height), color="white") |
|
draw = ImageDraw.Draw(image) |
|
draw.text((0, 0), text, font=font, fill="black") |
|
image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) |
|
return rearrange(image, "h w c -> c h w") |
|
|
|
|
|
def add_label( |
|
image: Float[Tensor, "3 width height"], |
|
label: str, |
|
font: Path = Path("assets/Inter-Regular.otf"), |
|
font_size: int = 24, |
|
) -> Float[Tensor, "3 width_with_label height_with_label"]: |
|
return vcat( |
|
draw_label(label, font, font_size, image.device), |
|
image, |
|
align="left", |
|
gap=4, |
|
) |
|
|