Mojo commited on
Commit
229755d
·
1 Parent(s): b267f13

add utilities file

Browse files
S13.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
utilities/callbacks.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning.callbacks import Callback
3
+
4
+ from .visualize import plot_model_training_curves
5
+
6
+
7
+ class TrainingEndCallback(Callback):
8
+ def on_train_end(self, trainer, pl_module):
9
+ # Perform actions at the end of the entire training process
10
+ print("Training, validation, and testing completed!")
11
+
12
+ logged_metrics = pl_module.log_store
13
+
14
+ plot_model_training_curves(
15
+ train_accs=logged_metrics["train_acc_epoch"],
16
+ test_accs=logged_metrics["val_acc_epoch"],
17
+ train_losses=logged_metrics["train_loss_epoch"],
18
+ test_losses=logged_metrics["val_loss_epoch"],
19
+ )
20
+
21
+
22
+ class PrintLearningMetricsCallback(Callback):
23
+ def on_train_epoch_end(
24
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
25
+ ) -> None:
26
+ super().on_train_epoch_end(trainer, pl_module)
27
+ print(
28
+ f"\nEpoch: {trainer.current_epoch}, Train Loss: {trainer.logged_metrics['train_loss_epoch']}, Train Accuracy: {trainer.logged_metrics['train_acc_epoch']}"
29
+ )
30
+ pl_module.log_store.get("train_loss_epoch").append(
31
+ trainer.logged_metrics["train_loss_epoch"].cpu().detach().item()
32
+ )
33
+ pl_module.log_store.get("train_acc_epoch").append(
34
+ trainer.logged_metrics["train_acc_epoch"].cpu().detach().item()
35
+ )
36
+
37
+ def on_validation_epoch_end(
38
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
39
+ ) -> None:
40
+ super().on_validation_epoch_end(trainer, pl_module)
41
+ print(
42
+ f"\nEpoch: {trainer.current_epoch}, Val Loss: {trainer.logged_metrics['val_loss_epoch']}, Val Accuracy: {trainer.logged_metrics['val_acc_epoch']}"
43
+ )
44
+ pl_module.log_store.get("val_loss_epoch").append(
45
+ trainer.logged_metrics["val_loss_epoch"].cpu().detach().item()
46
+ )
47
+ pl_module.log_store.get("val_acc_epoch").append(
48
+ trainer.logged_metrics["val_acc_epoch"].cpu().detach().item()
49
+ )
50
+
51
+
52
+ def on_test_epoch_end(
53
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
54
+ ) -> None:
55
+ super().on_test_epoch_end(trainer, pl_module)
56
+ print(
57
+ f"\nEpoch: {trainer.current_epoch}, Test Loss: {trainer.logged_metrics['test_loss_epoch']}, Test Accuracy: {trainer.logged_metrics['test_acc_epoch']}"
58
+ )
59
+ pl_module.log_store.get("test_loss_epoch").append(
60
+ trainer.logged_metrics["test_loss_epoch"].cpu().detach().item()
61
+ )
62
+ pl_module.log_store.get("test_acc_epoch").append(
63
+ trainer.logged_metrics["test_acc_epoch"].cpu().detach().item()
64
+ )
utilities/dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from torch.utils.data import DataLoader
3
+ from torchvision.datasets import CIFAR10
4
+ from torchvision import transforms
5
+ import torch
6
+ import numpy as np
7
+
8
+
9
+ class CIFAR10(torch.utils.data.Dataset):
10
+ def __init__(self, dataset, transform=None) -> None:
11
+ # Initialize dataset and transform
12
+ self.dataset = dataset
13
+ self.transform = transform
14
+
15
+ def __len__(self) -> int:
16
+ # Return the length of the dataset
17
+ return len(self.dataset)
18
+
19
+ def __getitem__(self, index):
20
+ # Get image and label
21
+ image, label = self.dataset[index]
22
+
23
+ # Convert PIL image to numpy array
24
+ image = np.array(image)
25
+
26
+ # Apply transformations
27
+ if self.transform:
28
+ image = self.transform(image=image)["image"]
29
+
30
+ return (image, label)
31
+
32
+ class CIFAR10DataModule(pl.LightningDataModule):
33
+ def __init__(self, train_set_transforms,test_set_transforms, data_dir: str = "./data",batch_size: int = 64, num_workers: int = 4):
34
+ super().__init__()
35
+ self.data_dir = data_dir
36
+ self.batch_size = batch_size
37
+ self.num_workers = num_workers
38
+ self.train_set_transforms =train_set_transforms
39
+ self.test_set_transforms = test_set_transforms
40
+
41
+ def prepare_data(self):
42
+ # Download the CIFAR10 dataset
43
+ CIFAR10(self.data_dir, train=True, download=True)
44
+ CIFAR10(self.data_dir, train=False, download=True)
45
+
46
+ def setup(self, stage: str = None):
47
+ # Load the dataset
48
+ if stage == "fit" or stage is None:
49
+ self.cifar10_train = CIFAR10(self.data_dir, train=True, transform=self.train_set_transforms)
50
+ self.cifar10_val = CIFAR10(self.data_dir, train=False, transform=self.train_set_transforms)
51
+ if stage == "test" or stage is None:
52
+ self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.test_set_transforms)
53
+
54
+ def train_dataloader(self):
55
+ return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
56
+
57
+ def val_dataloader(self):
58
+ return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=self.num_workers)
59
+
60
+ def test_dataloader(self):
61
+ return DataLoader(self.cifar10_test, batch_size=self.batch_size, num_workers=self.num_workers)
utilities/resnet.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ResNet in PyTorch.
3
+ For Pre-activation ResNet, see 'preact_resnet.py'.
4
+
5
+ Reference:
6
+ [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import pytorch_lightning as pl
13
+ from torchmetrics.functional import accuracy
14
+ from torchvision import transforms
15
+ from torch.utils.data import DataLoader
16
+ from torchvision.datasets import CIFAR10
17
+ import albumentations as A
18
+ from albumentations.pytorch import ToTensorV2
19
+
20
+
21
+ class BasicBlock(nn.Module):
22
+ expansion = 1
23
+
24
+ def __init__(self, in_planes, planes, stride=1):
25
+ super(BasicBlock, self).__init__()
26
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
27
+ self.bn1 = nn.BatchNorm2d(planes)
28
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
29
+ self.bn2 = nn.BatchNorm2d(planes)
30
+
31
+ self.shortcut = nn.Sequential()
32
+ if stride != 1 or in_planes != self.expansion*planes:
33
+ self.shortcut = nn.Sequential(
34
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
35
+ nn.BatchNorm2d(self.expansion*planes)
36
+ )
37
+
38
+ def forward(self, x):
39
+ out = F.relu(self.bn1(self.conv1(x)))
40
+ out = self.bn2(self.conv2(out))
41
+ out += self.shortcut(x)
42
+ out = F.relu(out)
43
+ return out
44
+
45
+
46
+ class LitResNet(pl.LightningModule):
47
+ def __init__(self, block, num_blocks, num_classes=10,batch_size=128):
48
+ super(LitResNet, self).__init__()
49
+ self.batch_size = batch_size
50
+ self.in_planes = 64
51
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
52
+ self.bn1 = nn.BatchNorm2d(64)
53
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
54
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
55
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
56
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
57
+ self.linear = nn.Linear(512*block.expansion, num_classes)
58
+
59
+ def _make_layer(self, block, planes, num_blocks, stride):
60
+ strides = [stride] + [1]*(num_blocks-1)
61
+ layers = []
62
+ for stride in strides:
63
+ layers.append(block(self.in_planes, planes, stride))
64
+ self.in_planes = planes * block.expansion
65
+ return nn.Sequential(*layers)
66
+
67
+
68
+ def forward(self, x):
69
+ out = F.relu(self.bn1(self.conv1(x)))
70
+ out = self.layer1(out)
71
+ out = self.layer2(out)
72
+ out = self.layer3(out)
73
+ out = self.layer4(out)
74
+ out = F.avg_pool2d(out, 4)
75
+ out = out.view(out.size(0), -1)
76
+ out = self.linear(out)
77
+ return out
78
+
79
+
80
+
81
+ def training_step(self, batch, batch_idx):
82
+ x, y = batch
83
+ y_hat = self(x)
84
+ # Calculate loss
85
+ loss = F.cross_entropy(y_hat, y)
86
+ #Calculate accuracy
87
+ acc = accuracy(y_hat, y)
88
+ self.log_dict(
89
+ {"train_loss": loss, "train_acc": acc},
90
+ on_step=True,
91
+ on_epoch=True,
92
+ prog_bar=True,
93
+ logger=True,
94
+ )
95
+ return loss
96
+
97
+ def validation_step(self, batch, batch_idx):
98
+ x, y = batch
99
+ y_hat = self(x)
100
+ loss = F.cross_entropy(y_hat, y)
101
+ acc = accuracy(y_hat, y)
102
+ self.log_dict(
103
+ {"val_loss": loss, "val_acc": acc},
104
+ on_step=True,
105
+ on_epoch=True,
106
+ prog_bar=True,
107
+ logger=True,
108
+ )
109
+ return loss
110
+
111
+ def test_step(self, batch, batch_idx):
112
+ x, y = batch
113
+ y_hat = self(x)
114
+
115
+ argmax_pred = y_hat.argmax(dim=1).cpu()
116
+ loss = F.cross_entropy(y_hat, y)
117
+ acc = accuracy(y_hat, y)
118
+ self.log_dict(
119
+ {"test_loss": loss, "test_acc": acc},
120
+ on_step=True,
121
+ on_epoch=True,
122
+ prog_bar=True,
123
+ logger=True,
124
+ )
125
+
126
+ # Update the confusion matrix
127
+ self.confusion_matrix.update(y_hat, y)
128
+
129
+ # Store the predictions, labels and incorrect predictions
130
+ x, y, y_hat, argmax_pred = (
131
+ x.cpu(),
132
+ y.cpu(),
133
+ y_hat.cpu(),
134
+ argmax_pred.cpu(),
135
+ )
136
+ self.pred_store["test_preds"] = torch.cat(
137
+ (self.pred_store["test_preds"], argmax_pred), dim=0
138
+ )
139
+ self.pred_store["test_labels"] = torch.cat(
140
+ (self.pred_store["test_labels"], y), dim=0
141
+ )
142
+ for d, t, p, o in zip(x, y, argmax_pred, y_hat):
143
+ if p.eq(t.view_as(p)).item() == False:
144
+ self.pred_store["test_incorrect"].append(
145
+ (d.cpu(), t, p, o[p.item()].cpu())
146
+ )
147
+
148
+ return loss
149
+
150
+
151
+ def configure_optimizers(self):
152
+ return torch.optim.Adam(self.parameters(), lr=0.02)
153
+
154
+ def LitResNet18():
155
+ return LitResNet(BasicBlock, [2, 2, 2, 2])
156
+
157
+ def LitResNet34():
158
+ return LitResNet(BasicBlock, [3, 4, 6, 3])
159
+
160
+
161
+
162
+
utilities/transforms.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Third-Party Imports
2
+ import torch
3
+ import albumentations as A
4
+ from albumentations.pytorch import ToTensorV2
5
+
6
+
7
+ # Train Phase transformations
8
+ train_set_transforms = {
9
+ 'randomcrop': A.RandomCrop(height=32, width=32, p=0.2),
10
+ 'horizontalflip': A.HorizontalFlip(),
11
+ 'cutout': A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=1, min_width=1, fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], mask_fill_value=None),
12
+ 'normalize': A.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
13
+ 'standardize': ToTensorV2(),
14
+ }
15
+
16
+ # Test Phase transformations
17
+ test_set_transforms = {
18
+ 'normalize': A.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
19
+ 'standardize': ToTensorV2()
20
+ }
utilities/visualise.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from torchvision import transforms
3
+ import torch
4
+
5
+
6
+ def plot_class_label_counts(data_loader, classes):
7
+ class_counts = {}
8
+ for class_name in classes:
9
+ class_counts[class_name] = 0
10
+ for _, batch_label in data_loader:
11
+ for label in batch_label:
12
+ class_counts[classes[label.item()]] += 1
13
+
14
+ fig = plt.figure()
15
+ plt.suptitle("Class Distribution")
16
+ plt.bar(range(len(class_counts)), list(class_counts.values()))
17
+ plt.xticks(range(len(class_counts)), list(class_counts.keys()), rotation=90)
18
+ plt.tight_layout()
19
+ plt.show()
20
+
21
+
22
+ def plot_data_samples(data_loader, classes):
23
+ batch_data, batch_label = next(iter(data_loader))
24
+
25
+ fig = plt.figure()
26
+ plt.suptitle("Data Samples with Labels post Transforms")
27
+ for i in range(12):
28
+ plt.subplot(3, 4, i + 1)
29
+ plt.tight_layout()
30
+ # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
31
+ unnormalized = transforms.Normalize(
32
+ (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
33
+ )(batch_data[i])
34
+ plt.imshow(transforms.ToPILImage()(unnormalized))
35
+ plt.title(
36
+ classes[batch_label[i].item()],
37
+ )
38
+
39
+ plt.xticks([])
40
+ plt.yticks([])
41
+
42
+
43
+ def plot_model_training_curves(train_accs, test_accs, train_losses, test_losses):
44
+ fig, axs = plt.subplots(2, 2, figsize=(15, 10))
45
+ axs[0, 0].plot(train_losses)
46
+ axs[0, 0].set_title("Training Loss")
47
+ axs[1, 0].plot(train_accs)
48
+ axs[1, 0].set_title("Training Accuracy")
49
+ axs[0, 1].plot(test_losses)
50
+ axs[0, 1].set_title("Test Loss")
51
+ axs[1, 1].plot(test_accs)
52
+ axs[1, 1].set_title("Test Accuracy")
53
+ plt.plot()
54
+
55
+
56
+ def plot_incorrect_preds(incorrect, classes, num_imgs):
57
+ # num_imgs is a multiple of 5
58
+ assert num_imgs % 5 == 0
59
+ assert len(incorrect) >= num_imgs
60
+
61
+ # incorrect (data, target, pred, output)
62
+ print(f"Total Incorrect Predictions {len(incorrect)}")
63
+ fig = plt.figure(figsize=(10, num_imgs // 2))
64
+ plt.suptitle("Target | Predicted Label")
65
+ for i in range(num_imgs):
66
+ plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")
67
+
68
+ # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
69
+ unnormalized = transforms.Normalize(
70
+ (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
71
+ )(incorrect[i][0])
72
+ plt.imshow(transforms.ToPILImage()(unnormalized))
73
+ plt.title(
74
+ f"{classes[incorrect[i][1].item()]}|{classes[incorrect[i][2].item()]}",
75
+ # fontsize=8,
76
+ )
77
+ plt.xticks([])
78
+ plt.yticks([])
79
+ plt.tight_layout()
80
+
81
+
82
+ def display_cifar_data_samples(data_set, number_of_samples: int, classes: list):
83
+ """
84
+ Function to display samples for data_set
85
+ :param data_set: Train or Test data_set transformed to Tensor
86
+ :param number_of_samples: Number of samples to be displayed
87
+ :param classes: Name of classes to be displayed
88
+ """
89
+ # Get batch from the data_set
90
+ batch_data = []
91
+ batch_label = []
92
+ for count, item in enumerate(data_set):
93
+ if not count <= number_of_samples:
94
+ break
95
+ batch_data.append(item[0])
96
+ batch_label.append(item[1])
97
+ batch_data = torch.stack(batch_data, dim=0).numpy()
98
+
99
+ # Plot the samples from the batch
100
+ fig = plt.figure()
101
+ x_count = 5
102
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
103
+
104
+ for i in range(number_of_samples):
105
+ plt.subplot(y_count, x_count, i + 1)
106
+ plt.tight_layout()
107
+ plt.imshow(np.transpose(batch_data[i].squeeze(), (1, 2, 0)))
108
+ plt.title(classes[batch_label[i]])
109
+ plt.xticks([])
110
+ plt.yticks([])
111
+
112
+
113
+ # ---------------------------- MISCLASSIFIED DATA ----------------------------
114
+ def display_cifar_misclassified_data(data: list,
115
+ classes: list[str],
116
+ inv_normalize: transforms.Normalize,
117
+ number_of_samples: int = 10):
118
+ """
119
+ Function to plot images with labels
120
+ :param data: List[Tuple(image, label)]
121
+ :param classes: Name of classes in the dataset
122
+ :param inv_normalize: Mean and Standard deviation values of the dataset
123
+ :param number_of_samples: Number of images to print
124
+ """
125
+ fig = plt.figure(figsize=(10, 10))
126
+
127
+ x_count = 5
128
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
129
+
130
+ for i in range(number_of_samples):
131
+ plt.subplot(y_count, x_count, i + 1)
132
+ img = data[i][0].squeeze().to('cpu')
133
+ img = inv_normalize(img)
134
+ plt.imshow(np.transpose(img, (1, 2, 0)))
135
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
136
+ plt.xticks([])
137
+ plt.yticks([])
138
+
139
+
140
+ def display_mnist_misclassified_data(data: list,
141
+ number_of_samples: int = 10):
142
+ """
143
+ Function to plot images with labels
144
+ :param data: List[Tuple(image, label)]
145
+ :param number_of_samples: Number of images to print
146
+ """
147
+ fig = plt.figure(figsize=(8, 5))
148
+
149
+ x_count = 5
150
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
151
+
152
+ for i in range(number_of_samples):
153
+ plt.subplot(y_count, x_count, i + 1)
154
+ img = data[i][0].squeeze(0).to('cpu')
155
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
156
+ plt.title(r"Correct: " + str(data[i][1].item()) + '\n' + 'Output: ' + str(data[i][2].item()))
157
+ plt.xticks([])
158
+ plt.yticks([])
159
+
160
+
161
+ # ---------------------------- AUGMENTATION SAMPLES ----------------------------
162
+ def visualize_cifar_augmentation(data_set, data_transforms):
163
+ """
164
+ Function to visualize the augmented data
165
+ :param data_set: Dataset without transformations
166
+ :param data_transforms: Dictionary of transforms
167
+ """
168
+ sample, label = data_set[6]
169
+ total_augmentations = len(data_transforms)
170
+
171
+ fig = plt.figure(figsize=(10, 5))
172
+ for count, (key, trans) in enumerate(data_transforms.items()):
173
+ if count == total_augmentations - 1:
174
+ break
175
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
176
+ augmented = trans(image=sample)['image']
177
+ plt.imshow(augmented)
178
+ plt.title(key)
179
+ plt.xticks([])
180
+ plt.yticks([])
181
+
182
+
183
+ def visualize_mnist_augmentation(data_set, data_transforms):
184
+ """
185
+ Function to visualize the augmented data
186
+ :param data_set: Dataset to visualize the augmentations
187
+ :param data_transforms: Dictionary of transforms
188
+ """
189
+ sample, label = data_set[6]
190
+ total_augmentations = len(data_transforms)
191
+
192
+ fig = plt.figure(figsize=(10, 5))
193
+ for count, (key, trans) in enumerate(data_transforms.items()):
194
+ if count == total_augmentations - 1:
195
+ break
196
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
197
+ img = trans(sample).to('cpu')
198
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
199
+ plt.title(key)
200
+ plt.xticks([])
201
+ plt.yticks([])
202
+
203
+
204
+ # ---------------------------- LOSS AND ACCURACIES ----------------------------
205
+ def display_loss_and_accuracies(train_losses: list,
206
+ train_acc: list,
207
+ test_losses: list,
208
+ test_acc: list,
209
+ plot_size: tuple = (10, 10)) -> NoReturn:
210
+ """
211
+ Function to display training and test information(losses and accuracies)
212
+ :param train_losses: List containing training loss of each epoch
213
+ :param train_acc: List containing training accuracy of each epoch
214
+ :param test_losses: List containing test loss of each epoch
215
+ :param test_acc: List containing test accuracy of each epoch
216
+ :param plot_size: Size of the plot
217
+ """
218
+ # Create a plot of 2x2 of size
219
+ fig, axs = plt.subplots(2, 2, figsize=plot_size)
220
+
221
+ # Plot the training loss and accuracy for each epoch
222
+ axs[0, 0].plot(train_losses)
223
+ axs[0, 0].set_title("Training Loss")
224
+ axs[1, 0].plot(train_acc)
225
+ axs[1, 0].set_title("Training Accuracy")
226
+
227
+ # Plot the test loss and accuracy for each epoch
228
+ axs[0, 1].plot(test_losses)
229
+ axs[0, 1].set_title("Test Loss")
230
+ axs[1, 1].plot(test_acc)
231
+ axs[1, 1].set_title("Test Accuracy")
232
+
233
+
234
+ # ---------------------------- Feature Maps and Kernels ----------------------------
235
+
236
+ @dataclass
237
+ class ConvLayerInfo:
238
+ """
239
+ Data Class to store Conv layer's information
240
+ """
241
+ layer_number: int
242
+ weights: torch.nn.parameter.Parameter
243
+ layer_info: torch.nn.modules.conv.Conv2d
244
+
245
+
246
+ class FeatureMapVisualizer:
247
+ """
248
+ Class to visualize Feature Map of the Layers
249
+ """
250
+
251
+ def __init__(self, model):
252
+ """
253
+ Contructor
254
+ :param model: Model Architecture
255
+ """
256
+ self.conv_layers = []
257
+ self.outputs = []
258
+ self.layerwise_kernels = None
259
+
260
+ # Disect the model
261
+ counter = 0
262
+ model_children = model.children()
263
+ for children in model_children:
264
+ if type(children) == nn.Sequential:
265
+ for child in children:
266
+ if type(child) == nn.Conv2d:
267
+ counter += 1
268
+ self.conv_layers.append(ConvLayerInfo(layer_number=counter,
269
+ weights=child.weight,
270
+ layer_info=child)
271
+ )
272
+
273
+ def get_model_weights(self):
274
+ """
275
+ Method to get the model weights
276
+ """
277
+ model_weights = [layer.weights for layer in self.conv_layers]
278
+ return model_weights
279
+
280
+ def get_conv_layers(self):
281
+ """
282
+ Get the convolution layers
283
+ """
284
+ conv_layers = [layer.layer_info for layer in self.conv_layers]
285
+ return conv_layers
286
+
287
+ def get_total_conv_layers(self) -> int:
288
+ """
289
+ Get total number of convolution layers
290
+ """
291
+ out = self.get_conv_layers()
292
+ return len(out)
293
+
294
+ def feature_maps_of_all_kernels(self, image: torch.Tensor) -> dict:
295
+ """
296
+ Get feature maps from all the kernels of all the layers
297
+ :param image: Image to be passed to the network
298
+ """
299
+ image = image.unsqueeze(0)
300
+ image = image.to('cpu')
301
+
302
+ outputs = {}
303
+
304
+ layers = self.get_conv_layers()
305
+ for index, layer in enumerate(layers):
306
+ image = layer(image)
307
+ outputs[str(layer)] = image
308
+ self.outputs = outputs
309
+ return outputs
310
+
311
+ def visualize_feature_map_of_kernel(self, image: torch.Tensor, kernel_number: int) -> None:
312
+ """
313
+ Function to visualize feature map of kernel number from each layer
314
+ :param image: Image passed to the network
315
+ :param kernel_number: Number of kernel in each layer (Should be less than or equal to the minimum number of kernel in the network)
316
+ """
317
+ # List to store processed feature maps
318
+ processed = []
319
+
320
+ # Get feature maps from all kernels of all the conv layers
321
+ outputs = self.feature_maps_of_all_kernels(image)
322
+
323
+ # Extract the n_th kernel's output from each layer and convert it to grayscale
324
+ for feature_map in outputs.values():
325
+ try:
326
+ feature_map = feature_map[0][kernel_number]
327
+ except IndexError:
328
+ print("Filter number should be less than the minimum number of channels in a network")
329
+ break
330
+ finally:
331
+ gray_scale = feature_map / feature_map.shape[0]
332
+ processed.append(gray_scale.data.numpy())
333
+
334
+ # Plot the Feature maps with layer and kernel number
335
+ x_range = len(outputs) // 5 + 4
336
+ fig = plt.figure(figsize=(10, 10))
337
+ for i in range(len(processed)):
338
+ a = fig.add_subplot(x_range, 5, i + 1)
339
+ imgplot = plt.imshow(processed[i])
340
+ a.axis("off")
341
+ title = f"{list(outputs.keys())[i].split('(')[0]}_l{i + 1}_k{kernel_number}"
342
+ a.set_title(title, fontsize=10)
343
+
344
+ def get_max_kernel_number(self):
345
+ """
346
+ Function to get maximum number of kernels in the network (for a layer)
347
+ """
348
+ layers = self.get_conv_layers()
349
+ channels = [layer.out_channels for layer in layers]
350
+ self.layerwise_kernels = channels
351
+ return max(channels)
352
+
353
+ def visualize_kernels_from_layer(self, layer_number: int):
354
+ """
355
+ Visualize Kernels from a layer
356
+ :param layer_number: Number of layer from which kernels are to be visualized
357
+ """
358
+ # Get the kernels number for each layer
359
+ self.get_max_kernel_number()
360
+
361
+ # Zero Indexing
362
+ layer_number = layer_number - 1
363
+ _kernels = self.layerwise_kernels[layer_number]
364
+
365
+ grid = math.ceil(math.sqrt(_kernels))
366
+
367
+ plt.figure(figsize=(5, 4))
368
+ model_weights = self.get_model_weights()
369
+ _layer_weights = model_weights[layer_number].cpu()
370
+ for i, filter in enumerate(_layer_weights):
371
+ plt.subplot(grid, grid, i + 1)
372
+ plt.imshow(filter[0, :, :].detach(), cmap='gray')
373
+ plt.axis('off')
374
+ plt.show()
375
+
376
+
377
+ # ---------------------------- Confusion Matrix ----------------------------
378
+ def visualize_confusion_matrix(classes: list[str], device: str, model: 'DL Model',
379
+ test_loader: torch.utils.data.DataLoader):
380
+ """
381
+ Function to generate and visualize confusion matrix
382
+ :param classes: List of class names
383
+ :param device: cuda/cpu
384
+ :param model: Model Architecture
385
+ :param test_loader: DataLoader for test set
386
+ """
387
+ nb_classes = len(classes)
388
+ device = 'cuda'
389
+ cm = torch.zeros(nb_classes, nb_classes)
390
+
391
+ model.eval()
392
+ with torch.no_grad():
393
+ for inputs, labels in test_loader:
394
+ inputs = inputs.to(device)
395
+ labels = labels.to(device)
396
+ model = model.to(device)
397
+
398
+ preds = model(inputs)
399
+ preds = preds.argmax(dim=1)
400
+
401
+ for t, p in zip(labels.view(-1), preds.view(-1)):
402
+ cm[t, p] = cm[t, p] + 1
403
+
404
+ # Build confusion matrix
405
+ labels = labels.to('cpu')
406
+ preds = preds.to('cpu')
407
+ cf_matrix = confusion_matrix(labels, preds)
408
+ df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None],
409
+ index=[i for i in classes],
410
+ columns=[i for i in classes])
411
+ plt.figure(figsize=(12, 7))
412
+ sn.heatmap(df_cm, annot=True)