AkashDataScience commited on
Commit
1ffed57
·
1 Parent(s): afa3231

Added feature for missclassified images

Browse files
Files changed (4) hide show
  1. app.py +40 -10
  2. dataset.py +70 -0
  3. requirements.txt +10 -3
  4. visualize.py +116 -0
app.py CHANGED
@@ -1,4 +1,9 @@
1
- import torch, torchvision
 
 
 
 
 
2
  from torchvision import transforms
3
  import numpy as np
4
  import gradio as gr
@@ -8,12 +13,20 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
8
  from resnet import ResNet18
9
  import gradio as gr
10
 
 
 
 
11
  model = ResNet18()
12
- model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
 
 
 
 
 
13
 
14
  inv_normalize = transforms.Normalize(
15
- mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
16
- std=[1/0.23, 1/0.23, 1/0.23]
17
  )
18
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
19
  'dog', 'frog', 'horse', 'ship', 'truck')
@@ -39,14 +52,20 @@ def resize_image_pil(image, new_width, new_height):
39
 
40
  return resized
41
 
42
- def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_number = -1, top_predictions=3):
 
43
  input_img = resize_image_pil(input_img, 32, 32)
44
 
45
  input_img = np.array(input_img)
46
  org_img = input_img
47
  input_img = input_img.reshape((32, 32, 3))
48
- transform = transforms.ToTensor()
49
- input_img = transform(input_img)
 
 
 
 
 
50
  input_img = input_img
51
  input_img = input_img.unsqueeze(0)
52
  outputs = model(input_img)
@@ -68,8 +87,16 @@ def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_numb
68
 
69
  # Pick the top n predictions
70
  top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])
 
 
 
 
 
 
 
 
71
 
72
- return classes[prediction[0].item()], visualization, top_n_confidences
73
 
74
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
75
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
@@ -90,12 +117,15 @@ demo = gr.Interface(
90
  gr.Checkbox(label="Show GradCAM"),
91
  gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"),
92
  gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
93
- gr.Slider(2, 10, value=3, step=1, label="Number of Top Classes")
 
 
94
  ],
95
  outputs = [
96
  "text",
97
  gr.Image(width=256, height=256, label="Output"),
98
- gr.Label(label="Top Classes")
 
99
  ],
100
  title = title,
101
  description = description,
 
1
+ import torch
2
+ import dataset
3
+ import visualize
4
+ import albumentations
5
+ from utils import get_misclassified_data
6
+ from albumentations.pytorch import ToTensorV2
7
  from torchvision import transforms
8
  import numpy as np
9
  import gradio as gr
 
13
  from resnet import ResNet18
14
  import gradio as gr
15
 
16
+ cuda = torch.cuda.is_available()
17
+ device = 'cuda' if cuda else 'cpu'
18
+
19
  model = ResNet18()
20
+ model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)), strict=False)
21
+
22
+ # dataloader arguments - something you'll fetch these from cmdprmt
23
+ dataloader_args = dict(shuffle=True, batch_size=128, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)
24
+
25
+ test_loader = dataset.get_test_data_loader(**dataloader_args)
26
 
27
  inv_normalize = transforms.Normalize(
28
+ mean=[-0.48215841/0.24348513, -0.44653091/0.26158784, -0.49139968/0.24703223],
29
+ std=[1/0.24348513, 1/0.26158784, 1/0.24703223]
30
  )
31
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
32
  'dog', 'frog', 'horse', 'ship', 'truck')
 
52
 
53
  return resized
54
 
55
+ def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_number = -1,
56
+ top_predictions=3, is_missclassified_images=True, num_missclassified_images=10):
57
  input_img = resize_image_pil(input_img, 32, 32)
58
 
59
  input_img = np.array(input_img)
60
  org_img = input_img
61
  input_img = input_img.reshape((32, 32, 3))
62
+ transforms = albumentations.Compose(
63
+ # Normalize
64
+ [albumentations.Normalize([0.49139968, 0.48215841, 0.44653091],
65
+ [0.24703223, 0.24348513, 0.26158784]),
66
+ # Convert to tensor
67
+ ToTensorV2()])
68
+ input_img = transforms(image = input_img)['image']
69
  input_img = input_img
70
  input_img = input_img.unsqueeze(0)
