Mojo
commited on
Commit
Β·
45cd721
1
Parent(s):
c7e44af
Added files
Browse files- app.py +179 -4
- {app/assets β assets}/images/airplane.jpg +0 -0
- {app/assets β assets}/images/bird.jpeg +0 -0
- {app/assets β assets}/images/car.jpg +0 -0
- {app/assets β assets}/images/cat.jpeg +0 -0
- {app/assets β assets}/images/deer.jpg +0 -0
- {app/assets β assets}/images/dog.jpg +0 -0
- {app/assets β assets}/images/frog.jpeg +0 -0
- {app/assets β assets}/images/horse.jpg +0 -0
- {app/assets β assets}/images/ship.jpg +0 -0
- {app/assets β assets}/images/truck.jpg +0 -0
- modules/config.py +5 -1
app.py
CHANGED
@@ -1,10 +1,185 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from modules.custom_resnet import CustomResNet
|
4 |
+
from modules.visualize import plot_gradcam_images, plot_misclassified_images
|
5 |
+
from pytorch_grad_cam import GradCAM
|
6 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
7 |
+
from torchvision import transforms
|
8 |
+
import modules.config as config
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
|
12 |
|
13 |
+
TITLE = "CIFAR10 Image classification using a Custom ResNet Model"
|
14 |
+
DESCRIPTION = "Gradio App to infer using a Custom ResNet model and get GradCAM results"
|
15 |
+
examples = [
|
16 |
+
["assets/images/airplane.jpg", 3, True, "layer3_x", 0.6, True, 5, True, 5],
|
17 |
+
["assets/images/bird.jpeg", 4, True, "layer3_x", 0.7, True, 10, True, 20],
|
18 |
+
["assets/images/car.jpg", 5, True, "layer3_x", 0.5, True, 15, True, 5],
|
19 |
+
["assets/images/cat.jpeg", 6, True, "layer3_x", 0.65, True, 20, True, 10],
|
20 |
+
["assets/images/deer.jpg", 7, False, "layer2", 0.75, True, 5, True, 5],
|
21 |
+
["assets/images/dog.jpg", 8, True, "layer2", 0.55, True, 10, True, 5],
|
22 |
+
["assets/images/frog.jpeg", 9, True, "layer2", 0.8, True, 15, True, 15],
|
23 |
+
["assets/images/horse.jpg", 10, False, "layer1_r1", 0.85, True, 20, True, 5],
|
24 |
+
["assets/images/ship.jpg", 3, True, "layer1_r1", 0.4, True, 5, True, 15],
|
25 |
+
["assets/images/truck.jpg", 4, True, "layer1_r1", 0.3, True, 5, True, 10],
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
# load and initialise the model
|
30 |
+
|
31 |
+
model = CustomResNet()
|
32 |
+
|
33 |
+
# Define the device
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
# Using the checkpoint path present in config, load the trained model
|
36 |
+
model.load_state_dict(torch.load(config.GRADIO_MODEL_PATH, map_location=device), strict=False)
|
37 |
+
# Send model to CPU
|
38 |
+
model.to(device)
|
39 |
+
# Make the model in evaluation mode
|
40 |
+
model.eval()
|
41 |
+
|
42 |
+
# Load the misclassified images data
|
43 |
+
misclassified_image_data = torch.load(config.GRADIO_MISCLASSIFIED_PATH, map_location=device)
|
44 |
+
|
45 |
+
# Class Names
|
46 |
+
classes = list(config.CIFAR_CLASSES)
|
47 |
+
# Allowed model names
|
48 |
+
model_layer_names = ["prep", "layer1_x", "layer1_r1", "layer2", "layer3_x", "layer3_r2"]
|
49 |
+
|
50 |
+
|
51 |
+
def get_target_layer(layer_name):
|
52 |
+
"""Get target layer for visualization"""
|
53 |
+
if layer_name == "prep":
|
54 |
+
return [model.prep[-1]]
|
55 |
+
elif layer_name == "layer1_x":
|
56 |
+
return [model.layer1_x[-1]]
|
57 |
+
elif layer_name == "layer1_r1":
|
58 |
+
return [model.layer1_r1[-1]]
|
59 |
+
elif layer_name == "layer2":
|
60 |
+
return [model.layer2[-1]]
|
61 |
+
elif layer_name == "layer3_x":
|
62 |
+
return [model.layer3_x[-1]]
|
63 |
+
elif layer_name == "layer3_r2":
|
64 |
+
return [model.layer3_r2[-1]]
|
65 |
+
else:
|
66 |
+
return None
|
67 |
+
|
68 |
+
|
69 |
+
def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"):
|
70 |
+
""" "Given an input image, generate the prediction, confidence and display_image"""
|
71 |
+
mean = list(config.CIFAR_MEAN)
|
72 |
+
std = list(config.CIFAR_STD)
|
73 |
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
orginal_img = input_image
|
77 |
+
input_image = transform(input_image).unsqueeze(0).to(device)
|
78 |
+
# print(f"Input Device: {input_image.device}")
|
79 |
+
model_output = model(input_image).to(device)
|
80 |
+
# print(f"Output Device: {outputs.device}")
|
81 |
+
output_exp = torch.exp(model_output).to(device)
|
82 |
+
# print(f"Output Exp Device: {o.device}")
|
83 |
+
|
84 |
+
output_numpy = np.squeeze(np.asarray(output_exp.numpy()))
|
85 |
+
# get indexes of probabilties in descending order
|
86 |
+
sorted_indexes = np.argsort(output_numpy)[::-1]
|
87 |
+
# sort the probabilities in descending order
|
88 |
+
# final_class = classes[o_np.argmax()]
|
89 |
+
|
90 |
+
confidences = {}
|
91 |
+
for _ in range(int(num_classes)):
|
92 |
+
# set the confidence of highest class with highest probability
|
93 |
+
confidences[classes[sorted_indexes[_]]] = float(output_numpy[sorted_indexes[_]])
|
94 |
+
|
95 |
+
# Show Grad Cam
|
96 |
+
if show_gradcam:
|
97 |
+
# Get the target layer
|
98 |
+
target_layers = get_target_layer(layer_name)
|
99 |
+
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
100 |
+
cam_generated = cam(input_tensor=input_image, targets=None)
|
101 |
+
cam_generated = cam_generated[0, :]
|
102 |
+
display_image = show_cam_on_image(orginal_img / 255, cam_generated, use_rgb=True, image_weight=transparency)
|
103 |
+
else:
|
104 |
+
display_image = orginal_img
|
105 |
+
|
106 |
+
return confidences, display_image
|
107 |
+
|
108 |
+
|
109 |
+
def app_interface(
|
110 |
+
input_image,
|
111 |
+
num_classes,
|
112 |
+
show_gradcam,
|
113 |
+
layer_name,
|
114 |
+
transparency,
|
115 |
+
show_misclassified,
|
116 |
+
num_misclassified,
|
117 |
+
show_gradcam_misclassified,
|
118 |
+
num_gradcam_misclassified,
|
119 |
+
):
|
120 |
+
"""Function which provides the Gradio interface"""
|
121 |
+
|
122 |
+
# Get the prediction for the input image along with confidence and display_image
|
123 |
+
confidences, display_image = generate_prediction(input_image, num_classes, show_gradcam, transparency, layer_name)
|
124 |
+
|
125 |
+
if show_misclassified:
|
126 |
+
misclassified_fig, misclassified_axs = plot_misclassified_images(
|
127 |
+
data=misclassified_image_data, class_label=classes, num_images=num_misclassified
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
misclassified_fig = None
|
131 |
+
|
132 |
+
if show_gradcam_misclassified:
|
133 |
+
gradcam_fig, gradcam_axs = plot_gradcam_images(
|
134 |
+
model=model,
|
135 |
+
data=misclassified_image_data,
|
136 |
+
class_label=classes,
|
137 |
+
# Use penultimate block of resnet18 layer 3 as the target layer for gradcam
|
138 |
+
# Decided using model summary so that dimensions > 7x7
|
139 |
+
target_layers=get_target_layer(layer_name),
|
140 |
+
targets=None,
|
141 |
+
num_images=num_gradcam_misclassified,
|
142 |
+
image_weight=transparency,
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
gradcam_fig = None
|
146 |
+
|
147 |
+
# # delete ununsed axises
|
148 |
+
# del misclassified_axs
|
149 |
+
# del gradcam_axs
|
150 |
+
|
151 |
+
return confidences, display_image, misclassified_fig, gradcam_fig
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
inference_app = gr.Interface(
|
157 |
+
app_interface,
|
158 |
+
inputs=[
|
159 |
+
# This accepts the image after resizing it to 32x32 which is what our model expects
|
160 |
+
gr.Image(shape=(32, 32)),
|
161 |
+
gr.Number(value=3, maximum=10, minimum=1, step=1.0, precision=0, label="#Classes to show"),
|
162 |
+
gr.Checkbox(True, label="Show GradCAM Image"),
|
163 |
+
gr.Dropdown(model_layer_names, value="layer3_x", label="Visulalization Layer from Model"),
|
164 |
+
# How much should the image be overlayed on the original image
|
165 |
+
gr.Slider(0, 1, 0.6, label="Image Overlay Factor"),
|
166 |
+
gr.Checkbox(True, label="Show Misclassified Images?"),
|
167 |
+
gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#Misclassified images to show"),
|
168 |
+
gr.Checkbox(True, label="Visulize GradCAM for Misclassified images?"),
|
169 |
+
gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#GradCAM images to show"),
|
170 |
+
],
|
171 |
+
outputs=[
|
172 |
+
gr.Label(label="Confidences", container=True, show_label=True),
|
173 |
+
gr.Image(shape=(32, 32), label="Grad CAM/ Input Image", container=True, show_label=True).style(
|
174 |
+
width=256, height=256
|
175 |
+
),
|
176 |
+
gr.Plot(label="Misclassified images", container=True, show_label=True),
|
177 |
+
gr.Plot(label="Grad CAM of Misclassified images", container=True, show_label=True),
|
178 |
+
],
|
179 |
+
title=TITLE,
|
180 |
+
description=DESCRIPTION,
|
181 |
+
examples=examples,
|
182 |
+
)
|
183 |
+
inference_app.launch()
|
184 |
|
185 |
|
{app/assets β assets}/images/airplane.jpg
RENAMED
File without changes
|
{app/assets β assets}/images/bird.jpeg
RENAMED
File without changes
|
{app/assets β assets}/images/car.jpg
RENAMED
File without changes
|
{app/assets β assets}/images/cat.jpeg
RENAMED
File without changes
|
{app/assets β assets}/images/deer.jpg
RENAMED
File without changes
|
{app/assets β assets}/images/dog.jpg
RENAMED
File without changes
|
{app/assets β assets}/images/frog.jpeg
RENAMED
File without changes
|
{app/assets β assets}/images/horse.jpg
RENAMED
File without changes
|
{app/assets β assets}/images/ship.jpg
RENAMED
File without changes
|
{app/assets β assets}/images/truck.jpg
RENAMED
File without changes
|
modules/config.py
CHANGED
@@ -47,4 +47,8 @@ CIFAR_CLASSES = tuple(
|
|
47 |
"ship",
|
48 |
"truck",
|
49 |
]
|
50 |
-
)
|
|
|
|
|
|
|
|
|
|
47 |
"ship",
|
48 |
"truck",
|
49 |
]
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
GRADIO_MISCLASSIFIED_PATH = "./assets/model/Misclassified_Data.pt"
|
54 |
+
GRADIO_MODEL_PATH = "./assets/model/CustomResNet.pt"
|