Spaces:
Sleeping
Sleeping
import numpy as np | |
import math | |
import PIL | |
def postprocess(x): | |
"""[0,1] to uint8.""" | |
x = np.clip(255 * x, 0, 255) | |
x = np.cast[np.uint8](x) | |
return x | |
def tile(X, rows, cols): | |
"""Tile images for display.""" | |
tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) | |
for i in range(rows): | |
for j in range(cols): | |
idx = i * cols + j | |
if idx < X.shape[0]: | |
img = X[idx,...] | |
tiling[ | |
i*X.shape[1]:(i+1)*X.shape[1], | |
j*X.shape[2]:(j+1)*X.shape[2], | |
:] = img | |
return tiling | |
def plot_batch(X, out_path): | |
"""Save batch of images tiled.""" | |
n_channels = X.shape[3] | |
if n_channels > 3: | |
X = X[:,:,:,np.random.choice(n_channels, size = 3)] | |
X = postprocess(X) | |
rc = math.sqrt(X.shape[0]) | |
rows = cols = math.ceil(rc) | |
canvas = tile(X, rows, cols) | |
canvas = np.squeeze(canvas) | |
PIL.Image.fromarray(canvas).save(out_path) |