|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
|
|
|
def convert_back_image(image): |
|
"""Using mean and std deviation convert image back to normal""" |
|
cifar10_mean = (0.4914, 0.4822, 0.4471) |
|
cifar10_std = (0.2469, 0.2433, 0.2615) |
|
image = image.numpy().astype(dtype=np.float32) |
|
|
|
for i in range(image.shape[0]): |
|
image[i] = (image[i] * cifar10_std[i]) + cifar10_mean[i] |
|
|
|
|
|
image = image.clip(0, 1) |
|
|
|
return np.transpose(image, (1, 2, 0)) |
|
|
|
|
|
def plot_sample_training_images(batch_data, batch_label, class_label, num_images=30): |
|
"""Function to plot sample images from the training data.""" |
|
images, labels = batch_data, batch_label |
|
|
|
|
|
num_images = min(num_images, len(images)) |
|
|
|
num_cols = 5 |
|
num_rows = int(np.ceil(num_images / num_cols)) |
|
|
|
|
|
fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10)) |
|
|
|
|
|
|
|
for img_index in range(1, num_images + 1): |
|
plt.subplot(num_rows, num_cols, img_index) |
|
plt.tight_layout() |
|
plt.axis("off") |
|
plt.imshow(convert_back_image(images[img_index - 1])) |
|
plt.title(class_label[labels[img_index - 1].item()]) |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
return fig, axs |
|
|
|
|
|
def plot_train_test_metrics(results): |
|
""" |
|
Function to plot the training and test metrics. |
|
""" |
|
|
|
train_losses = results["train_loss"] |
|
train_acc = results["train_acc"] |
|
test_losses = results["test_loss"] |
|
test_acc = results["test_acc"] |
|
|
|
|
|
fig, axs = plt.subplots(1, 2, figsize=(16, 8)) |
|
|
|
|
|
axs[0].plot(train_losses, label="Train") |
|
axs[0].plot(test_losses, label="Test") |
|
axs[0].set_title("Loss") |
|
axs[0].legend(loc="upper right") |
|
|
|
|
|
axs[1].plot(train_acc, label="Train") |
|
axs[1].plot(test_acc, label="Test") |
|
axs[1].set_title("Accuracy") |
|
axs[1].legend(loc="upper right") |
|
|
|
return fig, axs |
|
|
|
|
|
def plot_misclassified_images(data, class_label, num_images=10): |
|
"""Plot the misclassified images from the test dataset.""" |
|
|
|
num_images = min(num_images, len(data["ground_truths"])) |
|
|
|
num_cols = 5 |
|
num_rows = int(np.ceil(num_images / num_cols)) |
|
|
|
|
|
fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2)) |
|
|
|
|
|
|
|
for img_index in range(1, num_images + 1): |
|
|
|
label = data["ground_truths"][img_index - 1].cpu().item() |
|
pred = data["predicted_vals"][img_index - 1].cpu().item() |
|
|
|
image = data["images"][img_index - 1].cpu() |
|
|
|
plt.subplot(num_rows, num_cols, img_index) |
|
plt.tight_layout() |
|
plt.axis("off") |
|
plt.imshow(convert_back_image(image)) |
|
plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""") |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
return fig, axs |
|
|
|
|
|
|
|
def plot_gradcam_images( |
|
model, |
|
data, |
|
class_label, |
|
target_layers, |
|
targets=None, |
|
num_images=10, |
|
image_weight=0.25, |
|
): |
|
"""Show gradcam for misclassified images""" |
|
|
|
|
|
num_images = min(num_images, len(data["ground_truths"])) |
|
|
|
num_cols = 5 |
|
num_rows = int(np.ceil(num_images / num_cols)) |
|
|
|
|
|
fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2)) |
|
|
|
|
|
|
|
|
|
cam = GradCAM(model=model, target_layers=target_layers) |
|
|
|
|
|
for img_index in range(1, num_images + 1): |
|
|
|
|
|
label = data["ground_truths"][img_index - 1].cpu().item() |
|
pred = data["predicted_vals"][img_index - 1].cpu().item() |
|
|
|
image = data["images"][img_index - 1].cpu() |
|
|
|
|
|
|
|
grad_cam_output = cam( |
|
input_tensor=image.unsqueeze(0), |
|
targets=targets, |
|
aug_smooth=True, |
|
eigen_smooth=True, |
|
) |
|
grad_cam_output = grad_cam_output[0, :] |
|
|
|
|
|
overlayed_image = show_cam_on_image( |
|
convert_back_image(image), |
|
grad_cam_output, |
|
use_rgb=True, |
|
image_weight=image_weight, |
|
) |
|
|
|
|
|
plt.subplot(num_rows, num_cols, img_index) |
|
plt.tight_layout() |
|
plt.axis("off") |
|
plt.imshow(overlayed_image) |
|
plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""") |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
return fig, axs |