File size: 5,123 Bytes
1ffed57
 
 
 
 
a534ad2
9d912f9
 
 
 
 
 
 
 
 
1ffed57
 
 
9d912f9
1ffed57
 
 
 
 
 
9d912f9
 
 
 
cebd76b
 
ffe5b98
9d912f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ffed57
f60ff45
9d912f9
 
 
 
 
1ffed57
 
 
 
 
 
 
9d912f9
 
 
 
 
 
 
47b7204
 
 
 
 
 
 
 
30b6d93
 
 
 
 
 
1ffed57
cebd76b
 
 
1ffed57
cebd76b
30b6d93
f60ff45
9d912f9
 
 
3c39c38
 
 
 
 
 
 
 
 
 
9d912f9
 
 
47b7204
 
6373a51
30b6d93
1ffed57
 
f60ff45
9d912f9
 
 
 
1ffed57
f60ff45
9d912f9
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import dataset
import albumentations
from utils import get_misclassified_data
from albumentations.pytorch import ToTensorV2
from visualize import display_cifar_misclassified_data
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import ResNet18
import gradio as gr

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'

model = ResNet18()
model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)), strict=False)

# dataloader arguments - something you'll fetch these from cmdprmt
dataloader_args = dict(shuffle=True, batch_size=128, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

test_loader = dataset.get_test_data_loader(**dataloader_args)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Get the misclassified data from test dataset
misclassified_data = get_misclassified_data(model, device, test_loader)

def resize_image_pil(image, new_width, new_height):

    # Convert to PIL image
    img = Image.fromarray(np.array(image))
    
    # Get original size
    width, height = img.size

    # Calculate scale
    width_scale = new_width / width
    height_scale = new_height / height 
    scale = min(width_scale, height_scale)

    # Resize
    resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
    
    # Crop to exact size
    resized = resized.crop((0, 0, new_width, new_height))

    return resized

def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_number = -1,
              top_predictions=3, is_misclassified_images=True, num_misclassified_images=10):
    input_img = resize_image_pil(input_img, 32, 32)
    
    input_img = np.array(input_img)
    org_img = input_img
    input_img = input_img.reshape((32, 32, 3))
    transforms = albumentations.Compose(
        # Normalize
        [albumentations.Normalize([0.49139968, 0.48215841, 0.44653091], 
                                  [0.24703223, 0.24348513, 0.26158784]), 
        # Convert to tensor                          
        ToTensorV2()])
    input_img = transforms(image = input_img)['image']
    input_img = input_img
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)
    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i]: float(o[i]) for i in range(10)}
    _, prediction = torch.max(outputs, 1)
    if is_grad_cam:
        target_layers = [model.layer2[target_layer_number]]
        cam = GradCAM(model=model, target_layers=target_layers)
        grayscale_cam = cam(input_tensor=input_img, targets=None)
        grayscale_cam = grayscale_cam[0, :]
        visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
    else:
        visualization = None

    # Sort the confidences dictionary based on confidence values
    sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True))
    
    # Pick the top n predictions
    top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])

    if is_misclassified_images:
        # Plot the misclassified data
        misclassified_images = display_cifar_misclassified_data(misclassified_data, number_of_samples=num_misclassified_images)
    else:
        misclassified_images = None
    
    return classes[prediction[0].item()], visualization, top_n_confidences, misclassified_images

title = "CIFAR10 trained on ResNet18 Model with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
examples = [["cat.jpg", True, 0.5, -1, 3, True, 10], 
            ["dog.jpg", True, 0.5, -1, 3, True, 10], 
            ["bird.jpg", True, 0.5, -1, 3, True, 10], 
            ["car.jpg", True, 0.5, -1, 3, True, 10], 
            ["deer.jpg", True, 0.5, -1, 3, True, 10], 
            ["frog.jpg", True, 0.5, -1, 3, True, 10], 
            ["horse.jpg", True, 0.5, -1, 3, True, 10], 
            ["plane.jpg", True, 0.5, -1, 3, True, 10], 
            ["ship.jpg", True, 0.5, -1, 3, True, 10], 
            ["truck.jpg", True, 0.5, -1, 3, True, 10]]
demo = gr.Interface(
    inference, 
    inputs = [
        gr.Image(width=256, height=256, label="Input Image"),
        gr.Checkbox(label="Show GradCAM"),
        gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"), 
        gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
        gr.Slider(2, 10, value=3, step=1, label="Number of Top Classes"),
        gr.Checkbox(label="Show Misclassified Images"),
        gr.Slider(5, 40, value=10, step=5, label="Number of Misclassified Images")
        ], 
    outputs = [
        "text", 
        gr.Image(width=256, height=256, label="Output"),
        gr.Label(label="Top Classes"),
        gr.Plot(label="Misclassified Images")
        ],
    title = title,
    description = description,
    examples = examples,
)
demo.launch()