mkthoma commited on
Commit
2f71151
·
1 Parent(s): 547f90a

Updated resnet class

Browse files
Files changed (1) hide show
  1. resnet.py +42 -30
resnet.py CHANGED
@@ -24,13 +24,20 @@ import torchvision.transforms as transforms
24
  import torchvision.datasets as datasets
25
  import pytorch_lightning as pl
26
  import matplotlib.pyplot as plt
27
-
28
 
29
 
30
  PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
31
  BATCH_SIZE = 256
32
 
33
 
 
 
 
 
 
 
 
34
  # Model
35
  class custom_ResNet(pl.LightningModule):
36
  def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4):
@@ -168,7 +175,7 @@ class custom_ResNet(pl.LightningModule):
168
  acc = pred.eq(y.view_as(pred)).float().mean()
169
  self.log('test_loss', loss, prog_bar=True)
170
  self.log('test_acc', acc, prog_bar=True)
171
- return pred # Return predictions instead of loss
172
 
173
  def configure_optimizers(self):
174
  optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
@@ -188,7 +195,7 @@ class custom_ResNet(pl.LightningModule):
188
 
189
  # Assign train/val datasets for use in dataloaders
190
  if stage == "fit" or stage is None:
191
- cifar_full = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
192
  self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
193
 
194
  # Assign test dataset for use in dataloader(s)
@@ -215,9 +222,9 @@ class custom_ResNet(pl.LightningModule):
215
  y_hat = self.forward(x)
216
  pred = y_hat.argmax(dim=1, keepdim=True)
217
  misclassified_mask = pred.eq(y.view_as(pred)).squeeze()
218
- misclassified_images.extend(x[~misclassified_mask].detach()) # Detach here to avoid CPU transfer
219
- misclassified_true_labels.extend(y[~misclassified_mask].detach()) # Detach here to avoid CPU transfer
220
- misclassified_predicted_labels.extend(pred[~misclassified_mask].detach()) # Detach here to avoid CPU transfer
221
 
222
  num_collected += sum(~misclassified_mask)
223
 
@@ -260,15 +267,33 @@ class custom_ResNet(pl.LightningModule):
260
 
261
  return gradcam_images
262
 
263
- # Add a 'use_gradcam' parameter to the show_misclassified_images function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def show_misclassified_images(self, num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5):
265
  misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
266
-
267
- # Create subplots based on the number of columns required
268
- num_rows = num_images
269
- num_cols = 2 if use_gradcam else 1 # Show GradCAM images side by side with misclassified images if 'use_gradcam' is True
270
 
271
- fig, axs = plt.subplots(num_rows, num_cols, figsize=(8, 5 * num_rows))
272
 
273
  if use_gradcam:
274
  grad_cam_images = self.get_gradcam_images(target_layer=gradcam_layer, transparency=transparency, num_images=num_images)
@@ -277,22 +302,9 @@ class custom_ResNet(pl.LightningModule):
277
  img = misclassified_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C)
278
  img = self.normalize_image(img) # Normalize the image
279
 
280
- if num_cols > 1: # Use multiple columns for subplots
281
- axs[i, 0].imshow(img)
282
- axs[i, 0].set_title(f"True label: {self.classes[true_labels[i]]}\nPredicted: {self.classes[predicted_labels[i]]}")
283
- axs[i, 0].axis("off")
284
-
285
- if use_gradcam:
286
- # gradcam_img = grad_cam_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C)
287
- gradcam_img = self.normalize_image(grad_cam_images[i]) # Normalize the image
288
- axs[i, 1].imshow(gradcam_img)
289
- axs[i, 1].set_title("GradCAM")
290
- axs[i, 1].axis("off")
291
- else: # Use a single column for subplots
292
- axs[i].imshow(img)
293
- axs[i].set_title(f"True label: {self.classes[true_labels[i]]}\nPredicted: {self.classes[predicted_labels[i]]}")
294
- axs[i].axis("off")
295
-
296
- fig.tight_layout()
297
- return fig
298
 
 
 
 
24
  import torchvision.datasets as datasets
25
  import pytorch_lightning as pl
26
  import matplotlib.pyplot as plt
27
+ import matplotlib.gridspec as gridspec
28
 
29
 
30
  PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
31
  BATCH_SIZE = 256
32
 
33
 
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from torchsummary import summary
38
+ from io import BytesIO
39
+ import numpy as np
40
+
41
  # Model
42
  class custom_ResNet(pl.LightningModule):
43
  def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4):
 
175
  acc = pred.eq(y.view_as(pred)).float().mean()
176
  self.log('test_loss', loss, prog_bar=True)
177
  self.log('test_acc', acc, prog_bar=True)
178
+ return pred
179
 
180
  def configure_optimizers(self):
181
  optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
 
195
 
196
  # Assign train/val datasets for use in dataloaders
197
  if stage == "fit" or stage is None:
198
+ cifar_full = CIFAR10(self.data_dir, train=True, download=True, transform=self.train_transform)
199
  self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
200
 
201
  # Assign test dataset for use in dataloader(s)
 
222
  y_hat = self.forward(x)
223
  pred = y_hat.argmax(dim=1, keepdim=True)
224
  misclassified_mask = pred.eq(y.view_as(pred)).squeeze()
225
+ misclassified_images.extend(x[~misclassified_mask].detach())
226
+ misclassified_true_labels.extend(y[~misclassified_mask].detach())
227
+ misclassified_predicted_labels.extend(pred[~misclassified_mask].detach())
228
 
229
  num_collected += sum(~misclassified_mask)
230
 
 
267
 
268
  return gradcam_images
269
 
270
+ def create_layout(self, num_images, use_gradcam):
271
+ num_cols = 3 if use_gradcam else 2
272
+ fig = plt.figure(figsize=(12, 5 * num_images))
273
+ gs = gridspec.GridSpec(num_images, num_cols, figure=fig, width_ratios=[0.3, 1, 1] if use_gradcam else [0.5, 1])
274
+
275
+ return fig, gs
276
+
277
+ def show_images_with_labels(self, fig, gs, i, img, label_text, use_gradcam=False, gradcam_img=None):
278
+ ax_img = fig.add_subplot(gs[i, 1])
279
+ ax_img.imshow(img)
280
+ ax_img.set_title("Original Image")
281
+ ax_img.axis("off")
282
+
283
+ if use_gradcam:
284
+ ax_gradcam = fig.add_subplot(gs[i, 2])
285
+ ax_gradcam.imshow(gradcam_img)
286
+ ax_gradcam.set_title("GradCAM Image")
287
+ ax_gradcam.axis("off")
288
+
289
+ ax_label = fig.add_subplot(gs[i, 0])
290
+ ax_label.text(0, 0.5, label_text, fontsize=10, verticalalignment='center')
291
+ ax_label.axis("off")
292
+
293
  def show_misclassified_images(self, num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5):
294
  misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
 
 
 
 
295
 
296
+ fig, gs = self.create_layout(num_images, use_gradcam)
297
 
298
  if use_gradcam:
299
  grad_cam_images = self.get_gradcam_images(target_layer=gradcam_layer, transparency=transparency, num_images=num_images)
 
302
  img = misclassified_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C)
303
  img = self.normalize_image(img) # Normalize the image
304
 
305
+ # Show true label and predicted label on the left, and images on the right
306
+ label_text = f"True Label: {self.classes[true_labels[i]]}\nPredicted Label: {self.classes[predicted_labels[i]]}"
307
+ self.show_images_with_labels(fig, gs, i, img, label_text, use_gradcam, grad_cam_images[i] if use_gradcam else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
+ plt.tight_layout()
310
+ return fig