Spaces:
Runtime error
Runtime error
from matplotlib import pyplot as plt | |
import numpy as np | |
import torchvision | |
def imshow(dataloader, title=None): | |
inputs, _ = next(iter(dataloader)) | |
out = torchvision.utils.make_grid(inputs) | |
inp = out.numpy().transpose((1, 2, 0)) | |
mean = np.array([0.485, 0.456, 0.406]) | |
std = np.array([0.229, 0.224, 0.225]) | |
inp = std * inp + mean | |
inp = np.clip(inp, 0, 1) | |
plt.imshow(inp) | |
if title is not None: | |
plt.title(title) | |
plt.show() | |
plt.pause(0.001) # pause a bit so that plots are updated | |