Mojo commited on
Commit
4f6f9e3
·
1 Parent(s): 923fe1a

Optimised the files

Browse files
models/custom_resnet.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to define the model."""
2
+
3
+ # Resources
4
+ # https://lightning.ai/docs/pytorch/stable/starter/introduction.html
5
+ # https://lightning.ai/docs/pytorch/stable/starter/converting.html
6
+ # https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/cifar10-baseline.html
7
+
8
+ import modules.config as config
9
+ import pytorch_lightning as pl
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.optim as optim
14
+ import torchinfo
15
+ from torch.optim.lr_scheduler import OneCycleLR
16
+ from torch_lr_finder import LRFinder
17
+ from torchmetrics import Accuracy
18
+
19
+ # What is the start LR and weight decay you'd prefer?
20
+ PREFERRED_START_LR = config.PREFERRED_START_LR
21
+ PREFERRED_WEIGHT_DECAY = config.PREFERRED_WEIGHT_DECAY
22
+
23
+
24
+ def detailed_model_summary(model, input_size):
25
+ """Define a function to print the model summary."""
26
+
27
+ # https://github.com/TylerYep/torchinfo
28
+ torchinfo.summary(
29
+ model,
30
+ input_size=input_size,
31
+ batch_dim=0,
32
+ col_names=(
33
+ "input_size",
34
+ "kernel_size",
35
+ "output_size",
36
+ "num_params",
37
+ "trainable",
38
+ ),
39
+ verbose=1,
40
+ col_width=16,
41
+ )
42
+
43
+
44
+ ############# Assignment 13 Model #############
45
+
46
+
47
+ # This is for Assignment 13
48
+ # Model used from Assignment 11 and converted to lightning model
49
+ class CustomResNet(pl.LightningModule):
50
+ """This defines the structure of the NN."""
51
+
52
+ # Class variable to print shape
53
+ print_shape = False
54
+ # Default dropout value
55
+ dropout_value = 0.02
56
+
57
+ def __init__(self):
58
+ super().__init__()
59
+
60
+ # Define loss function
61
+ # https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
62
+ self.loss_function = torch.nn.CrossEntropyLoss()
63
+
64
+ # Define accuracy function
65
+ # https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html
66
+ self.accuracy_function = Accuracy(task="multiclass", num_classes=10)
67
+
68
+ # Add results dictionary
69
+ self.results = {
70
+ "train_loss": [],
71
+ "train_acc": [],
72
+ "test_loss": [],
73
+ "test_acc": [],
74
+ "val_loss": [],
75
+ "val_acc": [],
76
+ }
77
+
78
+ # Save misclassified images
79
+ self.misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
80
+
81
+ # LR
82
+ self.learning_rate = PREFERRED_START_LR
83
+
84
+ # Model Notes
85
+
86
+ # PrepLayer - Conv 3x3 s1, p1) >> BN >> RELU [64k]
87
+ # 1. Input size: 32x32x3
88
+ self.prep = nn.Sequential(
89
+ nn.Conv2d(
90
+ in_channels=3,
91
+ out_channels=64,
92
+ kernel_size=(3, 3),
93
+ stride=1,
94
+ padding=1,
95
+ dilation=1,
96
+ bias=False,
97
+ ),
98
+ nn.BatchNorm2d(64),
99
+ nn.ReLU(),
100
+ nn.Dropout(self.dropout_value),
101
+ )
102
+
103
+ # Layer1: X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [128k]
104
+ self.layer1_x = nn.Sequential(
105
+ nn.Conv2d(
106
+ in_channels=64,
107
+ out_channels=128,
108
+ kernel_size=(3, 3),
109
+ stride=1,
110
+ padding=1,
111
+ dilation=1,
112
+ bias=False,
113
+ ),
114
+ nn.MaxPool2d(kernel_size=2, stride=2),
115
+ nn.BatchNorm2d(128),
116
+ nn.ReLU(),
117
+ nn.Dropout(self.dropout_value),
118
+ )
119
+
120
+ # Layer1: R1 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [128k]
121
+ self.layer1_r1 = nn.Sequential(
122
+ nn.Conv2d(
123
+ in_channels=128,
124
+ out_channels=128,
125
+ kernel_size=(3, 3),
126
+ stride=1,
127
+ padding=1,
128
+ dilation=1,
129
+ bias=False,
130
+ ),
131
+ nn.BatchNorm2d(128),
132
+ nn.ReLU(),
133
+ nn.Dropout(self.dropout_value),
134
+ nn.Conv2d(
135
+ in_channels=128,
136
+ out_channels=128,
137
+ kernel_size=(3, 3),
138
+ stride=1,
139
+ padding=1,
140
+ dilation=1,
141
+ bias=False,
142
+ ),
143
+ nn.BatchNorm2d(128),
144
+ nn.ReLU(),
145
+ nn.Dropout(self.dropout_value),
146
+ )
147
+
148
+ # Layer 2: Conv 3x3 [256k], MaxPooling2D, BN, ReLU
149
+ self.layer2 = nn.Sequential(
150
+ nn.Conv2d(
151
+ in_channels=128,
152
+ out_channels=256,
153
+ kernel_size=(3, 3),
154
+ stride=1,
155
+ padding=1,
156
+ dilation=1,
157
+ bias=False,
158
+ ),
159
+ nn.MaxPool2d(kernel_size=2, stride=2),
160
+ nn.BatchNorm2d(256),
161
+ nn.ReLU(),
162
+ nn.Dropout(self.dropout_value),
163
+ )
164
+
165
+ # Layer 3: X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [512k]
166
+ self.layer3_x = nn.Sequential(
167
+ nn.Conv2d(
168
+ in_channels=256,
169
+ out_channels=512,
170
+ kernel_size=(3, 3),
171
+ stride=1,
172
+ padding=1,
173
+ dilation=1,
174
+ bias=False,
175
+ ),
176
+ nn.MaxPool2d(kernel_size=2, stride=2),
177
+ nn.BatchNorm2d(512),
178
+ nn.ReLU(),
179
+ nn.Dropout(self.dropout_value),
180
+ )
181
+
182
+ # Layer 3: R2 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [512k]
183
+ self.layer3_r2 = nn.Sequential(
184
+ nn.Conv2d(
185
+ in_channels=512,
186
+ out_channels=512,
187
+ kernel_size=(3, 3),
188
+ stride=1,
189
+ padding=1,
190
+ dilation=1,
191
+ bias=False,
192
+ ),
193
+ nn.BatchNorm2d(512),
194
+ nn.ReLU(),
195
+ nn.Dropout(self.dropout_value),
196
+ nn.Conv2d(
197
+ in_channels=512,
198
+ out_channels=512,
199
+ kernel_size=(3, 3),
200
+ stride=1,
201
+ padding=1,
202
+ dilation=1,
203
+ bias=False,
204
+ ),
205
+ nn.BatchNorm2d(512),
206
+ nn.ReLU(),
207
+ nn.Dropout(self.dropout_value),
208
+ )
209
+
210
+ # MaxPooling with Kernel Size 4
211
+ # If stride is None, it is set to kernel_size
212
+ self.maxpool = nn.MaxPool2d(kernel_size=4, stride=4)
213
+
214
+ # FC Layer
215
+ self.fc = nn.Linear(512, 10)
216
+
217
+ # Save hyperparameters
218
+ self.save_hyperparameters()
219
+
220
+ def print_view(self, x, msg=""):
221
+ """Print shape of the model"""
222
+ if self.print_shape:
223
+ if msg != "":
224
+ print(msg, "\n\t", x.shape, "\n")
225
+ else:
226
+ print(x.shape)
227
+
228
+ def forward(self, x):
229
+ """Forward pass"""
230
+
231
+ # PrepLayer
232
+ x = self.prep(x)
233
+ self.print_view(x, "PrepLayer")
234
+
235
+ # Layer 1
236
+ x = self.layer1_x(x)
237
+ self.print_view(x, "Layer 1, X")
238
+ r1 = self.layer1_r1(x)
239
+ self.print_view(r1, "Layer 1, R1")
240
+ x = x + r1
241
+ self.print_view(x, "Layer 1, X + R1")
242
+
243
+ # Layer 2
244
+ x = self.layer2(x)
245
+ self.print_view(x, "Layer 2")
246
+
247
+ # Layer 3
248
+ x = self.layer3_x(x)
249
+ self.print_view(x, "Layer 3, X")
250
+ r2 = self.layer3_r2(x)
251
+ self.print_view(r2, "Layer 3, R2")
252
+ x = x + r2
253
+ self.print_view(x, "Layer 3, X + R2")
254
+
255
+ # MaxPooling
256
+ x = self.maxpool(x)
257
+ self.print_view(x, "Max Pooling")
258
+
259
+ # FC Layer
260
+ # Reshape before FC such that it becomes 1D
261
+ x = x.view(x.shape[0], -1)
262
+ self.print_view(x, "Reshape before FC")
263
+ x = self.fc(x)
264
+ self.print_view(x, "After FC")
265
+
266
+ # Softmax
267
+ return F.log_softmax(x, dim=-1)
268
+
269
+ # Alert: Remove this function later as Tuner is now being used to automatically find the best LR
270
+ def find_optimal_lr(self, train_loader):
271
+ """Use LR Finder to find the best starting learning rate"""
272
+
273
+ # https://github.com/davidtvs/pytorch-lr-finder
274
+ # https://github.com/davidtvs/pytorch-lr-finder#notes
275
+ # https://github.com/davidtvs/pytorch-lr-finder/blob/master/torch_lr_finder/lr_finder.py
276
+
277
+ # New optimizer with default LR
278
+ tmp_optimizer = optim.Adam(self.parameters(), lr=PREFERRED_START_LR, weight_decay=PREFERRED_WEIGHT_DECAY)
279
+
280
+ # Create LR finder object
281
+ lr_finder = LRFinder(self, optimizer=tmp_optimizer, criterion=self.loss_function)
282
+ lr_finder.range_test(train_loader=train_loader, end_lr=10, num_iter=100)
283
+ # https://github.com/davidtvs/pytorch-lr-finder/issues/88
284
+ _, suggested_lr = lr_finder.plot(suggest_lr=True)
285
+ lr_finder.reset()
286
+ # plot.figure.savefig("LRFinder - Suggested Max LR.png")
287
+
288
+ print(f"Suggested Max LR: {suggested_lr}")
289
+
290
+ if suggested_lr is None:
291
+ suggested_lr = PREFERRED_START_LR
292
+
293
+ return suggested_lr
294
+
295
+ # optimiser function
296
+ def configure_optimizers(self):
297
+ """Add ADAM optimizer to the lightning module"""
298
+ optimizer = optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=PREFERRED_WEIGHT_DECAY)
299
+
300
+ # Percent start for OneCycleLR
301
+ # Handles the case where max_epochs is less than 5
302
+ percent_start = 5 / int(self.trainer.max_epochs)
303
+ if percent_start >= 1:
304
+ percent_start = 0.3
305
+
306
+ # https://lightning.ai/docs/pytorch/stable/common/optimization.html#total-stepping-batches
307
+ scheduler_dict = {
308
+ "scheduler": OneCycleLR(
309
+ optimizer=optimizer,
310
+ max_lr=self.learning_rate,
311
+ total_steps=int(self.trainer.estimated_stepping_batches),
312
+ pct_start=percent_start,
313
+ div_factor=100,
314
+ three_phase=False,
315
+ anneal_strategy="linear",
316
+ final_div_factor=100,
317
+ verbose=False,
318
+ ),
319
+ "interval": "step",
320
+ }
321
+
322
+ return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
323
+
324
+ # Define loss function
325
+ def compute_loss(self, prediction, target):
326
+ """Compute Loss"""
327
+
328
+ # Calculate loss
329
+ loss = self.loss_function(prediction, target)
330
+
331
+ return loss
332
+
333
+ # Define accuracy function
334
+ def compute_accuracy(self, prediction, target):
335
+ """Compute accuracy"""
336
+
337
+ # Calculate accuracy
338
+ acc = self.accuracy_function(prediction, target)
339
+
340
+ return acc * 100
341
+
342
+ # Function to compute loss and accuracy for both training and validation
343
+ def compute_metrics(self, batch):
344
+ """Function to calculate loss and accuracy"""
345
+
346
+ # Get data and target from batch
347
+ data, target = batch
348
+
349
+ # Generate predictions using model
350
+ pred = self(data)
351
+
352
+ # Calculate loss for the batch
353
+ loss = self.compute_loss(prediction=pred, target=target)
354
+
355
+ # Calculate accuracy for the batch
356
+ acc = self.compute_accuracy(prediction=pred, target=target)
357
+
358
+ return loss, acc
359
+
360
+ # Get misclassified images based on how many images to return
361
+ def store_misclassified_images(self):
362
+ """Get an array of misclassified images"""
363
+
364
+ self.misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
365
+
366
+ # Initialize the model to evaluation mode
367
+ self.eval()
368
+
369
+ # Disable gradient calculation while testing
370
+ with torch.no_grad():
371
+ for batch in self.trainer.test_dataloaders:
372
+ # Move data and labels to device
373
+ data, target = batch
374
+ data, target = data.to(self.device), target.to(self.device)
375
+
376
+ # Predict using model
377
+ pred = self(data)
378
+
379
+ # Get the index of the max log-probability
380
+ output = pred.argmax(dim=1)
381
+
382
+ # Save the incorrect predictions
383
+ incorrect_indices = ~output.eq(target)
384
+
385
+ # Store images incorrectly predicted, generated predictions and the actual value
386
+ self.misclassified_image_data["images"].extend(data[incorrect_indices])
387
+ self.misclassified_image_data["ground_truths"].extend(target[incorrect_indices])
388
+ self.misclassified_image_data["predicted_vals"].extend(output[incorrect_indices])
389
+
390
+ # training function
391
+ def training_step(self, batch, batch_idx):
392
+ """Training step"""
393
+
394
+ # Compute loss and accuracy
395
+ loss, acc = self.compute_metrics(batch)
396
+
397
+ self.log("train_loss", loss, prog_bar=True, on_epoch=True, logger=True)
398
+ self.log("train_acc", acc, prog_bar=True, on_epoch=True, logger=True)
399
+ # Return training loss
400
+ return loss
401
+
402
+ # validation function
403
+ def validation_step(self, batch, batch_idx):
404
+ """Validation step"""
405
+
406
+ # Compute loss and accuracy
407
+ loss, acc = self.compute_metrics(batch)
408
+
409
+ self.log("val_loss", loss, prog_bar=True, on_epoch=True, logger=True)
410
+ self.log("val_acc", acc, prog_bar=True, on_epoch=True, logger=True)
411
+ # Return validation loss
412
+ return loss
413
+
414
+ # test function will just use validation step
415
+ def test_step(self, batch, batch_idx):
416
+ """Test step"""
417
+
418
+ # Compute loss and accuracy
419
+ loss, acc = self.compute_metrics(batch)
420
+
421
+ self.log("test_loss", loss, prog_bar=False, on_epoch=True, logger=True)
422
+ self.log("test_acc", acc, prog_bar=False, on_epoch=True, logger=True)
423
+ # Return validation loss
424
+ return loss
425
+
426
+ # At the end of train epoch append the training loss and accuracy to an instance variable called results
427
+ def on_train_epoch_end(self):
428
+ """On train epoch end"""
429
+
430
+ # Append training loss and accuracy to results
431
+ self.results["train_loss"].append(self.trainer.callback_metrics["train_loss"].detach().item())
432
+ self.results["train_acc"].append(self.trainer.callback_metrics["train_acc"].detach().item())
433
+
434
+ # At the end of validation epoch append the validation loss and accuracy to an instance variable called results
435
+ def on_validation_epoch_end(self):
436
+ """On validation epoch end"""
437
+
438
+ # Append validation loss and accuracy to results
439
+ self.results["test_loss"].append(self.trainer.callback_metrics["val_loss"].detach().item())
440
+ self.results["test_acc"].append(self.trainer.callback_metrics["val_acc"].detach().item())
441
+
442
+ # # At the end of test epoch append the test loss and accuracy to an instance variable called results
443
+ # def on_test_epoch_end(self):
444
+ # """On test epoch end"""
445
+
446
+ # # Append test loss and accuracy to results
447
+ # self.results["test_loss"].append(self.trainer.callback_metrics["test_loss"].detach().item())
448
+ # self.results["test_acc"].append(self.trainer.callback_metrics["test_acc"].detach().item())
449
+
450
+ # At the end of test save misclassified images, the predictions and ground truth in an instance variable called misclassified_image_data
451
+ def on_test_end(self):
452
+ """On test end"""
453
+
454
+ print("Test ended! Saving misclassified images")
455
+ # Get misclassified images
456
+ self.store_misclassified_images()
modules/config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Alert: Change these when running in production
2
+
3
+ # Constants naming convention: All caps separated by underscore
4
+ # https://realpython.com/python-constants/
5
+
6
+ # Where do we store the data?
7
+ DATA_PATH = "../../data/"
8
+ CHECKPOINT_PATH = "../../checkpoints/"
9
+ LOGGING_PATH = "../../logs/"
10
+ MISCLASSIFIED_PATH = "Misclassified_Data.pt"
11
+ MODEL_PATH = "CustomResNet.pt"
12
+
13
+ # Specify the number of epochs
14
+ NUM_EPOCHS = 24
15
+
16
+ # Set the batch size
17
+ BATCH_SIZE = 512
18
+
19
+ # Set seed value for reproducibility
20
+ SEED = 53
21
+
22
+ # What is the start LR and weight decay you'd prefer?
23
+ PREFERRED_START_LR = 5e-3
24
+ PREFERRED_WEIGHT_DECAY = 1e-5
25
+
26
+
27
+ # What is the mean and std deviation of the dataset?
28
+ CIFAR_MEAN = (0.4915, 0.4823, 0.4468)
29
+ CIFAR_STD = (0.2470, 0.2435, 0.2616)
30
+
31
+ # What is the cutout size?
32
+ CUTOUT_SIZE = 16
33
+
34
+ # What are the classes in CIFAR10?
35
+ # Create class labels and convert to tuple
36
+ CIFAR_CLASSES = tuple(
37
+ c.capitalize()
38
+ for c in [
39
+ "plane",
40
+ "car",
41
+ "bird",
42
+ "cat",
43
+ "deer",
44
+ "dog",
45
+ "frog",
46
+ "horse",
47
+ "ship",
48
+ "truck",
49
+ ]
50
+ )
modules/dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains functions to download and transform the CIFAR10 dataset"""
2
+ # Needed for image transformations
3
+ import albumentations as A
4
+ import modules.config as config
5
+
6
+ # # Needed for padding issues in albumentations
7
+ # import cv2
8
+ import numpy as np
9
+ from albumentations.pytorch.transforms import ToTensorV2
10
+ from torch.utils.data import Dataset
11
+
12
+ # Use precomputed values for mean and standard deviation of the dataset
13
+ CIFAR_MEAN = config.CIFAR_MEAN
14
+ CIFAR_STD = config.CIFAR_STD
15
+ CUTOUT_SIZE = config.CUTOUT_SIZE
16
+
17
+ # Create class labels and convert to tuple
18
+ CIFAR_CLASSES = config.CIFAR_CLASSES
19
+
20
+
21
+ class CIFAR10Transforms(Dataset):
22
+ """Apply albumentations augmentations to CIFAR10 dataset"""
23
+
24
+ # Given a dataset and transformations,
25
+ # apply the transformations and return the dataset
26
+ def __init__(self, dataset, transforms):
27
+ self.dataset = dataset
28
+ self.transforms = transforms
29
+
30
+ def __getitem__(self, idx):
31
+ # Get the image and label from the dataset
32
+ image, label = self.dataset[idx]
33
+
34
+ # Apply transformations on the image
35
+ image = self.transforms(image=np.array(image))["image"]
36
+
37
+ return image, label
38
+
39
+ def __len__(self):
40
+ return len(self.dataset)
41
+
42
+ def __repr__(self):
43
+ return f"CIFAR10Transforms(dataset={self.dataset}, transforms={self.transforms})"
44
+
45
+ def __str__(self):
46
+ return f"CIFAR10Transforms(dataset={self.dataset}, transforms={self.transforms})"
47
+
48
+
49
+ def apply_cifar_image_transformations(mean=CIFAR_MEAN, std=CIFAR_STD, cutout_size=CUTOUT_SIZE):
50
+ """
51
+ Function to apply the required transformations to the MNIST dataset.
52
+ """
53
+ # Apply the required transformations to the MNIST dataset
54
+ train_transforms = A.Compose(
55
+ [
56
+ # normalize the images with mean and standard deviation from the whole dataset
57
+ # https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize
58
+ # # transforms.Normalize(cifar_mean, cifar_std),
59
+ A.Normalize(mean=list(mean), std=list(std)),
60
+ # RandomCrop 32, 32 (after padding of 4)
61
+ # https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.PadIfNeeded
62
+ # MinHeight and MinWidth are set to 36 to ensure that the image is padded to 36x36 after padding
63
+ # border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
64
+ # cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
65
+ # Default: cv2.BORDER_REFLECT_101
66
+ A.PadIfNeeded(min_height=36, min_width=36),
67
+ # https://albumentations.ai/docs/api_reference/augmentations/crops/transforms/#albumentations.augmentations.crops.transforms.RandomCrop
68
+ A.RandomCrop(32, 32),
69
+ # CutOut(8, 8)
70
+ # # https://albumentations.ai/docs/api_reference/augmentations/dropout/cutout/#albumentations.augmentations.dropout.cutout.Cutout
71
+ # # Because we normalized the images with mean and standard deviation from the whole dataset, the fill_value is set to the mean of the dataset
72
+ # A.Cutout(
73
+ # num_holes=1, max_h_size=cutout_size, max_w_size=cutout_size, p=1.0
74
+ # ),
75
+ # https://albumentations.ai/docs/api_reference/augmentations/dropout/coarse_dropout/#coarsedropout-augmentation-augmentationsdropoutcoarse_dropout
76
+ A.CoarseDropout(
77
+ max_holes=1,
78
+ max_height=cutout_size,
79
+ max_width=cutout_size,
80
+ min_holes=1,
81
+ min_height=cutout_size,
82
+ min_width=cutout_size,
83
+ p=1.0,
84
+ ),
85
+ # Convert the images to tensors
86
+ # # transforms.ToTensor(),
87
+ ToTensorV2(),
88
+ ]
89
+ )
90
+
91
+ # Test data transformations
92
+ test_transforms = A.Compose(
93
+ # Convert the images to tensors
94
+ # normalize the images with mean and standard deviation from the whole dataset
95
+ [
96
+ A.Normalize(mean=list(mean), std=list(std)),
97
+ # Convert the images to tensors
98
+ ToTensorV2(),
99
+ ]
100
+ )
101
+
102
+ return train_transforms, test_transforms
103
+
104
+
105
+ def calculate_mean_std(dataset):
106
+ """Function to calculate the mean and standard deviation of CIFAR dataset"""
107
+ data = dataset.data.astype(np.float32) / 255.0
108
+ mean = np.mean(data, axis=(0, 1, 2))
109
+ std = np.std(data, axis=(0, 1, 2))
110
+ return mean, std
modules/lightning_dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains functions to prepare dataloader in the way lightning expects"""
2
+ import pytorch_lightning as pl
3
+ import torchvision.datasets as datasets
4
+ from lightning_fabric.utilities.seed import seed_everything
5
+ from modules.dataset import CIFAR10Transforms, apply_cifar_image_transformations
6
+ from torch.utils.data import DataLoader, random_split
7
+
8
+
9
+ class CIFARDataModule(pl.LightningDataModule):
10
+ """Lightning DataModule for CIFAR10 dataset"""
11
+
12
+ def __init__(self, data_path, batch_size, seed, val_split=0, num_workers=0):
13
+ super().__init__()
14
+
15
+ self.data_path = data_path
16
+ self.batch_size = batch_size
17
+ self.seed = seed
18
+ self.val_split = val_split
19
+ self.num_workers = num_workers
20
+ self.dataloader_dict = {
21
+ # "shuffle": True,
22
+ "batch_size": self.batch_size,
23
+ "num_workers": self.num_workers,
24
+ "pin_memory": True,
25
+ # "worker_init_fn": self._init_fn,
26
+ "persistent_workers": self.num_workers > 0,
27
+ }
28
+ self.prepare_data_per_node = False
29
+
30
+ # Fixes attribute defined outside __init__ warning
31
+ self.training_dataset = None
32
+ self.validation_dataset = None
33
+ self.testing_dataset = None
34
+
35
+ # # Make sure data is downloaded
36
+ # self.prepare_data()
37
+
38
+ def _split_train_val(self, dataset):
39
+ """Split the dataset into train and validation sets"""
40
+
41
+ # Throw an error if the validation split is not between 0 and 1
42
+ if not 0 < self.val_split < 1:
43
+ raise ValueError("Validation split must be between 0 and 1")
44
+
45
+ # # Set seed again, might not be necessary
46
+ # seed_everything(int(self.seed))
47
+
48
+ # Calculate lengths of each dataset
49
+ total_length = len(dataset)
50
+ train_length = int((1 - self.val_split) * total_length)
51
+ val_length = total_length - train_length
52
+
53
+ # Split the dataset
54
+ train_dataset, val_dataset = random_split(dataset, [train_length, val_length])
55
+
56
+ return train_dataset, val_dataset
57
+
58
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data
59
+ def prepare_data(self):
60
+ # Download the CIFAR10 dataset if it doesn't exist
61
+ datasets.CIFAR10(self.data_path, train=True, download=True)
62
+ datasets.CIFAR10(self.data_path, train=False, download=True)
63
+
64
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup
65
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.hooks.DataHooks.html#lightning.pytorch.core.hooks.DataHooks.setup
66
+ def setup(self, stage=None):
67
+ # seed_everything(int(self.seed))
68
+
69
+ # Define the data transformations
70
+ train_transforms, test_transforms = apply_cifar_image_transformations()
71
+ val_transforms = test_transforms
72
+
73
+ # Create train and validation datasets
74
+ if stage == "fit" or stage is None:
75
+ if self.val_split != 0:
76
+ # Split the training data into training and validation sets
77
+ data_train, data_val = self._split_train_val(datasets.CIFAR10(self.data_path, train=True))
78
+ # Apply transformations
79
+ self.training_dataset = CIFAR10Transforms(data_train, train_transforms)
80
+ self.validation_dataset = CIFAR10Transforms(data_val, val_transforms)
81
+ else:
82
+ # Only training data here
83
+ self.training_dataset = CIFAR10Transforms(
84
+ datasets.CIFAR10(self.data_path, train=True), train_transforms
85
+ )
86
+ # Validation will be same sa test
87
+ self.validation_dataset = CIFAR10Transforms(
88
+ datasets.CIFAR10(self.data_path, train=False), val_transforms
89
+ )
90
+
91
+ # Create test dataset
92
+ if stage == "test" or stage is None:
93
+ # Assign Test split(s) for use in Dataloaders
94
+ self.testing_dataset = CIFAR10Transforms(datasets.CIFAR10(self.data_path, train=False), test_transforms)
95
+
96
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#train-dataloader
97
+ def train_dataloader(self):
98
+ return DataLoader(self.training_dataset, **self.dataloader_dict, shuffle=True)
99
+
100
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#val-dataloader
101
+ def val_dataloader(self):
102
+ return DataLoader(self.validation_dataset, **self.dataloader_dict, shuffle=False)
103
+
104
+ # https://lightning.ai/docs/pytorch/stable/data/datamodule.html#test-dataloader
105
+ def test_dataloader(self):
106
+ return DataLoader(self.testing_dataset, **self.dataloader_dict, shuffle=False)
107
+
108
+ def _init_fn(self, worker_id):
109
+ seed_everything(int(self.seed) + worker_id)
modules/trainer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to define the train and test functions."""
2
+
3
+ # from functools import partial
4
+
5
+ import modules.config as config
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from modules.utils import create_folder_if_not_exists
9
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ModelSummary
10
+
11
+ # Import tuner
12
+ from pytorch_lightning.tuner.tuning import Tuner
13
+
14
+ # What is the start LR and weight decay you'd prefer?
15
+ PREFERRED_START_LR = config.PREFERRED_START_LR
16
+
17
+
18
+ def train_and_test_model(
19
+ batch_size,
20
+ num_epochs,
21
+ model,
22
+ datamodule,
23
+ logger,
24
+ debug=False,
25
+ ):
26
+ """Trains and tests the model by iterating through epochs using Lightning Trainer."""
27
+
28
+ print(f"\n\nBatch size: {batch_size}, Total epochs: {num_epochs}\n\n")
29
+
30
+ print("Defining Lightning Callbacks")
31
+
32
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
33
+ checkpoint = ModelCheckpoint(
34
+ dirpath=config.CHECKPOINT_PATH, monitor="val_acc", mode="max", filename="model_best_epoch", save_last=True
35
+ )
36
+ # # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#learningratemonitor
37
+ lr_rate_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=False)
38
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelSummary.html#lightning.pytorch.callbacks.ModelSummary
39
+ model_summary = ModelSummary(max_depth=0)
40
+
41
+ print("Defining Lightning Trainer")
42
+ # Change trainer settings for debugging
43
+ if debug:
44
+ num_epochs = 1
45
+ fast_dev_run = True
46
+ overfit_batches = 0.1
47
+ profiler = "advanced"
48
+ else:
49
+ fast_dev_run = False
50
+ overfit_batches = 0.0
51
+ profiler = None
52
+
53
+ # https://lightning.ai/docs/pytorch/stable/common/trainer.html#methods
54
+ trainer = pl.Trainer(
55
+ precision=16,
56
+ fast_dev_run=fast_dev_run,
57
+ # deterministic=True,
58
+ # devices="auto",
59
+ # accelerator="auto",
60
+ max_epochs=num_epochs,
61
+ logger=logger,
62
+ # enable_model_summary=False,
63
+ overfit_batches=overfit_batches,
64
+ log_every_n_steps=10,
65
+ # num_sanity_val_steps=5,
66
+ profiler=profiler,
67
+ # check_val_every_n_epoch=1,
68
+ callbacks=[checkpoint, lr_rate_monitor, model_summary],
69
+ # callbacks=[checkpoint],
70
+ )
71
+
72
+ # # Using the learning rate finder
73
+ # model.learning_rate = model.find_optimal_lr(train_loader=datamodule.train_dataloader())
74
+
75
+ # Using the lr_find from Trainer.tune method instead
76
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.tuner.tuning.Tuner.html#lightning.pytorch.tuner.tuning.Tuner
77
+ # https://www.youtube.com/watch?v=cLZv0eZQSIE
78
+ print("Finding the optimal learning rate using Lightning Tuner.")
79
+ tuner = Tuner(trainer)
80
+ tuner.lr_find(
81
+ model=model,
82
+ datamodule=datamodule,
83
+ min_lr=PREFERRED_START_LR,
84
+ max_lr=5,
85
+ num_training=200,
86
+ mode="linear",
87
+ early_stop_threshold=10,
88
+ attr_name="learning_rate",
89
+ )
90
+
91
+ trainer.fit(model, datamodule=datamodule)
92
+ trainer.test(model, dataloaders=datamodule.test_dataloader())
93
+
94
+ # # Obtain the results dictionary from model
95
+ print("Collecting epoch level model results.")
96
+ results = model.results
97
+ # print(f"Results Length: {len(results)}")
98
+
99
+ # Get the list of misclassified images
100
+ print("Collecting misclassified images.")
101
+ misclassified_image_data = model.misclassified_image_data
102
+ # print(f"Misclassified Images Length: {len(misclassified_image_data)}")
103
+
104
+ # Save the model using torch save as backup
105
+ print("Saving the model.")
106
+ create_folder_if_not_exists(config.MODEL_PATH)
107
+ torch.save(model.state_dict(), config.MODEL_PATH)
108
+
109
+ # Save first few misclassified images data to a file
110
+ num_elements = 20
111
+ print(f"Saving first {num_elements} misclassified images.")
112
+ subset_misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []}
113
+ subset_misclassified_image_data["images"] = misclassified_image_data["images"][:num_elements]
114
+ subset_misclassified_image_data["ground_truths"] = misclassified_image_data["ground_truths"][:num_elements]
115
+ subset_misclassified_image_data["predicted_vals"] = misclassified_image_data["predicted_vals"][:num_elements]
116
+ create_folder_if_not_exists(config.MISCLASSIFIED_PATH)
117
+ torch.save(subset_misclassified_image_data, config.MISCLASSIFIED_PATH)
118
+
119
+ return trainer, results, misclassified_image_data
120
+ # return trainer
modules/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to define utility functions for the project."""
2
+ import os
3
+
4
+ import torch
5
+
6
+
7
+ def get_num_workers(model_run_location):
8
+ """Given a run mode, return the number of workers to be used for data loading."""
9
+
10
+ # calculate the number of workers
11
+ num_workers = (os.cpu_count() - 1) if os.cpu_count() > 3 else 2
12
+
13
+ # If run_mode is local, use only 2 workers
14
+ num_workers = num_workers if model_run_location == "colab" else 0
15
+
16
+ return num_workers
17
+
18
+
19
+ # Function to save the model
20
+ # https://debuggercafe.com/saving-and-loading-the-best-model-in-pytorch/
21
+ def save_model(epoch, model, optimizer, scheduler, batch_size, criterion, file_name):
22
+ """
23
+ Function to save the trained model along with other information to disk.
24
+ """
25
+ # print(f"Saving model from epoch {epoch}...")
26
+ torch.save(
27
+ {
28
+ "epoch": epoch,
29
+ "model_state_dict": model.state_dict(),
30
+ "optimizer_state_dict": optimizer.state_dict(),
31
+ "scheduler_state_dict": scheduler.state_dict(),
32
+ "batch_size": batch_size,
33
+ "loss": criterion,
34
+ },
35
+ file_name,
36
+ )
37
+
38
+
39
+ # Given a list of train_losses, train_accuracies, test_losses,
40
+ # test_accuracies, loop through epoch and print the metrics
41
+ def pretty_print_metrics(num_epochs, results):
42
+ """
43
+ Function to print the metrics in a pretty format.
44
+ """
45
+ # Extract train_losses, train_acc, test_losses, test_acc from results
46
+ train_losses = results["train_loss"]
47
+ train_acc = results["train_acc"]
48
+ test_losses = results["test_loss"]
49
+ test_acc = results["test_acc"]
50
+
51
+ for i in range(num_epochs):
52
+ print(
53
+ f"Epoch: {i+1:02d}, Train Loss: {train_losses[i]:.4f}, "
54
+ f"Test Loss: {test_losses[i]:.4f}, Train Accuracy: {train_acc[i]:.4f}, "
55
+ f"Test Accuracy: {test_acc[i]:.4f}"
56
+ )
57
+
58
+
59
+ # Given a file path, extract the folder path and create folder recursively if it does not already exist
60
+ def create_folder_if_not_exists(file_path):
61
+ """
62
+ Function to create a folder if it does not exist.
63
+ """
64
+ # Extract the folder path
65
+ folder_path = os.path.dirname(file_path)
66
+
67
+ # Create the folder if it does not exist
68
+ if not os.path.exists(folder_path):
69
+ os.makedirs(folder_path)
70
+ print(f"Created folder: {folder_path}")
modules/visualize.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from pytorch_grad_cam import GradCAM
4
+ from pytorch_grad_cam.utils.image import show_cam_on_image
5
+
6
+
7
+ def convert_back_image(image):
8
+ """Using mean and std deviation convert image back to normal"""
9
+ cifar10_mean = (0.4914, 0.4822, 0.4471)
10
+ cifar10_std = (0.2469, 0.2433, 0.2615)
11
+ image = image.numpy().astype(dtype=np.float32)
12
+
13
+ for i in range(image.shape[0]):
14
+ image[i] = (image[i] * cifar10_std[i]) + cifar10_mean[i]
15
+
16
+ # To stop throwing a warning that image pixels exceeds bounds
17
+ image = image.clip(0, 1)
18
+
19
+ return np.transpose(image, (1, 2, 0))
20
+
21
+
22
+ def plot_sample_training_images(batch_data, batch_label, class_label, num_images=30):
23
+ """Function to plot sample images from the training data."""
24
+ images, labels = batch_data, batch_label
25
+
26
+ # Calculate the number of images to plot
27
+ num_images = min(num_images, len(images))
28
+ # calculate the number of rows and columns to plot
29
+ num_cols = 5
30
+ num_rows = int(np.ceil(num_images / num_cols))
31
+
32
+ # Initialize a subplot with the required number of rows and columns
33
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))
34
+
35
+ # Iterate through the images and plot them in the grid along with class labels
36
+
37
+ for img_index in range(1, num_images + 1):
38
+ plt.subplot(num_rows, num_cols, img_index)
39
+ plt.tight_layout()
40
+ plt.axis("off")
41
+ plt.imshow(convert_back_image(images[img_index - 1]))
42
+ plt.title(class_label[labels[img_index - 1].item()])
43
+ plt.xticks([])
44
+ plt.yticks([])
45
+
46
+ return fig, axs
47
+
48
+
49
+ def plot_train_test_metrics(results):
50
+ """
51
+ Function to plot the training and test metrics.
52
+ """
53
+ # Extract train_losses, train_acc, test_losses, test_acc from results
54
+ train_losses = results["train_loss"]
55
+ train_acc = results["train_acc"]
56
+ test_losses = results["test_loss"]
57
+ test_acc = results["test_acc"]
58
+
59
+ # Plot the graphs in a 1x2 grid showing the training and test metrics
60
+ fig, axs = plt.subplots(1, 2, figsize=(16, 8))
61
+
62
+ # Loss plot
63
+ axs[0].plot(train_losses, label="Train")
64
+ axs[0].plot(test_losses, label="Test")
65
+ axs[0].set_title("Loss")
66
+ axs[0].legend(loc="upper right")
67
+
68
+ # Accuracy plot
69
+ axs[1].plot(train_acc, label="Train")
70
+ axs[1].plot(test_acc, label="Test")
71
+ axs[1].set_title("Accuracy")
72
+ axs[1].legend(loc="upper right")
73
+
74
+ return fig, axs
75
+
76
+
77
+ def plot_misclassified_images(data, class_label, num_images=10):
78
+ """Plot the misclassified images from the test dataset."""
79
+ # Calculate the number of images to plot
80
+ num_images = min(num_images, len(data["ground_truths"]))
81
+ # calculate the number of rows and columns to plot
82
+ num_cols = 5
83
+ num_rows = int(np.ceil(num_images / num_cols))
84
+
85
+ # Initialize a subplot with the required number of rows and columns
86
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
87
+
88
+ # Iterate through the images and plot them in the grid along with class labels
89
+
90
+ for img_index in range(1, num_images + 1):
91
+ # Get the ground truth and predicted labels for the image
92
+ label = data["ground_truths"][img_index - 1].cpu().item()
93
+ pred = data["predicted_vals"][img_index - 1].cpu().item()
94
+ # Get the image
95
+ image = data["images"][img_index - 1].cpu()
96
+ # Plot the image
97
+ plt.subplot(num_rows, num_cols, img_index)
98
+ plt.tight_layout()
99
+ plt.axis("off")
100
+ plt.imshow(convert_back_image(image))
101
+ plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""")
102
+ plt.xticks([])
103
+ plt.yticks([])
104
+
105
+ return fig, axs
106
+
107
+
108
+ # Function to plot gradcam for misclassified images using pytorch_grad_cam
109
+ def plot_gradcam_images(
110
+ model,
111
+ data,
112
+ class_label,
113
+ target_layers,
114
+ targets=None,
115
+ num_images=10,
116
+ image_weight=0.25,
117
+ ):
118
+ """Show gradcam for misclassified images"""
119
+
120
+ # Calculate the number of images to plot
121
+ num_images = min(num_images, len(data["ground_truths"]))
122
+ # calculate the number of rows and columns to plot
123
+ num_cols = 5
124
+ num_rows = int(np.ceil(num_images / num_cols))
125
+
126
+ # Initialize a subplot with the required number of rows and columns
127
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
128
+
129
+ # Initialize the GradCAM object
130
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/grad_cam.py
131
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/base_cam.py
132
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
133
+
134
+ # Iterate through the images and plot them in the grid along with class labels
135
+ for img_index in range(1, num_images + 1):
136
+ # Extract elements from the data dictionary
137
+ # Get the ground truth and predicted labels for the image
138
+ label = data["ground_truths"][img_index - 1].cpu().item()
139
+ pred = data["predicted_vals"][img_index - 1].cpu().item()
140
+ # Get the image
141
+ image = data["images"][img_index - 1].cpu()
142
+
143
+ # Get the GradCAM output
144
+ # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py
145
+ grad_cam_output = cam(
146
+ input_tensor=image.unsqueeze(0),
147
+ targets=targets,
148
+ aug_smooth=True,
149
+ eigen_smooth=True,
150
+ )
151
+ grad_cam_output = grad_cam_output[0, :]
152
+
153
+ # Overlay gradcam on top of numpy image
154
+ overlayed_image = show_cam_on_image(
155
+ convert_back_image(image),
156
+ grad_cam_output,
157
+ use_rgb=True,
158
+ image_weight=image_weight,
159
+ )
160
+
161
+ # Plot the image
162
+ plt.subplot(num_rows, num_cols, img_index)
163
+ plt.tight_layout()
164
+ plt.axis("off")
165
+ plt.imshow(overlayed_image)
166
+ plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""")
167
+ plt.xticks([])
168
+ plt.yticks([])
169
+ return fig, axs
utilities/callbacks.py DELETED
@@ -1,64 +0,0 @@
1
- import pytorch_lightning as pl
2
- from pytorch_lightning.callbacks import Callback
3
-
4
- from .visualize import plot_model_training_curves
5
-
6
-
7
- class TrainingEndCallback(Callback):
8
- def on_train_end(self, trainer, pl_module):
9
- # Perform actions at the end of the entire training process
10
- print("Training, validation, and testing completed!")
11
-
12
- logged_metrics = pl_module.log_store
13
-
14
- plot_model_training_curves(
15
- train_accs=logged_metrics["train_acc_epoch"],
16
- test_accs=logged_metrics["val_acc_epoch"],
17
- train_losses=logged_metrics["train_loss_epoch"],
18
- test_losses=logged_metrics["val_loss_epoch"],
19
- )
20
-
21
-
22
- class PrintLearningMetricsCallback(Callback):
23
- def on_train_epoch_end(
24
- self, trainer: pl.Trainer, pl_module: pl.LightningModule
25
- ) -> None:
26
- super().on_train_epoch_end(trainer, pl_module)
27
- print(
28
- f"\nEpoch: {trainer.current_epoch}, Train Loss: {trainer.logged_metrics['train_loss_epoch']}, Train Accuracy: {trainer.logged_metrics['train_acc_epoch']}"
29
- )
30
- pl_module.log_store.get("train_loss_epoch").append(
31
- trainer.logged_metrics["train_loss_epoch"].cpu().detach().item()
32
- )
33
- pl_module.log_store.get("train_acc_epoch").append(
34
- trainer.logged_metrics["train_acc_epoch"].cpu().detach().item()
35
- )
36
-
37
- def on_validation_epoch_end(
38
- self, trainer: pl.Trainer, pl_module: pl.LightningModule
39
- ) -> None:
40
- super().on_validation_epoch_end(trainer, pl_module)
41
- print(
42
- f"\nEpoch: {trainer.current_epoch}, Val Loss: {trainer.logged_metrics['val_loss_epoch']}, Val Accuracy: {trainer.logged_metrics['val_acc_epoch']}"
43
- )
44
- pl_module.log_store.get("val_loss_epoch").append(
45
- trainer.logged_metrics["val_loss_epoch"].cpu().detach().item()
46
- )
47
- pl_module.log_store.get("val_acc_epoch").append(
48
- trainer.logged_metrics["val_acc_epoch"].cpu().detach().item()
49
- )
50
-
51
-
52
- def on_test_epoch_end(
53
- self, trainer: pl.Trainer, pl_module: pl.LightningModule
54
- ) -> None:
55
- super().on_test_epoch_end(trainer, pl_module)
56
- print(
57
- f"\nEpoch: {trainer.current_epoch}, Test Loss: {trainer.logged_metrics['test_loss_epoch']}, Test Accuracy: {trainer.logged_metrics['test_acc_epoch']}"
58
- )
59
- pl_module.log_store.get("test_loss_epoch").append(
60
- trainer.logged_metrics["test_loss_epoch"].cpu().detach().item()
61
- )
62
- pl_module.log_store.get("test_acc_epoch").append(
63
- trainer.logged_metrics["test_acc_epoch"].cpu().detach().item()
64
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utilities/config.py DELETED
@@ -1,58 +0,0 @@
1
- # Seed
2
- SEED = 1
3
-
4
- # Dataset
5
-
6
- CLASSES = (
7
- "Airplane",
8
- "Automobile",
9
- "Bird",
10
- "Cat",
11
- "Deer",
12
- "Dog",
13
- "Frog",
14
- "Horse",
15
- "Ship",
16
- "Truck",
17
- )
18
-
19
- SHUFFLE = True
20
- DATA_DIR = "../data"
21
- NUM_WORKERS = 4
22
- PIN_MEMORY = True
23
-
24
- # Training Hyperparameters
25
-
26
- INPUT_SIZE = (3, 32, 32)
27
- NUM_CLASSES = 10
28
- LEARNING_RATE = 0.001
29
- WEIGHT_DECAY = 1e-4
30
- BATCH_SIZE = 512
31
- NUM_EPOCHS = 24
32
- DROPOUT_PERCENTAGE = 0.05
33
- LAYER_NORM = "bn" # Batch Normalization
34
-
35
- # OPTIMIZER & SCHEDULER
36
-
37
- LRFINDER_END_LR = 0.1
38
- LRFINDER_NUM_ITERATIONS = 50
39
- LRFINDER_STEP_MODE = "exp"
40
-
41
- OCLR_DIV_FACTOR = 100
42
- OCLR_FINAL_DIV_FACTOR = 100
43
- OCLR_THREE_PHASE = False
44
- OCLR_ANNEAL_STRATEGY = "linear"
45
-
46
- # Compute Related
47
-
48
- ACCELERATOR = "cuda"
49
- PRECISION = 32
50
-
51
- # Store
52
-
53
- TRAINING_STAT_STORE = "Store/training_stats.csv"
54
- MODEL_SAVE_PATH = "Store/model.pth"
55
-
56
- # Visualization
57
-
58
- NORM_CONF_MAT = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utilities/dataset.py DELETED
@@ -1,92 +0,0 @@
1
- import numpy as np
2
- import pytorch_lightning as pl
3
- import torch
4
- from torchvision import datasets
5
-
6
-
7
- class CIFAR10(torch.utils.data.Dataset):
8
- def __init__(self, dataset, transform=None) -> None:
9
- # Initialize dataset and transform
10
- self.dataset = dataset
11
- self.transform = transform
12
-
13
- def __len__(self) -> int:
14
- # Return the length of the dataset
15
- return len(self.dataset)
16
-
17
- def __getitem__(self, index):
18
- # Get image and label
19
- image, label = self.dataset[index]
20
-
21
- # Convert PIL image to numpy array
22
- image = np.array(image)
23
-
24
- # Apply transformations
25
- if self.transform:
26
- image = self.transform(image=image)["image"]
27
-
28
- return (image, label)
29
-
30
-
31
- class CIFAR10DataModule(pl.LightningDataModule):
32
- def __init__(
33
- self,
34
- train_transforms,
35
- val_transforms,
36
- shuffle=True,
37
- data_dir="../data",
38
- batch_size=64,
39
- num_workers=-1,
40
- pin_memory=True,
41
- ):
42
- super().__init__()
43
- self.shuffle = shuffle
44
- self.data_dir = data_dir
45
- self.batch_size = batch_size
46
- self.num_workers = num_workers
47
- self.pin_memory = pin_memory
48
- self.train_transforms = train_transforms
49
- self.val_transforms = val_transforms
50
- self.train_data = None
51
- self.val_data = None
52
-
53
- def prepare_data(self):
54
- datasets.CIFAR10(self.data_dir, train=True, download=True)
55
- datasets.CIFAR10(self.data_dir, train=False, download=True)
56
-
57
- def setup(self, stage):
58
- self.train_data = CIFAR10(
59
- datasets.CIFAR10(root=self.data_dir, train=True, download=False),
60
- transform=self.train_transforms,
61
- )
62
- self.val_data = CIFAR10(
63
- datasets.CIFAR10(root=self.data_dir, train=False, download=False),
64
- transform=self.val_transforms,
65
- )
66
-
67
- def train_dataloader(self):
68
- return torch.utils.data.DataLoader(
69
- self.train_data,
70
- batch_size=self.batch_size,
71
- shuffle=self.shuffle,
72
- num_workers=self.num_workers,
73
- pin_memory=self.pin_memory,
74
- )
75
-
76
- def val_dataloader(self):
77
- return torch.utils.data.DataLoader(
78
- self.val_data,
79
- batch_size=self.batch_size,
80
- shuffle=False,
81
- num_workers=self.num_workers,
82
- pin_memory=self.pin_memory,
83
- )
84
-
85
- def test_dataloader(self):
86
- return torch.utils.data.DataLoader(
87
- self.val_data,
88
- batch_size=self.batch_size,
89
- shuffle=False,
90
- num_workers=self.num_workers,
91
- pin_memory=self.pin_memory,
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utilities/resnet.py DELETED
@@ -1,162 +0,0 @@
1
- """
2
- ResNet in PyTorch.
3
- For Pre-activation ResNet, see 'preact_resnet.py'.
4
-
5
- Reference:
6
- [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7
- Deep Residual Learning for Image Recognition. arXiv:1512.03385
8
- """
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import pytorch_lightning as pl
13
- from torchmetrics.functional import accuracy
14
- from torchvision import transforms
15
- from torch.utils.data import DataLoader
16
- from torchvision.datasets import CIFAR10
17
- import albumentations as A
18
- from albumentations.pytorch import ToTensorV2
19
-
20
-
21
- class BasicBlock(nn.Module):
22
- expansion = 1
23
-
24
- def __init__(self, in_planes, planes, stride=1):
25
- super(BasicBlock, self).__init__()
26
- self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
27
- self.bn1 = nn.BatchNorm2d(planes)
28
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
29
- self.bn2 = nn.BatchNorm2d(planes)
30
-
31
- self.shortcut = nn.Sequential()
32
- if stride != 1 or in_planes != self.expansion*planes:
33
- self.shortcut = nn.Sequential(
34
- nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
35
- nn.BatchNorm2d(self.expansion*planes)
36
- )
37
-
38
- def forward(self, x):
39
- out = F.relu(self.bn1(self.conv1(x)))
40
- out = self.bn2(self.conv2(out))
41
- out += self.shortcut(x)
42
- out = F.relu(out)
43
- return out
44
-
45
-
46
- class LitResNet(pl.LightningModule):
47
- def __init__(self, block, num_blocks, num_classes=10,batch_size=128):
48
- super(LitResNet, self).__init__()
49
- self.batch_size = batch_size
50
- self.in_planes = 64
51
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
52
- self.bn1 = nn.BatchNorm2d(64)
53
- self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
54
- self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
55
- self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
56
- self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
57
- self.linear = nn.Linear(512*block.expansion, num_classes)
58
-
59
- def _make_layer(self, block, planes, num_blocks, stride):
60
- strides = [stride] + [1]*(num_blocks-1)
61
- layers = []
62
- for stride in strides:
63
- layers.append(block(self.in_planes, planes, stride))
64
- self.in_planes = planes * block.expansion
65
- return nn.Sequential(*layers)
66
-
67
-
68
- def forward(self, x):
69
- out = F.relu(self.bn1(self.conv1(x)))
70
- out = self.layer1(out)
71
- out = self.layer2(out)
72
- out = self.layer3(out)
73
- out = self.layer4(out)
74
- out = F.avg_pool2d(out, 4)
75
- out = out.view(out.size(0), -1)
76
- out = self.linear(out)
77
- return out
78
-
79
-
80
-
81
- def training_step(self, batch, batch_idx):
82
- x, y = batch
83
- y_hat = self(x)
84
- # Calculate loss
85
- loss = F.cross_entropy(y_hat, y)
86
- #Calculate accuracy
87
- acc = accuracy(y_hat, y)
88
- self.log_dict(
89
- {"train_loss": loss, "train_acc": acc},
90
- on_step=True,
91
- on_epoch=True,
92
- prog_bar=True,
93
- logger=True,
94
- )
95
- return loss
96
-
97
- def validation_step(self, batch, batch_idx):
98
- x, y = batch
99
- y_hat = self(x)
100
- loss = F.cross_entropy(y_hat, y)
101
- acc = accuracy(y_hat, y)
102
- self.log_dict(
103
- {"val_loss": loss, "val_acc": acc},
104
- on_step=True,
105
- on_epoch=True,
106
- prog_bar=True,
107
- logger=True,
108
- )
109
- return loss
110
-
111
- def test_step(self, batch, batch_idx):
112
- x, y = batch
113
- y_hat = self(x)
114
-
115
- argmax_pred = y_hat.argmax(dim=1).cpu()
116
- loss = F.cross_entropy(y_hat, y)
117
- acc = accuracy(y_hat, y)
118
- self.log_dict(
119
- {"test_loss": loss, "test_acc": acc},
120
- on_step=True,
121
- on_epoch=True,
122
- prog_bar=True,
123
- logger=True,
124
- )
125
-
126
- # Update the confusion matrix
127
- self.confusion_matrix.update(y_hat, y)
128
-
129
- # Store the predictions, labels and incorrect predictions
130
- x, y, y_hat, argmax_pred = (
131
- x.cpu(),
132
- y.cpu(),
133
- y_hat.cpu(),
134
- argmax_pred.cpu(),
135
- )
136
- self.pred_store["test_preds"] = torch.cat(
137
- (self.pred_store["test_preds"], argmax_pred), dim=0
138
- )
139
- self.pred_store["test_labels"] = torch.cat(
140
- (self.pred_store["test_labels"], y), dim=0
141
- )
142
- for d, t, p, o in zip(x, y, argmax_pred, y_hat):
143
- if p.eq(t.view_as(p)).item() == False:
144
- self.pred_store["test_incorrect"].append(
145
- (d.cpu(), t, p, o[p.item()].cpu())
146
- )
147
-
148
- return loss
149
-
150
-
151
- def configure_optimizers(self):
152
- return torch.optim.Adam(self.parameters(), lr=0.02)
153
-
154
- def LitResNet18():
155
- return LitResNet(BasicBlock, [2, 2, 2, 2])
156
-
157
- def LitResNet34():
158
- return LitResNet(BasicBlock, [3, 4, 6, 3])
159
-
160
-
161
-
162
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utilities/transforms.py DELETED
@@ -1,20 +0,0 @@
1
- # Third-Party Imports
2
- import torch
3
- import albumentations as A
4
- from albumentations.pytorch import ToTensorV2
5
-
6
-
7
- # Train Phase transformations
8
- train_set_transforms = {
9
- 'randomcrop': A.RandomCrop(height=32, width=32, p=0.2),
10
- 'horizontalflip': A.HorizontalFlip(),
11
- 'cutout': A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=1, min_width=1, fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], mask_fill_value=None),
12
- 'normalize': A.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
13
- 'standardize': ToTensorV2(),
14
- }
15
-
16
- # Test Phase transformations
17
- test_set_transforms = {
18
- 'normalize': A.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
19
- 'standardize': ToTensorV2()
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utilities/visualise.py DELETED
@@ -1,78 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- from torchvision import transforms
3
-
4
-
5
- def plot_class_label_counts(data_loader, classes):
6
- class_counts = {}
7
- for class_name in classes:
8
- class_counts[class_name] = 0
9
- for _, batch_label in data_loader:
10
- for label in batch_label:
11
- class_counts[classes[label.item()]] += 1
12
-
13
- fig = plt.figure()
14
- plt.suptitle("Class Distribution")
15
- plt.bar(range(len(class_counts)), list(class_counts.values()))
16
- plt.xticks(range(len(class_counts)), list(class_counts.keys()), rotation=90)
17
- plt.tight_layout()
18
- plt.show()
19
-
20
-
21
- def plot_data_samples(data_loader, classes):
22
- batch_data, batch_label = next(iter(data_loader))
23
-
24
- fig = plt.figure()
25
- plt.suptitle("Data Samples with Labels post Transforms")
26
- for i in range(12):
27
- plt.subplot(3, 4, i + 1)
28
- plt.tight_layout()
29
- # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
30
- unnormalized = transforms.Normalize(
31
- (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
32
- )(batch_data[i])
33
- plt.imshow(transforms.ToPILImage()(unnormalized))
34
- plt.title(
35
- classes[batch_label[i].item()],
36
- )
37
-
38
- plt.xticks([])
39
- plt.yticks([])
40
-
41
-
42
- def plot_model_training_curves(train_accs, test_accs, train_losses, test_losses):
43
- fig, axs = plt.subplots(2, 2, figsize=(15, 10))
44
- axs[0, 0].plot(train_losses)
45
- axs[0, 0].set_title("Training Loss")
46
- axs[1, 0].plot(train_accs)
47
- axs[1, 0].set_title("Training Accuracy")
48
- axs[0, 1].plot(test_losses)
49
- axs[0, 1].set_title("Test Loss")
50
- axs[1, 1].plot(test_accs)
51
- axs[1, 1].set_title("Test Accuracy")
52
- plt.plot()
53
-
54
-
55
- def plot_incorrect_preds(incorrect, classes, num_imgs):
56
- # num_imgs is a multiple of 5
57
- assert num_imgs % 5 == 0
58
- assert len(incorrect) >= num_imgs
59
-
60
- # incorrect (data, target, pred, output)
61
- print(f"Total Incorrect Predictions {len(incorrect)}")
62
- fig = plt.figure(figsize=(10, num_imgs // 2))
63
- plt.suptitle("Target | Predicted Label")
64
- for i in range(num_imgs):
65
- plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")
66
-
67
- # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
68
- unnormalized = transforms.Normalize(
69
- (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
70
- )(incorrect[i][0])
71
- plt.imshow(transforms.ToPILImage()(unnormalized))
72
- plt.title(
73
- f"{classes[incorrect[i][1].item()]}|{classes[incorrect[i][2].item()]}",
74
- # fontsize=8,
75
- )
76
- plt.xticks([])
77
- plt.yticks([])
78
- plt.tight_layout()