File size: 2,169 Bytes
e18f153
 
 
 
1843265
 
 
 
 
 
 
 
 
 
e18f153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from typing import List
import io

def fig2img(fig: plt.Figure):
  """Convert a Matplotlib figure to a PIL Image and return it"""
  plt.close()
  buf = io.BytesIO()
  fig.savefig(buf)
  buf.seek(0)
  img = Image.open(buf)
  return img

def show_tile_images(
    images: List[np.ndarray | Image.Image],
    width_parts: int,
    figsize = None,
    space = 0.0,
    pad = False,
    figcolor = 'white',
    titles: List[str] = None,
    title_color: str = None,
    title_background_color: str = None,
    title_font_size: int = None):
  '''
  Show images in a tile format
  Args:
    images: List of images to show
    width_parts: Number of images to show in a row
    figsize: Size of the figure
    space: Space between images
    pad: Whether to pad the images or not
    figcolor: Background color of the figure
    titles: Titles of the images
    title_color: Color of the title
    title_background_color: Background color of the title
    title_font_size: Font size of the title
  Returns:
    Image: Image of the figure
  '''
  height = int(np.ceil(len(images) / width_parts))
  fig, axs = plt.subplots(height, width_parts, figsize=figsize if figsize != None else (8 * 2, 12 * height))
  fig.patch.set_facecolor(figcolor)
  axes = axs.flatten() if isinstance(axs, np.ndarray) else [axs]
  titles = (titles or []) + np.full(len(images) - len(titles or []), None).tolist()
  for img, ax, title in zip(images, axes, titles):
    if title:
      params = {k: v for k, v in { 'color': title_color, 'backgroundcolor': title_background_color, 'fontsize': title_font_size }.items() if v is not None}
      ax.set_title(title, **params)
    ax.imshow(img.convert("RGB") if not isinstance(img, np.ndarray) else img)
    ax.axis('off')
  if pad:
    fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=space, hspace=space)
    fig.tight_layout(h_pad=space, w_pad = space, pad = space)
  else:
    fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=space, hspace=space)
    fig.tight_layout(h_pad=space, w_pad = space, pad = 0)
  plt.margins()
  plt.close()
  return fig2img(fig)