Spaces:
Sleeping
Sleeping
import math | |
import numpy as np | |
import albumentations | |
import matplotlib.pyplot as plt | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
CLASS_NAMES= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
def get_inv_transforms(): | |
"""Method to get transform to inverse the effect of normalization for ploting | |
Returns: | |
_Object: Object to apply image augmentations | |
""" | |
# Normalize image | |
inv_transforms = albumentations.Normalize([-0.48215841/0.24348513, -0.44653091/0.26158784, -0.49139968/0.24703223], | |
[1/0.24348513, 1/0.26158784, 1/0.24703223], max_pixel_value=1.0) | |
return inv_transforms | |
def plot_samples(train_loader, number_of_images): | |
"""Method to plot samples of augmented images | |
Args: | |
train_loader (Object): Object of data loader class to get images | |
""" | |
inv_transform = get_inv_transforms() | |
figure = plt.figure() | |
x_count = 5 | |
y_count = 1 if number_of_images <= 5 else math.floor(number_of_images / x_count) | |
images, labels = next(iter(train_loader)) | |
for index in range(1, number_of_images + 1): | |
plt.subplot(y_count, x_count, index) | |
plt.title(CLASS_NAMES[labels[index].numpy()]) | |
plt.axis('off') | |
image = np.array(images[index]) | |
image = np.transpose(image, (1, 2, 0)) | |
image = inv_transform(image=image)['image'] | |
plt.imshow(image) | |
def display_cifar_misclassified_data(data: list, | |
number_of_samples: int = 10): | |
""" | |
Function to plot images with labels | |
:param data: List[Tuple(image, label)] | |
:param number_of_samples: Number of images to print | |
""" | |
fig = plt.figure(figsize=(10, 10)) | |
inv_transform = get_inv_transforms() | |
x_count = 5 | |
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) | |
for i in range(number_of_samples): | |
plt.subplot(y_count, x_count, i + 1) | |
img = np.array(data[i][0].squeeze().to('cpu')) | |
img = np.transpose(img, (1, 2, 0)) | |
img = inv_transform(image=img)['image'] | |
plt.imshow(img) | |
plt.title(r'Pred: ' + CLASS_NAMES[data[i][2].item()]) | |
plt.xticks([]) | |
plt.yticks([]) | |
return fig | |
def display_gradcam_output(data: list, | |
model, | |
target_layers, | |
targets=None, | |
number_of_samples: int = 10, | |
transparency: float = 0.60): | |
""" | |
Function to visualize GradCam output on the data | |
:param data: List[Tuple(image, label)] | |
:param classes: Name of classes in the dataset | |
:param inv_normalize: Mean and Standard deviation values of the dataset | |
:param model: Model architecture | |
:param target_layers: Layers on which GradCam should be executed | |
:param targets: Classes to be focused on for GradCam | |
:param number_of_samples: Number of images to print | |
:param transparency: Weight of Normal image when mixed with activations | |
""" | |
# Plot configuration | |
fig = plt.figure(figsize=(10, 10)) | |
inv_transform = get_inv_transforms() | |
x_count = 5 | |
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) | |
# Create an object for GradCam | |
cam = GradCAM(model=model, target_layers=target_layers) | |
# Iterate over number of specified images | |
for i in range(number_of_samples): | |
plt.subplot(y_count, x_count, i + 1) | |
input_tensor = data[i][0] | |
# Get the activations of the layer for the images | |
grayscale_cam = cam(input_tensor=input_tensor, targets=targets) | |
grayscale_cam = grayscale_cam[0, :] | |
# Get back the original image | |
img = np.array(input_tensor.squeeze(0).to('cpu')) | |
img = np.transpose(img, (1, 2, 0)) | |
img = inv_transform(image=img)['image'] | |
rgb_img = np.clip(img, 0, 1) | |
# Mix the activations on the original image | |
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency) | |
# Display the images on the plot | |
plt.imshow(visualization) | |
plt.title(r"Correct: " + CLASS_NAMES[data[i][1].item()] + 'Output: ' + CLASS_NAMES[data[i][2].item()]) | |
plt.xticks([]) | |
plt.yticks([]) |