File size: 4,396 Bytes
1ffed57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d35fce
1ffed57
 
 
a534ad2
 
1ffed57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aa42d5
1ffed57
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import math
import numpy as np
import albumentations
import matplotlib.pyplot as plt

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

CLASS_NAMES= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

def get_inv_transforms():
    """Method to get transform to inverse the effect of normalization for ploting

    Returns:
        _Object: Object to apply image augmentations
    """
    # Normalize image
    inv_transforms = albumentations.Normalize([-0.48215841/0.24348513, -0.44653091/0.26158784, -0.49139968/0.24703223],
                                              [1/0.24348513, 1/0.26158784, 1/0.24703223], max_pixel_value=1.0)
    return inv_transforms

def plot_samples(train_loader, number_of_images):
    """Method to plot samples of augmented images

    Args:
        train_loader (Object): Object of data loader class to get images
    """
    inv_transform = get_inv_transforms()

    figure = plt.figure()
    x_count = 5
    y_count = 1 if number_of_images <= 5 else math.floor(number_of_images / x_count)
    images, labels = next(iter(train_loader))

    for index in range(1, number_of_images + 1):
        plt.subplot(y_count, x_count, index)
        plt.title(CLASS_NAMES[labels[index].numpy()])
        plt.axis('off')
        image = np.array(images[index])
        image = np.transpose(image, (1, 2, 0))
        image = inv_transform(image=image)['image']
        plt.imshow(image)

def display_cifar_misclassified_data(data: list,
                                     number_of_samples: int = 10):
    """
    Function to plot images with labels
    :param data: List[Tuple(image, label)]
    :param number_of_samples: Number of images to print
    """
    fig = plt.figure(figsize=(10, 10))
    inv_transform = get_inv_transforms()

    x_count = 5
    y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)

    for i in range(number_of_samples):
        plt.subplot(y_count, x_count, i + 1)
        img = np.array(data[i][0].squeeze().to('cpu'))
        img = np.transpose(img, (1, 2, 0))
        img = inv_transform(image=img)['image']
        plt.imshow(img)
        plt.title(r'Pred: ' + CLASS_NAMES[data[i][2].item()])
        plt.xticks([])
        plt.yticks([])

    return fig

def display_gradcam_output(data: list,
                           model,
                           target_layers,
                           targets=None,
                           number_of_samples: int = 10,
                           transparency: float = 0.60):
    """
    Function to visualize GradCam output on the data
    :param data: List[Tuple(image, label)]
    :param classes: Name of classes in the dataset
    :param inv_normalize: Mean and Standard deviation values of the dataset
    :param model: Model architecture
    :param target_layers: Layers on which GradCam should be executed
    :param targets: Classes to be focused on for GradCam
    :param number_of_samples: Number of images to print
    :param transparency: Weight of Normal image when mixed with activations
    """
    # Plot configuration
    fig = plt.figure(figsize=(10, 10))
    inv_transform = get_inv_transforms()

    x_count = 5
    y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)

    # Create an object for GradCam
    cam = GradCAM(model=model, target_layers=target_layers)

    # Iterate over number of specified images
    for i in range(number_of_samples):
        plt.subplot(y_count, x_count, i + 1)
        input_tensor = data[i][0]

        # Get the activations of the layer for the images
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]

        # Get back the original image
        img = np.array(input_tensor.squeeze(0).to('cpu'))
        img = np.transpose(img, (1, 2, 0))
        img = inv_transform(image=img)['image']
        rgb_img = np.clip(img, 0, 1)

        # Mix the activations on the original image
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)

        # Display the images on the plot
        plt.imshow(visualization)
        plt.title(r"Correct: " + CLASS_NAMES[data[i][1].item()] + 'Output: ' + CLASS_NAMES[data[i][2].item()])
        plt.xticks([])
        plt.yticks([])