Spaces:
Sleeping
Sleeping
Updated resnet class
Browse files
resnet.py
CHANGED
@@ -24,13 +24,20 @@ import torchvision.transforms as transforms
|
|
24 |
import torchvision.datasets as datasets
|
25 |
import pytorch_lightning as pl
|
26 |
import matplotlib.pyplot as plt
|
27 |
-
|
28 |
|
29 |
|
30 |
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
|
31 |
BATCH_SIZE = 256
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Model
|
35 |
class custom_ResNet(pl.LightningModule):
|
36 |
def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4):
|
@@ -168,7 +175,7 @@ class custom_ResNet(pl.LightningModule):
|
|
168 |
acc = pred.eq(y.view_as(pred)).float().mean()
|
169 |
self.log('test_loss', loss, prog_bar=True)
|
170 |
self.log('test_acc', acc, prog_bar=True)
|
171 |
-
return pred
|
172 |
|
173 |
def configure_optimizers(self):
|
174 |
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
|
@@ -188,7 +195,7 @@ class custom_ResNet(pl.LightningModule):
|
|
188 |
|
189 |
# Assign train/val datasets for use in dataloaders
|
190 |
if stage == "fit" or stage is None:
|
191 |
-
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
|
192 |
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
|
193 |
|
194 |
# Assign test dataset for use in dataloader(s)
|
@@ -215,9 +222,9 @@ class custom_ResNet(pl.LightningModule):
|
|
215 |
y_hat = self.forward(x)
|
216 |
pred = y_hat.argmax(dim=1, keepdim=True)
|
217 |
misclassified_mask = pred.eq(y.view_as(pred)).squeeze()
|
218 |
-
misclassified_images.extend(x[~misclassified_mask].detach())
|
219 |
-
misclassified_true_labels.extend(y[~misclassified_mask].detach())
|
220 |
-
misclassified_predicted_labels.extend(pred[~misclassified_mask].detach())
|
221 |
|
222 |
num_collected += sum(~misclassified_mask)
|
223 |
|
@@ -260,15 +267,33 @@ class custom_ResNet(pl.LightningModule):
|
|
260 |
|
261 |
return gradcam_images
|
262 |
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
def show_misclassified_images(self, num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5):
|
265 |
misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
|
266 |
-
|
267 |
-
# Create subplots based on the number of columns required
|
268 |
-
num_rows = num_images
|
269 |
-
num_cols = 2 if use_gradcam else 1 # Show GradCAM images side by side with misclassified images if 'use_gradcam' is True
|
270 |
|
271 |
-
fig,
|
272 |
|
273 |
if use_gradcam:
|
274 |
grad_cam_images = self.get_gradcam_images(target_layer=gradcam_layer, transparency=transparency, num_images=num_images)
|
@@ -277,22 +302,9 @@ class custom_ResNet(pl.LightningModule):
|
|
277 |
img = misclassified_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C)
|
278 |
img = self.normalize_image(img) # Normalize the image
|
279 |
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
axs[i, 0].axis("off")
|
284 |
-
|
285 |
-
if use_gradcam:
|
286 |
-
# gradcam_img = grad_cam_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C)
|
287 |
-
gradcam_img = self.normalize_image(grad_cam_images[i]) # Normalize the image
|
288 |
-
axs[i, 1].imshow(gradcam_img)
|
289 |
-
axs[i, 1].set_title("GradCAM")
|
290 |
-
axs[i, 1].axis("off")
|
291 |
-
else: # Use a single column for subplots
|
292 |
-
axs[i].imshow(img)
|
293 |
-
axs[i].set_title(f"True label: {self.classes[true_labels[i]]}\nPredicted: {self.classes[predicted_labels[i]]}")
|
294 |
-
axs[i].axis("off")
|
295 |
-
|
296 |
-
fig.tight_layout()
|
297 |
-
return fig
|
298 |
|
|
|
|
|
|
24 |
import torchvision.datasets as datasets
|
25 |
import pytorch_lightning as pl
|
26 |
import matplotlib.pyplot as plt
|
27 |
+
import matplotlib.gridspec as gridspec
|
28 |
|
29 |
|
30 |
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
|
31 |
BATCH_SIZE = 256
|
32 |
|
33 |
|
34 |
+
import torch
|
35 |
+
import torch.nn as nn
|
36 |
+
import torch.nn.functional as F
|
37 |
+
from torchsummary import summary
|
38 |
+
from io import BytesIO
|
39 |
+
import numpy as np
|
40 |
+
|
41 |
# Model
|
42 |
class custom_ResNet(pl.LightningModule):
|
43 |
def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4):
|
|
|
175 |
acc = pred.eq(y.view_as(pred)).float().mean()
|
176 |
self.log('test_loss', loss, prog_bar=True)
|
177 |
self.log('test_acc', acc, prog_bar=True)
|
178 |
+
return pred
|
179 |
|
180 |
def configure_optimizers(self):
|
181 |
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
|
|
|
195 |
|
196 |
# Assign train/val datasets for use in dataloaders
|
197 |
if stage == "fit" or stage is None:
|
198 |
+
cifar_full = CIFAR10(self.data_dir, train=True, download=True, transform=self.train_transform)
|
199 |
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
|
200 |
|
201 |
# Assign test dataset for use in dataloader(s)
|
|
|
222 |
y_hat = self.forward(x)
|
223 |
pred = y_hat.argmax(dim=1, keepdim=True)
|
224 |
misclassified_mask = pred.eq(y.view_as(pred)).squeeze()
|
225 |
+
misclassified_images.extend(x[~misclassified_mask].detach())
|
226 |
+
misclassified_true_labels.extend(y[~misclassified_mask].detach())
|
227 |
+
misclassified_predicted_labels.extend(pred[~misclassified_mask].detach())
|
228 |
|
229 |
num_collected += sum(~misclassified_mask)
|
230 |
|
|
|
267 |
|
268 |
return gradcam_images
|
269 |
|
270 |
+
def create_layout(self, num_images, use_gradcam):
|
271 |
+
num_cols = 3 if use_gradcam else 2
|
272 |
+
fig = plt.figure(figsize=(12, 5 * num_images))
|
273 |
+
gs = gridspec.GridSpec(num_images, num_cols, figure=fig, width_ratios=[0.3, 1, 1] if use_gradcam else [0.5, 1])
|
274 |
+
|
275 |
+
return fig, gs
|
276 |
+
|
277 |
+
def show_images_with_labels(self, fig, gs, i, img, label_text, use_gradcam=False, gradcam_img=None):
|
278 |
+
ax_img = fig.add_subplot(gs[i, 1])
|
279 |
+
ax_img.imshow(img)
|
280 |
+
ax_img.set_title("Original Image")
|
281 |
+
ax_img.axis("off")
|
282 |
+
|
283 |
+
if use_gradcam:
|
284 |
+
ax_gradcam = fig.add_subplot(gs[i, 2])
|
285 |
+
ax_gradcam.imshow(gradcam_img)
|
286 |
+
ax_gradcam.set_title("GradCAM Image")
|
287 |
+
ax_gradcam.axis("off")
|
288 |
+
|
289 |
+
ax_label = fig.add_subplot(gs[i, 0])
|
290 |
+
ax_label.text(0, 0.5, label_text, fontsize=10, verticalalignment='center')
|
291 |
+
ax_label.axis("off")
|
292 |
+
|
293 |
def show_misclassified_images(self, num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5):
|
294 |
misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
|
|
|
|
|
|
|
|
|
295 |
|
296 |
+
fig, gs = self.create_layout(num_images, use_gradcam)
|
297 |
|
298 |
if use_gradcam:
|
299 |
grad_cam_images = self.get_gradcam_images(target_layer=gradcam_layer, transparency=transparency, num_images=num_images)
|
|
|
302 |
img = misclassified_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C)
|
303 |
img = self.normalize_image(img) # Normalize the image
|
304 |
|
305 |
+
# Show true label and predicted label on the left, and images on the right
|
306 |
+
label_text = f"True Label: {self.classes[true_labels[i]]}\nPredicted Label: {self.classes[predicted_labels[i]]}"
|
307 |
+
self.show_images_with_labels(fig, gs, i, img, label_text, use_gradcam, grad_cam_images[i] if use_gradcam else None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
+
plt.tight_layout()
|
310 |
+
return fig
|