Spaces:
Sleeping
Sleeping
Commit
·
1ffed57
1
Parent(s):
afa3231
Added feature for missclassified images
Browse files- app.py +40 -10
- dataset.py +70 -0
- requirements.txt +10 -3
- visualize.py +116 -0
app.py
CHANGED
@@ -1,4 +1,9 @@
|
|
1 |
-
import torch
|
|
|
|
|
|
|
|
|
|
|
2 |
from torchvision import transforms
|
3 |
import numpy as np
|
4 |
import gradio as gr
|
@@ -8,12 +13,20 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
8 |
from resnet import ResNet18
|
9 |
import gradio as gr
|
10 |
|
|
|
|
|
|
|
11 |
model = ResNet18()
|
12 |
-
model.load_state_dict(torch.load("model.pth", map_location=torch.device(
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
inv_normalize = transforms.Normalize(
|
15 |
-
mean=[-0.
|
16 |
-
std=[1/0.
|
17 |
)
|
18 |
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
19 |
'dog', 'frog', 'horse', 'ship', 'truck')
|
@@ -39,14 +52,20 @@ def resize_image_pil(image, new_width, new_height):
|
|
39 |
|
40 |
return resized
|
41 |
|
42 |
-
def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_number = -1,
|
|
|
43 |
input_img = resize_image_pil(input_img, 32, 32)
|
44 |
|
45 |
input_img = np.array(input_img)
|
46 |
org_img = input_img
|
47 |
input_img = input_img.reshape((32, 32, 3))
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
input_img = input_img
|
51 |
input_img = input_img.unsqueeze(0)
|
52 |
outputs = model(input_img)
|
@@ -68,8 +87,16 @@ def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_numb
|
|
68 |
|
69 |
# Pick the top n predictions
|
70 |
top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
return classes[prediction[0].item()], visualization, top_n_confidences
|
73 |
|
74 |
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
|
75 |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
|
@@ -90,12 +117,15 @@ demo = gr.Interface(
|
|
90 |
gr.Checkbox(label="Show GradCAM"),
|
91 |
gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"),
|
92 |
gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
|
93 |
-
gr.Slider(2, 10, value=3, step=1, label="Number of Top Classes")
|
|
|
|
|
94 |
],
|
95 |
outputs = [
|
96 |
"text",
|
97 |
gr.Image(width=256, height=256, label="Output"),
|
98 |
-
gr.Label(label="Top Classes")
|
|
|
99 |
],
|
100 |
title = title,
|
101 |
description = description,
|
|
|
1 |
+
import torch
|
2 |
+
import dataset
|
3 |
+
import visualize
|
4 |
+
import albumentations
|
5 |
+
from utils import get_misclassified_data
|
6 |
+
from albumentations.pytorch import ToTensorV2
|
7 |
from torchvision import transforms
|
8 |
import numpy as np
|
9 |
import gradio as gr
|
|
|
13 |
from resnet import ResNet18
|
14 |
import gradio as gr
|
15 |
|
16 |
+
cuda = torch.cuda.is_available()
|
17 |
+
device = 'cuda' if cuda else 'cpu'
|
18 |
+
|
19 |
model = ResNet18()
|
20 |
+
model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)), strict=False)
|
21 |
+
|
22 |
+
# dataloader arguments - something you'll fetch these from cmdprmt
|
23 |
+
dataloader_args = dict(shuffle=True, batch_size=128, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)
|
24 |
+
|
25 |
+
test_loader = dataset.get_test_data_loader(**dataloader_args)
|
26 |
|
27 |
inv_normalize = transforms.Normalize(
|
28 |
+
mean=[-0.48215841/0.24348513, -0.44653091/0.26158784, -0.49139968/0.24703223],
|
29 |
+
std=[1/0.24348513, 1/0.26158784, 1/0.24703223]
|
30 |
)
|
31 |
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
32 |
'dog', 'frog', 'horse', 'ship', 'truck')
|
|
|
52 |
|
53 |
return resized
|
54 |
|
55 |
+
def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_number = -1,
|
56 |
+
top_predictions=3, is_missclassified_images=True, num_missclassified_images=10):
|
57 |
input_img = resize_image_pil(input_img, 32, 32)
|
58 |
|
59 |
input_img = np.array(input_img)
|
60 |
org_img = input_img
|
61 |
input_img = input_img.reshape((32, 32, 3))
|
62 |
+
transforms = albumentations.Compose(
|
63 |
+
# Normalize
|
64 |
+
[albumentations.Normalize([0.49139968, 0.48215841, 0.44653091],
|
65 |
+
[0.24703223, 0.24348513, 0.26158784]),
|
66 |
+
# Convert to tensor
|
67 |
+
ToTensorV2()])
|
68 |
+
input_img = transforms(image = input_img)['image']
|
69 |
input_img = input_img
|
70 |
input_img = input_img.unsqueeze(0)
|
71 |
outputs = model(input_img)
|
|
|
87 |
|
88 |
# Pick the top n predictions
|
89 |
top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])
|
90 |
+
|
91 |
+
if is_missclassified_images:
|
92 |
+
# Get the misclassified data from test dataset
|
93 |
+
misclassified_data = get_misclassified_data(model, device, test_loader)
|
94 |
+
# Plot the misclassified data
|
95 |
+
visualize.display_cifar_misclassified_data(misclassified_data, number_of_samples=num_missclassified_images)
|
96 |
+
else:
|
97 |
+
missclassified_images = None
|
98 |
|
99 |
+
return classes[prediction[0].item()], visualization, top_n_confidences, missclassified_images
|
100 |
|
101 |
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
|
102 |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
|
|
|
117 |
gr.Checkbox(label="Show GradCAM"),
|
118 |
gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"),
|
119 |
gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
|
120 |
+
gr.Slider(2, 10, value=3, step=1, label="Number of Top Classes"),
|
121 |
+
gr.Checkbox(label="Show Misclassified Images"),
|
122 |
+
gr.Slider(5, 40, value=10, step=5, label="Number of Missclassified Images")
|
123 |
],
|
124 |
outputs = [
|
125 |
"text",
|
126 |
gr.Image(width=256, height=256, label="Output"),
|
127 |
+
gr.Label(label="Top Classes"),
|
128 |
+
gr.Plot(label="Missclassified Images")
|
129 |
],
|
130 |
title = title,
|
131 |
description = description,
|
dataset.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import albumentations
|
3 |
+
from torchvision import datasets
|
4 |
+
from albumentations.pytorch import ToTensorV2
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
|
7 |
+
class CIFAR10Data(Dataset):
|
8 |
+
def __init__(self, dataset, transforms=None) -> None:
|
9 |
+
self.dataset = dataset
|
10 |
+
self.transforms = transforms
|
11 |
+
|
12 |
+
def __len__(self):
|
13 |
+
return len(self.dataset)
|
14 |
+
|
15 |
+
def __getitem__(self, index):
|
16 |
+
image, label = self.dataset[index]
|
17 |
+
|
18 |
+
image = np.array(image)
|
19 |
+
|
20 |
+
if self.transforms:
|
21 |
+
image = self.transforms(image=image)['image']
|
22 |
+
|
23 |
+
return image, label
|
24 |
+
|
25 |
+
def _get_test_transforms():
|
26 |
+
test_transforms = albumentations.Compose([albumentations.Normalize([0.49139968, 0.48215841, 0.44653091],
|
27 |
+
[0.24703223, 0.24348513, 0.26158784]),
|
28 |
+
ToTensorV2()])
|
29 |
+
return test_transforms
|
30 |
+
|
31 |
+
def _get_data(is_train, is_download):
|
32 |
+
"""Method to get data for training or testing
|
33 |
+
|
34 |
+
Args:
|
35 |
+
is_train (bool): True if data is for training else false
|
36 |
+
is_download (bool): True to download dataset from iternet
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
object: Oject of dataset
|
40 |
+
"""
|
41 |
+
data = datasets.CIFAR10('../data', train=is_train, download=is_download)
|
42 |
+
return data
|
43 |
+
|
44 |
+
def _get_data_loader(data, **kwargs):
|
45 |
+
"""Method to get data loader.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
data (object): Oject of dataset
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
object: Object of DataLoader class used to feed data to neural network model
|
52 |
+
"""
|
53 |
+
loader = DataLoader(data, **kwargs)
|
54 |
+
return loader
|
55 |
+
|
56 |
+
def get_test_data_loader(**kwargs):
|
57 |
+
"""Method to get data loader for testing
|
58 |
+
|
59 |
+
Args:
|
60 |
+
batch_size (int): Number of images in a batch
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
object: Object of DataLoader class used to feed data to neural network model
|
64 |
+
"""
|
65 |
+
|
66 |
+
test_transforms = _get_test_transforms()
|
67 |
+
test_data = _get_data(is_train=False, is_download=True)
|
68 |
+
test_data = CIFAR10Data(test_data, test_transforms)
|
69 |
+
test_loader = _get_data_loader(data=test_data, **kwargs)
|
70 |
+
return test_loader
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
aiofiles==23.2.1
|
2 |
aiohttp==3.9.5
|
3 |
aiosignal==1.3.1
|
|
|
4 |
altair==5.3.0
|
5 |
annotated-types==0.6.0
|
6 |
anyio==4.3.0
|
@@ -30,6 +31,7 @@ httpcore==1.0.5
|
|
30 |
httpx==0.27.0
|
31 |
huggingface-hub==0.22.2
|
32 |
idna==3.7
|
|
|
33 |
importlib_resources==6.4.0
|
34 |
intel-openmp==2021.4.0
|
35 |
ipykernel==6.29.4
|
@@ -42,6 +44,7 @@ jsonschema-specifications==2023.12.1
|
|
42 |
jupyter_client==8.6.1
|
43 |
jupyter_core==5.7.2
|
44 |
kiwisolver==1.4.5
|
|
|
45 |
lightning==2.2.3
|
46 |
lightning-utilities==0.11.2
|
47 |
markdown-it-py==3.0.0
|
@@ -56,6 +59,7 @@ nest-asyncio==1.6.0
|
|
56 |
networkx==3.2.1
|
57 |
numpy==1.26.3
|
58 |
opencv-python==4.9.0.80
|
|
|
59 |
orjson==3.10.1
|
60 |
packaging==24.0
|
61 |
pandas==2.2.2
|
@@ -74,6 +78,7 @@ python-dateutil==2.9.0.post0
|
|
74 |
python-multipart==0.0.9
|
75 |
pytorch-lightning==2.2.3
|
76 |
pytz==2024.1
|
|
|
77 |
PyYAML==6.0.1
|
78 |
pyzmq==26.0.2
|
79 |
referencing==0.35.0
|
@@ -81,6 +86,7 @@ requests==2.31.0
|
|
81 |
rich==13.7.1
|
82 |
rpds-py==0.18.0
|
83 |
ruff==0.4.2
|
|
|
84 |
scikit-learn==1.4.2
|
85 |
scipy==1.13.0
|
86 |
semantic-version==2.10.0
|
@@ -93,14 +99,15 @@ starlette==0.37.2
|
|
93 |
sympy==1.12
|
94 |
tbb==2021.11.0
|
95 |
threadpoolctl==3.4.0
|
|
|
96 |
tomlkit==0.12.0
|
97 |
toolz==0.12.1
|
98 |
-
torch==2.3.0
|
99 |
torch-lr-finder==0.2.1
|
100 |
-
torchaudio==2.3.0
|
101 |
torchmetrics==1.3.2
|
102 |
torchsummary==1.5.1
|
103 |
-
torchvision==0.18.0
|
104 |
tornado==6.4
|
105 |
tqdm==4.66.2
|
106 |
traitlets==5.14.3
|
|
|
1 |
aiofiles==23.2.1
|
2 |
aiohttp==3.9.5
|
3 |
aiosignal==1.3.1
|
4 |
+
albumentations==1.4.6
|
5 |
altair==5.3.0
|
6 |
annotated-types==0.6.0
|
7 |
anyio==4.3.0
|
|
|
31 |
httpx==0.27.0
|
32 |
huggingface-hub==0.22.2
|
33 |
idna==3.7
|
34 |
+
imageio==2.34.1
|
35 |
importlib_resources==6.4.0
|
36 |
intel-openmp==2021.4.0
|
37 |
ipykernel==6.29.4
|
|
|
44 |
jupyter_client==8.6.1
|
45 |
jupyter_core==5.7.2
|
46 |
kiwisolver==1.4.5
|
47 |
+
lazy_loader==0.4
|
48 |
lightning==2.2.3
|
49 |
lightning-utilities==0.11.2
|
50 |
markdown-it-py==3.0.0
|
|
|
59 |
networkx==3.2.1
|
60 |
numpy==1.26.3
|
61 |
opencv-python==4.9.0.80
|
62 |
+
opencv-python-headless==4.9.0.80
|
63 |
orjson==3.10.1
|
64 |
packaging==24.0
|
65 |
pandas==2.2.2
|
|
|
78 |
python-multipart==0.0.9
|
79 |
pytorch-lightning==2.2.3
|
80 |
pytz==2024.1
|
81 |
+
pywin32==306
|
82 |
PyYAML==6.0.1
|
83 |
pyzmq==26.0.2
|
84 |
referencing==0.35.0
|
|
|
86 |
rich==13.7.1
|
87 |
rpds-py==0.18.0
|
88 |
ruff==0.4.2
|
89 |
+
scikit-image==0.23.2
|
90 |
scikit-learn==1.4.2
|
91 |
scipy==1.13.0
|
92 |
semantic-version==2.10.0
|
|
|
99 |
sympy==1.12
|
100 |
tbb==2021.11.0
|
101 |
threadpoolctl==3.4.0
|
102 |
+
tifffile==2024.5.10
|
103 |
tomlkit==0.12.0
|
104 |
toolz==0.12.1
|
105 |
+
torch==2.3.0+cu121
|
106 |
torch-lr-finder==0.2.1
|
107 |
+
torchaudio==2.3.0+cu121
|
108 |
torchmetrics==1.3.2
|
109 |
torchsummary==1.5.1
|
110 |
+
torchvision==0.18.0+cu121
|
111 |
tornado==6.4
|
112 |
tqdm==4.66.2
|
113 |
traitlets==5.14.3
|
visualize.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import albumentations
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from pytorch_grad_cam import GradCAM
|
7 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
8 |
+
|
9 |
+
CLASS_NAMES= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
10 |
+
|
11 |
+
def get_inv_transforms():
|
12 |
+
"""Method to get transform to inverse the effect of normalization for ploting
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
_Object: Object to apply image augmentations
|
16 |
+
"""
|
17 |
+
# Normalize image
|
18 |
+
inv_transforms = albumentations.Normalize([-0.48215841/0.24348513, -0.44653091/0.26158784, -0.49139968/0.24703223],
|
19 |
+
[1/0.24348513, 1/0.26158784, 1/0.24703223], max_pixel_value=1.0)
|
20 |
+
return inv_transforms
|
21 |
+
|
22 |
+
def plot_samples(train_loader, number_of_images):
|
23 |
+
"""Method to plot samples of augmented images
|
24 |
+
|
25 |
+
Args:
|
26 |
+
train_loader (Object): Object of data loader class to get images
|
27 |
+
"""
|
28 |
+
inv_transform = get_inv_transforms()
|
29 |
+
|
30 |
+
figure = plt.figure()
|
31 |
+
x_count = 5
|
32 |
+
y_count = 1 if number_of_images <= 5 else math.floor(number_of_images / x_count)
|
33 |
+
images, labels = next(iter(train_loader))
|
34 |
+
|
35 |
+
for index in range(1, number_of_images + 1):
|
36 |
+
plt.subplot(y_count, x_count, index)
|
37 |
+
plt.title(CLASS_NAMES[labels[index].numpy()])
|
38 |
+
plt.axis('off')
|
39 |
+
image = np.array(images[index])
|
40 |
+
image = np.transpose(image, (1, 2, 0))
|
41 |
+
image = inv_transform(image=image)['image']
|
42 |
+
plt.imshow(image)
|
43 |
+
|
44 |
+
def display_cifar_misclassified_data(data: list,
|
45 |
+
number_of_samples: int = 10):
|
46 |
+
"""
|
47 |
+
Function to plot images with labels
|
48 |
+
:param data: List[Tuple(image, label)]
|
49 |
+
:param number_of_samples: Number of images to print
|
50 |
+
"""
|
51 |
+
fig = plt.figure(figsize=(10, 10))
|
52 |
+
inv_transform = get_inv_transforms()
|
53 |
+
|
54 |
+
x_count = 5
|
55 |
+
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
|
56 |
+
|
57 |
+
for i in range(number_of_samples):
|
58 |
+
plt.subplot(y_count, x_count, i + 1)
|
59 |
+
img = np.array(data[i][0].squeeze().to('cpu'))
|
60 |
+
img = np.transpose(img, (1, 2, 0))
|
61 |
+
img = inv_transform(image=img)['image']
|
62 |
+
plt.imshow(img)
|
63 |
+
plt.title(r"Correct: " + CLASS_NAMES[data[i][1].item()] + '\n' + 'Output: ' + CLASS_NAMES[data[i][2].item()])
|
64 |
+
plt.xticks([])
|
65 |
+
plt.yticks([])
|
66 |
+
|
67 |
+
def display_gradcam_output(data: list,
|
68 |
+
model,
|
69 |
+
target_layers,
|
70 |
+
targets=None,
|
71 |
+
number_of_samples: int = 10,
|
72 |
+
transparency: float = 0.60):
|
73 |
+
"""
|
74 |
+
Function to visualize GradCam output on the data
|
75 |
+
:param data: List[Tuple(image, label)]
|
76 |
+
:param classes: Name of classes in the dataset
|
77 |
+
:param inv_normalize: Mean and Standard deviation values of the dataset
|
78 |
+
:param model: Model architecture
|
79 |
+
:param target_layers: Layers on which GradCam should be executed
|
80 |
+
:param targets: Classes to be focused on for GradCam
|
81 |
+
:param number_of_samples: Number of images to print
|
82 |
+
:param transparency: Weight of Normal image when mixed with activations
|
83 |
+
"""
|
84 |
+
# Plot configuration
|
85 |
+
fig = plt.figure(figsize=(10, 10))
|
86 |
+
inv_transform = get_inv_transforms()
|
87 |
+
|
88 |
+
x_count = 5
|
89 |
+
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
|
90 |
+
|
91 |
+
# Create an object for GradCam
|
92 |
+
cam = GradCAM(model=model, target_layers=target_layers)
|
93 |
+
|
94 |
+
# Iterate over number of specified images
|
95 |
+
for i in range(number_of_samples):
|
96 |
+
plt.subplot(y_count, x_count, i + 1)
|
97 |
+
input_tensor = data[i][0]
|
98 |
+
|
99 |
+
# Get the activations of the layer for the images
|
100 |
+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
|
101 |
+
grayscale_cam = grayscale_cam[0, :]
|
102 |
+
|
103 |
+
# Get back the original image
|
104 |
+
img = np.array(input_tensor.squeeze(0).to('cpu'))
|
105 |
+
img = np.transpose(img, (1, 2, 0))
|
106 |
+
img = inv_transform(image=img)['image']
|
107 |
+
rgb_img = np.clip(img, 0, 1)
|
108 |
+
|
109 |
+
# Mix the activations on the original image
|
110 |
+
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
111 |
+
|
112 |
+
# Display the images on the plot
|
113 |
+
plt.imshow(visualization)
|
114 |
+
plt.title(r"Correct: " + CLASS_NAMES[data[i][1].item()] + '\n' + 'Output: ' + CLASS_NAMES[data[i][2].item()])
|
115 |
+
plt.xticks([])
|
116 |
+
plt.yticks([])
|