71
  outputs = model(input_img)
 
87
 
88
  # Pick the top n predictions
89
  top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])
90
+
91
+ if is_missclassified_images:
92
+ # Get the misclassified data from test dataset
93
+ misclassified_data = get_misclassified_data(model, device, test_loader)
94
+ # Plot the misclassified data
95
+ visualize.display_cifar_misclassified_data(misclassified_data, number_of_samples=num_missclassified_images)
96
+ else:
97
+ missclassified_images = None
98
 
99
+ return classes[prediction[0].item()], visualization, top_n_confidences, missclassified_images
100
 
101
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
102
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
 
117
  gr.Checkbox(label="Show GradCAM"),
118
  gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"),
119
  gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
120
+ gr.Slider(2, 10, value=3, step=1, label="Number of Top Classes"),
121
+ gr.Checkbox(label="Show Misclassified Images"),
122
+ gr.Slider(5, 40, value=10, step=5, label="Number of Missclassified Images")
123
  ],
124
  outputs = [
125
  "text",
126
  gr.Image(width=256, height=256, label="Output"),
127
+ gr.Label(label="Top Classes"),
128
+ gr.Plot(label="Missclassified Images")
129
  ],
130
  title = title,
131
  description = description,
dataset.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import albumentations
3
+ from torchvision import datasets
4
+ from albumentations.pytorch import ToTensorV2
5
+ from torch.utils.data import Dataset, DataLoader
6
+
7
+ class CIFAR10Data(Dataset):
8
+ def __init__(self, dataset, transforms=None) -> None:
9
+ self.dataset = dataset
10
+ self.transforms = transforms
11
+
12
+ def __len__(self):
13
+ return len(self.dataset)
14
+
15
+ def __getitem__(self, index):
16
+ image, label = self.dataset[index]
17
+
18
+ image = np.array(image)
19
+
20
+ if self.transforms:
21
+ image = self.transforms(image=image)['image']
22
+
23
+ return image, label
24
+
25
+ def _get_test_transforms():
26
+ test_transforms = albumentations.Compose([albumentations.Normalize([0.49139968, 0.48215841, 0.44653091],
27
+ [0.24703223, 0.24348513, 0.26158784]),
28
+ ToTensorV2()])
29
+ return test_transforms
30
+
31
+ def _get_data(is_train, is_download):
32
+ """Method to get data for training or testing
33
+
34
+ Args:
35
+ is_train (bool): True if data is for training else false
36
+ is_download (bool): True to download dataset from iternet
37
+
38
+ Returns:
39
+ object: Oject of dataset
40
+ """
41
+ data = datasets.CIFAR10('../data', train=is_train, download=is_download)
42
+ return data
43
+
44
+ def _get_data_loader(data, **kwargs):
45
+ """Method to get data loader.
46
+
47
+ Args:
48
+ data (object): Oject of dataset
49
+
50
+ Returns:
51
+ object: Object of DataLoader class used to feed data to neural network model
52
+ """
53
+ loader = DataLoader(data, **kwargs)
54
+ return loader
55
+
56
+ def get_test_data_loader(**kwargs):
57
+ """Method to get data loader for testing
58
+
59
+ Args:
60
+ batch_size (int): Number of images in a batch
61
+
62
+ Returns:
63
+ object: Object of DataLoader class used to feed data to neural network model
64
+ """
65
+
66
+ test_transforms = _get_test_transforms()
67
+ test_data = _get_data(is_train=False, is_download=True)
68
+ test_data = CIFAR10Data(test_data, test_transforms)
69
+ test_loader = _get_data_loader(data=test_data, **kwargs)
70
+ return test_loader
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  aiofiles==23.2.1
2
  aiohttp==3.9.5
3
  aiosignal==1.3.1
 
4
  altair==5.3.0
5
  annotated-types==0.6.0
6
  anyio==4.3.0
@@ -30,6 +31,7 @@ httpcore==1.0.5
30
  httpx==0.27.0
31
  huggingface-hub==0.22.2
32
  idna==3.7
 
33
  importlib_resources==6.4.0
34
  intel-openmp==2021.4.0
35
  ipykernel==6.29.4
@@ -42,6 +44,7 @@ jsonschema-specifications==2023.12.1
42
  jupyter_client==8.6.1
