Spaces:
Sleeping
Sleeping
File size: 4,396 Bytes
1ffed57 3d35fce 1ffed57 a534ad2 1ffed57 6aa42d5 1ffed57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import math
import numpy as np
import albumentations
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
CLASS_NAMES= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def get_inv_transforms():
"""Method to get transform to inverse the effect of normalization for ploting
Returns:
_Object: Object to apply image augmentations
"""
# Normalize image
inv_transforms = albumentations.Normalize([-0.48215841/0.24348513, -0.44653091/0.26158784, -0.49139968/0.24703223],
[1/0.24348513, 1/0.26158784, 1/0.24703223], max_pixel_value=1.0)
return inv_transforms
def plot_samples(train_loader, number_of_images):
"""Method to plot samples of augmented images
Args:
train_loader (Object): Object of data loader class to get images
"""
inv_transform = get_inv_transforms()
figure = plt.figure()
x_count = 5
y_count = 1 if number_of_images <= 5 else math.floor(number_of_images / x_count)
images, labels = next(iter(train_loader))
for index in range(1, number_of_images + 1):
plt.subplot(y_count, x_count, index)
plt.title(CLASS_NAMES[labels[index].numpy()])
plt.axis('off')
image = np.array(images[index])
image = np.transpose(image, (1, 2, 0))
image = inv_transform(image=image)['image']
plt.imshow(image)
def display_cifar_misclassified_data(data: list,
number_of_samples: int = 10):
"""
Function to plot images with labels
:param data: List[Tuple(image, label)]
:param number_of_samples: Number of images to print
"""
fig = plt.figure(figsize=(10, 10))
inv_transform = get_inv_transforms()
x_count = 5
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
for i in range(number_of_samples):
plt.subplot(y_count, x_count, i + 1)
img = np.array(data[i][0].squeeze().to('cpu'))
img = np.transpose(img, (1, 2, 0))
img = inv_transform(image=img)['image']
plt.imshow(img)
plt.title(r'Pred: ' + CLASS_NAMES[data[i][2].item()])
plt.xticks([])
plt.yticks([])
return fig
def display_gradcam_output(data: list,
model,
target_layers,
targets=None,
number_of_samples: int = 10,
transparency: float = 0.60):
"""
Function to visualize GradCam output on the data
:param data: List[Tuple(image, label)]
:param classes: Name of classes in the dataset
:param inv_normalize: Mean and Standard deviation values of the dataset
:param model: Model architecture
:param target_layers: Layers on which GradCam should be executed
:param targets: Classes to be focused on for GradCam
:param number_of_samples: Number of images to print
:param transparency: Weight of Normal image when mixed with activations
"""
# Plot configuration
fig = plt.figure(figsize=(10, 10))
inv_transform = get_inv_transforms()
x_count = 5
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
# Create an object for GradCam
cam = GradCAM(model=model, target_layers=target_layers)
# Iterate over number of specified images
for i in range(number_of_samples):
plt.subplot(y_count, x_count, i + 1)
input_tensor = data[i][0]
# Get the activations of the layer for the images
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
# Get back the original image
img = np.array(input_tensor.squeeze(0).to('cpu'))
img = np.transpose(img, (1, 2, 0))
img = inv_transform(image=img)['image']
rgb_img = np.clip(img, 0, 1)
# Mix the activations on the original image
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
# Display the images on the plot
plt.imshow(visualization)
plt.title(r"Correct: " + CLASS_NAMES[data[i][1].item()] + 'Output: ' + CLASS_NAMES[data[i][2].item()])
plt.xticks([])
plt.yticks([]) |