ossaili
na
9b43cf7
raw
history blame contribute delete
530 Bytes
from matplotlib import pyplot as plt
import numpy as np
import torchvision
def imshow(dataloader, title=None):
inputs, _ = next(iter(dataloader))
out = torchvision.utils.make_grid(inputs)
inp = out.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.show()
plt.pause(0.001) # pause a bit so that plots are updated