from pathlib import Path from string import ascii_letters, digits, punctuation import numpy as np import torch from einops import rearrange from jaxtyping import Float from PIL import Image, ImageDraw, ImageFont from torch import Tensor from .layout import vcat EXPECTED_CHARACTERS = digits + punctuation + ascii_letters def draw_label( text: str, font: Path, font_size: int, device: torch.device = torch.device("cpu"), ) -> Float[Tensor, "3 height width"]: """Draw a black label on a white background with no border.""" try: font = ImageFont.truetype(str(font), font_size) except OSError: font = ImageFont.load_default() left, _, right, _ = font.getbbox(text) width = right - left _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) height = bottom - top image = Image.new("RGB", (width, height), color="white") draw = ImageDraw.Draw(image) draw.text((0, 0), text, font=font, fill="black") image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) return rearrange(image, "h w c -> c h w") def add_label( image: Float[Tensor, "3 width height"], label: str, font: Path = Path("assets/Inter-Regular.otf"), font_size: int = 24, ) -> Float[Tensor, "3 width_with_label height_with_label"]: return vcat( draw_label(label, font, font_size, image.device), image, align="left", gap=4, )