Mojo commited on
Commit
45cd721
Β·
1 Parent(s): c7e44af

Added files

Browse files
app.py CHANGED
@@ -1,10 +1,185 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return f"Hello {name}!"
 
 
 
 
 
 
5
 
6
 
7
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
 
1
  import gradio as gr
2
 
3
+ from modules.custom_resnet import CustomResNet
4
+ from modules.visualize import plot_gradcam_images, plot_misclassified_images
5
+ from pytorch_grad_cam import GradCAM
6
+ from pytorch_grad_cam.utils.image import show_cam_on_image
7
+ from torchvision import transforms
8
+ import modules.config as config
9
+ import numpy as np
10
+ import torch
11
 
12
 
13
+ TITLE = "CIFAR10 Image classification using a Custom ResNet Model"
14
+ DESCRIPTION = "Gradio App to infer using a Custom ResNet model and get GradCAM results"
15
+ examples = [
16
+ ["assets/images/airplane.jpg", 3, True, "layer3_x", 0.6, True, 5, True, 5],
17
+ ["assets/images/bird.jpeg", 4, True, "layer3_x", 0.7, True, 10, True, 20],
18
+ ["assets/images/car.jpg", 5, True, "layer3_x", 0.5, True, 15, True, 5],
19
+ ["assets/images/cat.jpeg", 6, True, "layer3_x", 0.65, True, 20, True, 10],
20
+ ["assets/images/deer.jpg", 7, False, "layer2", 0.75, True, 5, True, 5],
21
+ ["assets/images/dog.jpg", 8, True, "layer2", 0.55, True, 10, True, 5],
22
+ ["assets/images/frog.jpeg", 9, True, "layer2", 0.8, True, 15, True, 15],
23
+ ["assets/images/horse.jpg", 10, False, "layer1_r1", 0.85, True, 20, True, 5],
24
+ ["assets/images/ship.jpg", 3, True, "layer1_r1", 0.4, True, 5, True, 15],
25
+ ["assets/images/truck.jpg", 4, True, "layer1_r1", 0.3, True, 5, True, 10],
26
+ ]
27
+
28
+
29
+ # load and initialise the model
30
+
31
+ model = CustomResNet()
32
+
33
+ # Define the device
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ # Using the checkpoint path present in config, load the trained model
36
+ model.load_state_dict(torch.load(config.GRADIO_MODEL_PATH, map_location=device), strict=False)
37
+ # Send model to CPU
38
+ model.to(device)
39
+ # Make the model in evaluation mode
40
+ model.eval()
41
+
42
+ # Load the misclassified images data
43
+ misclassified_image_data = torch.load(config.GRADIO_MISCLASSIFIED_PATH, map_location=device)
44
+
45
+ # Class Names
46
+ classes = list(config.CIFAR_CLASSES)
47
+ # Allowed model names
48
+ model_layer_names = ["prep", "layer1_x", "layer1_r1", "layer2", "layer3_x", "layer3_r2"]
49
+
50
+
51
+ def get_target_layer(layer_name):
52
+ """Get target layer for visualization"""
53
+ if layer_name == "prep":
54
+ return [model.prep[-1]]
55
+ elif layer_name == "layer1_x":
56
+ return [model.layer1_x[-1]]
57
+ elif layer_name == "layer1_r1":
58
+ return [model.layer1_r1[-1]]
59
+ elif layer_name == "layer2":
60
+ return [model.layer2[-1]]
61
+ elif layer_name == "layer3_x":
62
+ return [model.layer3_x[-1]]
63
+ elif layer_name == "layer3_r2":
64
+ return [model.layer3_r2[-1]]
65
+ else:
66
+ return None
67
+
68
+
69
+ def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"):
70
+ """ "Given an input image, generate the prediction, confidence and display_image"""
71
+ mean = list(config.CIFAR_MEAN)
72
+ std = list(config.CIFAR_STD)
73
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
74
+
75
+ with torch.no_grad():
76
+ orginal_img = input_image
77
+ input_image = transform(input_image).unsqueeze(0).to(device)
78
+ # print(f"Input Device: {input_image.device}")
79
+ model_output = model(input_image).to(device)
80
+ # print(f"Output Device: {outputs.device}")
81
+ output_exp = torch.exp(model_output).to(device)
82
+ # print(f"Output Exp Device: {o.device}")
83
+
84
+ output_numpy = np.squeeze(np.asarray(output_exp.numpy()))
85
+ # get indexes of probabilties in descending order
86
+ sorted_indexes = np.argsort(output_numpy)[::-1]
87
+ # sort the probabilities in descending order
88
+ # final_class = classes[o_np.argmax()]
89
+
90
+ confidences = {}
91
+ for _ in range(int(num_classes)):
92
+ # set the confidence of highest class with highest probability
93
+ confidences[classes[sorted_indexes[_]]] = float(output_numpy[sorted_indexes[_]])
94
+
95
+ # Show Grad Cam
96
+ if show_gradcam:
97
+ # Get the target layer
98
+ target_layers = get_target_layer(layer_name)
99
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
100
+ cam_generated = cam(input_tensor=input_image, targets=None)
101
+ cam_generated = cam_generated[0, :]
102
+ display_image = show_cam_on_image(orginal_img / 255, cam_generated, use_rgb=True, image_weight=transparency)
103
+ else:
104
+ display_image = orginal_img
105
+
106
+ return confidences, display_image
107
+
108
+
109
+ def app_interface(
110
+ input_image,
111
+ num_classes,
112
+ show_gradcam,
113
+ layer_name,
114
+ transparency,
115
+ show_misclassified,
116
+ num_misclassified,
117
+ show_gradcam_misclassified,
118
+ num_gradcam_misclassified,
119
+ ):
120
+ """Function which provides the Gradio interface"""
121
+
122
+ # Get the prediction for the input image along with confidence and display_image
123
+ confidences, display_image = generate_prediction(input_image, num_classes, show_gradcam, transparency, layer_name)
124
+
125
+ if show_misclassified:
126
+ misclassified_fig, misclassified_axs = plot_misclassified_images(
127
+ data=misclassified_image_data, class_label=classes, num_images=num_misclassified
128
+ )
129
+ else:
130
+ misclassified_fig = None
131
+
132
+ if show_gradcam_misclassified:
133
+ gradcam_fig, gradcam_axs = plot_gradcam_images(
134
+ model=model,
135
+ data=misclassified_image_data,
136
+ class_label=classes,
137
+ # Use penultimate block of resnet18 layer 3 as the target layer for gradcam
138
+ # Decided using model summary so that dimensions > 7x7
139
+ target_layers=get_target_layer(layer_name),
140
+ targets=None,
141
+ num_images=num_gradcam_misclassified,
142
+ image_weight=transparency,
143
+ )
144
+ else:
145
+ gradcam_fig = None
146
+
147
+ # # delete ununsed axises
148
+ # del misclassified_axs
149
+ # del gradcam_axs
150
+
151
+ return confidences, display_image, misclassified_fig, gradcam_fig
152
+
153
+
154
+
155
+
156
+ inference_app = gr.Interface(
157
+ app_interface,
158
+ inputs=[
159
+ # This accepts the image after resizing it to 32x32 which is what our model expects
160
+ gr.Image(shape=(32, 32)),
161
+ gr.Number(value=3, maximum=10, minimum=1, step=1.0, precision=0, label="#Classes to show"),
162
+ gr.Checkbox(True, label="Show GradCAM Image"),
163
+ gr.Dropdown(model_layer_names, value="layer3_x", label="Visulalization Layer from Model"),
164
+ # How much should the image be overlayed on the original image
165
+ gr.Slider(0, 1, 0.6, label="Image Overlay Factor"),
166
+ gr.Checkbox(True, label="Show Misclassified Images?"),
167
+ gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#Misclassified images to show"),
168
+ gr.Checkbox(True, label="Visulize GradCAM for Misclassified images?"),
169
+ gr.Slider(value=10, maximum=25, minimum=5, step=5.0, precision=0, label="#GradCAM images to show"),
170
+ ],
171
+ outputs=[
172
+ gr.Label(label="Confidences", container=True, show_label=True),
173
+ gr.Image(shape=(32, 32), label="Grad CAM/ Input Image", container=True, show_label=True).style(
174
+ width=256, height=256
175
+ ),
176
+ gr.Plot(label="Misclassified images", container=True, show_label=True),
177
+ gr.Plot(label="Grad CAM of Misclassified images", container=True, show_label=True),
178
+ ],
179
+ title=TITLE,
180
+ description=DESCRIPTION,
181
+ examples=examples,
182
+ )
183
+ inference_app.launch()
184
 
