File size: 2,671 Bytes
229755d 923fe1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import matplotlib.pyplot as plt
from torchvision import transforms
def plot_class_label_counts(data_loader, classes):
class_counts = {}
for class_name in classes:
class_counts[class_name] = 0
for _, batch_label in data_loader:
for label in batch_label:
class_counts[classes[label.item()]] += 1
fig = plt.figure()
plt.suptitle("Class Distribution")
plt.bar(range(len(class_counts)), list(class_counts.values()))
plt.xticks(range(len(class_counts)), list(class_counts.keys()), rotation=90)
plt.tight_layout()
plt.show()
def plot_data_samples(data_loader, classes):
batch_data, batch_label = next(iter(data_loader))
fig = plt.figure()
plt.suptitle("Data Samples with Labels post Transforms")
for i in range(12):
plt.subplot(3, 4, i + 1)
plt.tight_layout()
# unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
unnormalized = transforms.Normalize(
(-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
)(batch_data[i])
plt.imshow(transforms.ToPILImage()(unnormalized))
plt.title(
classes[batch_label[i].item()],
)
plt.xticks([])
plt.yticks([])
def plot_model_training_curves(train_accs, test_accs, train_losses, test_losses):
fig, axs = plt.subplots(2, 2, figsize=(15, 10))
axs[0, 0].plot(train_losses)
axs[0, 0].set_title("Training Loss")
axs[1, 0].plot(train_accs)
axs[1, 0].set_title("Training Accuracy")
axs[0, 1].plot(test_losses)
axs[0, 1].set_title("Test Loss")
axs[1, 1].plot(test_accs)
axs[1, 1].set_title("Test Accuracy")
plt.plot()
def plot_incorrect_preds(incorrect, classes, num_imgs):
# num_imgs is a multiple of 5
assert num_imgs % 5 == 0
assert len(incorrect) >= num_imgs
# incorrect (data, target, pred, output)
print(f"Total Incorrect Predictions {len(incorrect)}")
fig = plt.figure(figsize=(10, num_imgs // 2))
plt.suptitle("Target | Predicted Label")
for i in range(num_imgs):
plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")
# unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
unnormalized = transforms.Normalize(
(-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
)(incorrect[i][0])
plt.imshow(transforms.ToPILImage()(unnormalized))
plt.title(
f"{classes[incorrect[i][1].item()]}|{classes[incorrect[i][2].item()]}",
# fontsize=8,
)
plt.xticks([])
plt.yticks([])
plt.tight_layout() |