File size: 2,671 Bytes
229755d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
923fe1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import matplotlib.pyplot as plt
from torchvision import transforms


def plot_class_label_counts(data_loader, classes):
    class_counts = {}
    for class_name in classes:
        class_counts[class_name] = 0
    for _, batch_label in data_loader:
        for label in batch_label:
            class_counts[classes[label.item()]] += 1

    fig = plt.figure()
    plt.suptitle("Class Distribution")
    plt.bar(range(len(class_counts)), list(class_counts.values()))
    plt.xticks(range(len(class_counts)), list(class_counts.keys()), rotation=90)
    plt.tight_layout()
    plt.show()


def plot_data_samples(data_loader, classes):
    batch_data, batch_label = next(iter(data_loader))

    fig = plt.figure()
    plt.suptitle("Data Samples with Labels post Transforms")
    for i in range(12):
        plt.subplot(3, 4, i + 1)
        plt.tight_layout()
        # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
        unnormalized = transforms.Normalize(
            (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
        )(batch_data[i])
        plt.imshow(transforms.ToPILImage()(unnormalized))
        plt.title(
            classes[batch_label[i].item()],
        )

        plt.xticks([])
        plt.yticks([])


def plot_model_training_curves(train_accs, test_accs, train_losses, test_losses):
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))
    axs[0, 0].plot(train_losses)
    axs[0, 0].set_title("Training Loss")
    axs[1, 0].plot(train_accs)
    axs[1, 0].set_title("Training Accuracy")
    axs[0, 1].plot(test_losses)
    axs[0, 1].set_title("Test Loss")
    axs[1, 1].plot(test_accs)
    axs[1, 1].set_title("Test Accuracy")
    plt.plot()


def plot_incorrect_preds(incorrect, classes, num_imgs):
    # num_imgs is a multiple of 5
    assert num_imgs % 5 == 0
    assert len(incorrect) >= num_imgs

    # incorrect (data, target, pred, output)
    print(f"Total Incorrect Predictions {len(incorrect)}")
    fig = plt.figure(figsize=(10, num_imgs // 2))
    plt.suptitle("Target | Predicted Label")
    for i in range(num_imgs):
        plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")

        # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
        unnormalized = transforms.Normalize(
            (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
        )(incorrect[i][0])
        plt.imshow(transforms.ToPILImage()(unnormalized))
        plt.title(
            f"{classes[incorrect[i][1].item()]}|{classes[incorrect[i][2].item()]}",
            # fontsize=8,
        )
        plt.xticks([])
        plt.yticks([])
    plt.tight_layout()