File size: 1,446 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from pathlib import Path
from string import ascii_letters, digits, punctuation
import numpy as np
import torch
from einops import rearrange
from jaxtyping import Float
from PIL import Image, ImageDraw, ImageFont
from torch import Tensor
from .layout import vcat
EXPECTED_CHARACTERS = digits + punctuation + ascii_letters
def draw_label(
text: str,
font: Path,
font_size: int,
device: torch.device = torch.device("cpu"),
) -> Float[Tensor, "3 height width"]:
"""Draw a black label on a white background with no border."""
try:
font = ImageFont.truetype(str(font), font_size)
except OSError:
font = ImageFont.load_default()
left, _, right, _ = font.getbbox(text)
width = right - left
_, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS)
height = bottom - top
image = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(image)
draw.text((0, 0), text, font=font, fill="black")
image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device)
return rearrange(image, "h w c -> c h w")
def add_label(
image: Float[Tensor, "3 width height"],
label: str,
font: Path = Path("assets/Inter-Regular.otf"),
font_size: int = 24,
) -> Float[Tensor, "3 width_with_label height_with_label"]:
return vcat(
draw_label(label, font, font_size, image.device),
image,
align="left",
gap=4,
)
|