|
import gradio as gr |
|
|
|
from models.custom_resnet import CustomResNet |
|
from modules.visualize import plot_gradcam_images, plot_misclassified_images |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from torchvision import transforms |
|
import modules.config as config |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
TITLE = "CIFAR10 Image classification using a Custom ResNet Model" |
|
DESCRIPTION = "Gradio App to infer using a Custom ResNet model and get GradCAM results" |
|
examples = [ |
|
["assets/images/airplane.jpg", 3, True, "layer3_x", 0.6, True, 5, True, 5], |
|
["assets/images/bird.jpeg", 4, True, "layer3_x", 0.7, True, 10, True, 20], |
|
["assets/images/car.jpg", 5, True, "layer3_x", 0.5, True, 15, True, 5], |
|
["assets/images/cat.jpeg", 6, True, "layer3_x", 0.65, True, 20, True, 10], |
|
["assets/images/deer.jpg", 7, False, "layer2", 0.75, True, 5, True, 5], |
|
["assets/images/dog.jpg", 8, True, "layer2", 0.55, True, 10, True, 5], |
|
["assets/images/frog.jpeg", 9, True, "layer2", 0.8, True, 15, True, 15], |
|
["assets/images/horse.jpg", 10, False, "layer1_r1", 0.85, True, 20, True, 5], |
|
["assets/images/ship.jpg", 3, True, "layer1_r1", 0.4, True, 5, True, 15], |
|
["assets/images/truck.jpg", 4, True, "layer1_r1", 0.3, True, 5, True, 10], |
|
] |
|
|
|
|
|
|
|
|
|
model = CustomResNet() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model.load_state_dict(torch.load(config.GRADIO_MODEL_PATH, map_location=device), strict=False) |
|
|
|
model.to(device) |
|
|
|
model.eval() |
|
|
|
|
|
misclassified_image_data = torch.load(config.GRADIO_MISCLASSIFIED_PATH, map_location=device) |
|
|
|
|
|
classes = list(config.CIFAR_CLASSES) |
|
|
|
model_layer_names = ["prep", "layer1_x", "layer1_r1", "layer2", "layer3_x", "layer3_r2"] |
|
|
|
|
|
def get_target_layer(layer_name): |
|
"""Get target layer for visualization""" |
|
if layer_name == "prep": |
|
return [model.prep[-1]] |
|
elif layer_name == "layer1_x": |
|
return [model.layer1_x[-1]] |
|
elif layer_name == "layer1_r1": |
|
return [model.layer1_r1[-1]] |
|
elif layer_name == "layer2": |
|
return [model.layer2[-1]] |
|
elif layer_name == "layer3_x": |
|
return [model.layer3_x[-1]] |
|
elif layer_name == "layer3_r2": |
|
return [model.layer3_r2[-1]] |
|
else: |
|
return None |
|
|
|
|
|
def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"): |
|
""" "Given an input image, generate the prediction, confidence and display_image""" |
|
mean = list(config.CIFAR_MEAN) |
|
std = list(config.CIFAR_STD) |
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) |
|
|
|
with torch.no_grad(): |
|
orginal_img = input_image |
|
input_image = transform(input_image).unsqueeze(0).to(device) |
|
|
|
model_output = model(input_image).to(device) |
|
|
|
output_exp = torch.exp(model_output).to(device) |
|
|
|
|
|
output_numpy = np.squeeze(np.asarray(output_exp.numpy())) |
|
|
|
sorted_indexes = np.argsort(output_numpy)[::-1] |
|
|
|
|
|
|
|
confidences = {} |
|
for _ in range(int(num_classes)): |
|
|
|
confidences[classes[sorted_indexes[_]]] = float(output_numpy[sorted_indexes[_]]) |
|
|
|
|
|
if show_gradcam: |
|
|
|
target_layers = get_target_layer(layer_name) |
|
cam = GradCAM(model=model, target_layers=target_layers) |
|
cam_generated = cam(input_tensor=input_image, targets=None) |
|
cam_generated = cam_generated[0, :] |
|
display_image = show_cam_on_image(orginal_img / 255, cam_generated, use_rgb=True, image_weight=transparency) |
|
|
|
else: |
|
display_image = orginal_img |
|
|
|
return confidences, display_image |
|
|
|
|
|
def app_interface( |
|
input_image, |
|
num_classes, |
|
show_gradcam, |
|
layer_name, |
|
transparency, |
|
show_misclassified, |
|
num_misclassified, |
|
show_gradcam_misclassified, |
|
num_gradcam_misclassified, |
|
): |
|
"""Function which provides the Gradio interface""" |
|
input_image = resize_image_pil(input_image, 32, 32) |
|
|
|
input_image = np.array(input_image) |
|
org_img = input_image |
|
|
|
confidences, display_image = generate_prediction(org_img, num_classes, show_gradcam, transparency, layer_name) |
|
|
|
if show_misclassified: |
|
misclassified_fig, misclassified_axs = plot_misclassified_images( |
|
data=misclassified_image_data, class_label=classes, num_images=num_misclassified |
|
) |
|
else: |
|
misclassified_fig = None |
|
|
|
if show_gradcam_misclassified: |
|
gradcam_fig, gradcam_axs = plot_gradcam_images( |
|
model=model, |
|
data=misclassified_image_data, |
|
class_label=classes, |
|
|
|
|
|
target_layers=get_target_layer(layer_name), |
|
targets=None, |
|
num_images=num_gradcam_misclassified, |
|
image_weight=transparency, |
|
) |
|
else: |
|
gradcam_fig = None |
|
|
|
|
|
|
|
|
|
|
|
return confidences, display_image, misclassified_fig, gradcam_fig |
|
|
|
def resize_image_pil(image, new_width, new_height): |
|
|
|
|
|
img = Image.fromarray(np.array(image)) |
|
|
|
|
|
width, height = img.size |
|
|
|
|
|
width_scale = new_width / width |
|
height_scale = new_height / height |
|
scale = min(width_scale, height_scale) |
|
|
|
|
|
resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) |
|
|
|
|
|
resized = resized.crop((0, 0, new_width, new_height)) |
|
|
|
return resized |
|
|
|
|
|
|
|
inference_app = gr.Interface( |
|
app_interface, |
|
inputs=[ |
|
|
|
gr.Image(width=256, height=256, label="Input Image"), |
|
gr.Number(value=3, maximum=10, minimum=1, step=1.0, precision=0, label="#Classes to show"), |
|
gr.Checkbox(True, label="Show GradCAM Image"), |
|
gr.Dropdown(model_layer_names, value="layer3_x", label="Visulalization Layer from Model"), |
|
|
|
gr.Slider(0, 1, 0.6, label="Image Overlay Factor"), |
|
gr.Checkbox(True, label="Show Misclassified Images?"), |
|
gr.Slider(value=10, maximum=25, minimum=5, step=5.0, label="#Misclassified images to show"), |
|
gr.Checkbox(True, label="Visulize GradCAM for Misclassified images?"), |
|
gr.Slider(value=10, maximum=25, minimum=5, step=5.0, label="#GradCAM images to show"), |
|
], |
|
outputs=[ |
|
gr.Label(label="Confidences", container=True, show_label=True), |
|
gr.Image(label="Grad CAM/ Input Image", container=True, show_label=True,height=256,width=256), |
|
gr.Plot(label="Misclassified images", container=True, show_label=True), |
|
gr.Plot(label="Grad CAM of Misclassified images", container=True, show_label=True), |
|
], |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
examples=examples, |
|
) |
|
inference_app.launch() |
|
|
|
|
|
|