185
 
{app/assets β†’ assets}/images/airplane.jpg RENAMED
File without changes
{app/assets β†’ assets}/images/bird.jpeg RENAMED
File without changes
{app/assets β†’ assets}/images/car.jpg RENAMED
File without changes
{app/assets β†’ assets}/images/cat.jpeg RENAMED
File without changes
{app/assets β†’ assets}/images/deer.jpg RENAMED
File without changes
{app/assets β†’ assets}/images/dog.jpg RENAMED
File without changes
{app/assets β†’ assets}/images/frog.jpeg RENAMED
File without changes
{app/assets β†’ assets}/images/horse.jpg RENAMED
File without changes
{app/assets β†’ assets}/images/ship.jpg RENAMED
File without changes
{app/assets β†’ assets}/images/truck.jpg RENAMED
File without changes
modules/config.py CHANGED
@@ -47,4 +47,8 @@ CIFAR_CLASSES = tuple(
47
  "ship",
48
  "truck",
49
  ]
50
- )
 
 
 
 
 
47
  "ship",
48
  "truck",
49
  ]
50
+ )
51
+
52
+
53
+ GRADIO_MISCLASSIFIED_PATH = "./assets/model/Misclassified_Data.pt"
54
+ GRADIO_MODEL_PATH = "./assets/model/CustomResNet.pt"