43
  jupyter_core==5.7.2
44
  kiwisolver==1.4.5
 
45
  lightning==2.2.3
46
  lightning-utilities==0.11.2
47
  markdown-it-py==3.0.0
@@ -56,6 +59,7 @@ nest-asyncio==1.6.0
56
  networkx==3.2.1
57
  numpy==1.26.3
58
  opencv-python==4.9.0.80
 
59
  orjson==3.10.1
60
  packaging==24.0
61
  pandas==2.2.2
@@ -74,6 +78,7 @@ python-dateutil==2.9.0.post0
74
  python-multipart==0.0.9
75
  pytorch-lightning==2.2.3
76
  pytz==2024.1
 
77
  PyYAML==6.0.1
78
  pyzmq==26.0.2
79
  referencing==0.35.0
@@ -81,6 +86,7 @@ requests==2.31.0
81
  rich==13.7.1
82
  rpds-py==0.18.0
83
  ruff==0.4.2
 
84
  scikit-learn==1.4.2
85
  scipy==1.13.0
86
  semantic-version==2.10.0
@@ -93,14 +99,15 @@ starlette==0.37.2
93
  sympy==1.12
94
  tbb==2021.11.0
95
  threadpoolctl==3.4.0
 
96
  tomlkit==0.12.0
97
  toolz==0.12.1
98
- torch==2.3.0
99
  torch-lr-finder==0.2.1
100
- torchaudio==2.3.0
101
  torchmetrics==1.3.2
102
  torchsummary==1.5.1
103
- torchvision==0.18.0
104
  tornado==6.4
105
  tqdm==4.66.2
106
  traitlets==5.14.3
 
1
  aiofiles==23.2.1
2
  aiohttp==3.9.5
3
  aiosignal==1.3.1
4
+ albumentations==1.4.6
5
  altair==5.3.0
6
  annotated-types==0.6.0
7
  anyio==4.3.0
 
31
  httpx==0.27.0
32
  huggingface-hub==0.22.2
33
  idna==3.7
34
+ imageio==2.34.1
35
  importlib_resources==6.4.0
36
  intel-openmp==2021.4.0
37
  ipykernel==6.29.4
 
44
  jupyter_client==8.6.1
45
  jupyter_core==5.7.2
46
  kiwisolver==1.4.5
47
+ lazy_loader==0.4
48
  lightning==2.2.3
49
  lightning-utilities==0.11.2
50
  markdown-it-py==3.0.0
 
59
  networkx==3.2.1
60
  numpy==1.26.3
61
  opencv-python==4.9.0.80
62
+ opencv-python-headless==4.9.0.80
63
  orjson==3.10.1
64
  packaging==24.0
65
  pandas==2.2.2
 
78
  python-multipart==0.0.9
79
  pytorch-lightning==2.2.3
80
  pytz==2024.1
81
+ pywin32==306
82
  PyYAML==6.0.1
83
  pyzmq==26.0.2
84
  referencing==0.35.0
 
86
  rich==13.7.1
87
  rpds-py==0.18.0
88
  ruff==0.4.2
89
+ scikit-image==0.23.2
90
  scikit-learn==1.4.2
91
  scipy==1.13.0
92
  semantic-version==2.10.0
 
99
  sympy==1.12
100
  tbb==2021.11.0
101
  threadpoolctl==3.4.0
102
+ tifffile==2024.5.10
103
  tomlkit==0.12.0
104
  toolz==0.12.1
105
+ torch==2.3.0+cu121
106
  torch-lr-finder==0.2.1
107
+ torchaudio==2.3.0+cu121
108
  torchmetrics==1.3.2
109
  torchsummary==1.5.1
110
+ torchvision==0.18.0+cu121
111
  tornado==6.4
112
  tqdm==4.66.2
113
  traitlets==5.14.3
visualize.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import albumentations
4
+ import matplotlib.pyplot as plt
5
+
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
+
9
+ CLASS_NAMES= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
10
+
11
+ def get_inv_transforms():
12
+ """Method to get transform to inverse the effect of normalization for ploting
13
+
14
+ Returns:
15
+ _Object: Object to apply image augmentations
16
+ """
17
+ # Normalize image
18
+ inv_transforms = albumentations.Normalize([-0.48215841/0.24348513, -0.44653091/0.26158784, -0.49139968/0.24703223],
19
+ [1/0.24348513, 1/0.26158784, 1/0.24703223], max_pixel_value=1.0)
20
+ return inv_transforms
21
+
22
+ def plot_samples(train_loader, number_of_images):
23
+ """Method to plot samples of augmented images
24
+
25
+ Args:
26
+ train_loader (Object): Object of data loader class to get images
27
+ """
28
+ inv_transform = get_inv_transforms()
29
+
30
+ figure = plt.figure()
31
+ x_count = 5
32
+ y_count = 1 if number_of_images <= 5 else math.floor(number_of_images / x_count)
33
+ images, labels = next(iter(train_loader))
34
+
35
+ for index in range(1, number_of_images + 1):
36
+ plt.subplot(y_count, x_count, index)
37
+ plt.title(CLASS_NAMES[labels[index].numpy()])
38
+ plt.axis('off')
39
+ image = np.array(images[index])
40
+ image = np.transpose(image, (1, 2, 0))
41
+ image = inv_transform(image=image)['image']
42
+ plt.imshow(image)
43
+
44
+ def display_cifar_misclassified_data(data: list,
45
+ number_of_samples: int = 10):
46
+ """
47
+ Function to plot images with labels
48
+ :param data: List[Tuple(image, label)]
49
+ :param number_of_samples: Number of images to print
50
+ """
51
+ fig = plt.figure(figsize=(10, 10))
52
+ inv_transform = get_inv_transforms()
53
+
54
+ x_count = 5
55
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
56
+
57
+ for i in range(number_of_samples):
58
+ plt.subplot(y_count, x_count, i + 1)
59
+ img = np.array(data[i][0].squeeze().to('cpu'))
60
+ img = np.transpose(img, (1, 2, 0))
61
+ img = inv_transform(image=img)['image']
62
+ plt.imshow(img)
63
+ plt.title(r"Correct: " + CLASS_NAMES[data[i][1].item()] + '\n' + 'Output: ' + CLASS_NAMES[data[i][2].item()])
64
+ plt.xticks([])
65
+ plt.yticks([])
66
+
67
+ def display_gradcam_output(data: list,
68
+ model,
69
+ target_layers,
70
+ targets=None,
71
+ number_of_samples: int = 10,
72
+ transparency: float = 0.60):
73
+ """
74
+ Function to visualize GradCam output on the data
75
+ :param data: List[Tuple(image, label)]
76
+ :param classes: Name of classes in the dataset
77
+ :param inv_normalize: Mean and Standard deviation values of the dataset
78
+ :param model: Model architecture
79
+ :param target_layers: Layers on which GradCam should be executed
80
+ :param targets: Classes to be focused on for GradCam
81
+ :param number_of_samples: Number of images to print
82
+ :param transparency: Weight of Normal image when mixed with activations
83
+ """
84
+ # Plot configuration
85
+ fig = plt.figure(figsize=(10, 10))
86
+ inv_transform = get_inv_transforms()
87
+
88
+ x_count = 5
89
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
90
+
91
+ # Create an object for GradCam
92
+ cam = GradCAM(model=model, target_layers=target_layers)
93
+
94
+ # Iterate over number of specified images
95
+ for i in range(number_of_samples):
96
+ plt.subplot(y_count, x_count, i + 1)
97
+ input_tensor = data[i][0]
98
+
99
+ # Get the activations of the layer for the images
100
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
101
+ grayscale_cam = grayscale_cam[0, :]
102
+
103
+ # Get back the original image
104
+ img = np.array(input_tensor.squeeze(0).to('cpu'))
105
+ img = np.transpose(img, (1, 2, 0))
106
+ img = inv_transform(image=img)['image']
107
+ rgb_img = np.clip(img, 0, 1)
108
+
109
+ # Mix the activations on the original image
110
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
111
+
112
+ # Display the images on the plot
113
+ plt.imshow(visualization)
114
+ plt.title(r"Correct: " + CLASS_NAMES[data[i][1].item()] + '\n' + 'Output: ' + CLASS_NAMES[data[i][2].item()])
115
+ plt.xticks([])
116
+ plt.yticks([])