swapniel99 commited on
Commit
375ca58
·
1 Parent(s): 9d2eeb0

Upload 8 files

Browse files
datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .cifar10 import CIFAR10
datasets/cifar10.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from torchvision import datasets
4
+ import albumentations as A
5
+
6
+ from .generic import MyDataSet
7
+
8
+
9
+ class AlbCIFAR10(datasets.CIFAR10):
10
+ def __init__(self, root, alb_transform=None, **kwargs):
11
+ super(AlbCIFAR10, self).__init__(root, **kwargs)
12
+ self.alb_transform = alb_transform
13
+
14
+ def __getitem__(self, index):
15
+ image, label = super(AlbCIFAR10, self).__getitem__(index)
16
+ if self.alb_transform is not None:
17
+ image = self.alb_transform(image=np.array(image))['image']
18
+ return image, label
19
+
20
+
21
+ class CIFAR10(MyDataSet):
22
+ DataSet = AlbCIFAR10
23
+ mean = (0.49139968, 0.48215827, 0.44653124)
24
+ std = (0.24703233, 0.24348505, 0.26158768)
25
+ default_alb_transforms = [
26
+ A.ToGray(p=0.2),
27
+ A.PadIfNeeded(40, 40, p=1),
28
+ A.RandomCrop(32, 32, p=1),
29
+ A.HorizontalFlip(p=0.5),
30
+ # Since normalisation was the first step, mean is already 0, so cutout fill_value = 0
31
+ A.CoarseDropout(max_holes=1, max_height=8, max_width=8, fill_value=0, p=1)
32
+ ]
datasets/generic.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from functools import cached_property
4
+
5
+ import torch
6
+ import albumentations as A
7
+ from albumentations.pytorch import ToTensorV2
8
+
9
+ try:
10
+ from epoch.utils import plot_examples
11
+ except ModuleNotFoundError:
12
+ from utils import plot_examples
13
+
14
+
15
+ class MyDataSet(ABC):
16
+ DataSet = None
17
+ mean = None
18
+ std = None
19
+ classes = None
20
+ default_alb_transforms = None
21
+
22
+ def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None):
23
+ self.batch_size = batch_size
24
+ self.normalize = normalize
25
+ self.shuffle = shuffle
26
+ self.augment = augment
27
+ self.alb_transforms = alb_transforms or self.default_alb_transforms
28
+
29
+ self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True}
30
+
31
+ @classmethod
32
+ def set_classes(cls, data):
33
+ if cls.classes is None:
34
+ cls.classes = {i: c for i, c in enumerate(data.classes)}
35
+
36
+ @cached_property
37
+ def train_data(self):
38
+ res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms())
39
+ self.set_classes(res)
40
+ return res
41
+
42
+ @cached_property
43
+ def test_data(self):
44
+ res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms())
45
+ self.set_classes(res)
46
+ return res
47
+
48
+ @cached_property
49
+ def train_loader(self):
50
+ return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs)
51
+
52
+ @cached_property
53
+ def test_loader(self):
54
+ return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs)
55
+
56
+ @cached_property
57
+ def example_iter(self):
58
+ return iter(self.train_loader)
59
+
60
+ def get_train_transforms(self):
61
+ all_transforms = list()
62
+ if self.normalize:
63
+ all_transforms.append(A.Normalize(self.mean, self.std))
64
+ if self.augment and self.alb_transforms is not None:
65
+ all_transforms.extend(self.alb_transforms)
66
+ all_transforms.append(ToTensorV2())
67
+ return A.Compose(all_transforms)
68
+
69
+ def get_test_transforms(self):
70
+ all_transforms = list()
71
+ if self.normalize:
72
+ all_transforms.append(A.Normalize(self.mean, self.std))
73
+ all_transforms.append(ToTensorV2())
74
+ return A.Compose(all_transforms)
75
+
76
+ def download(self):
77
+ self.DataSet('../data', train=True, download=True)
78
+ self.DataSet('../data', train=False, download=True)
79
+
80
+ def denormalise(self, tensor):
81
+ result = tensor.clone().detach().requires_grad_(False)
82
+ if self.normalize:
83
+ for t, m, s in zip(result, self.mean, self.std):
84
+ t.mul_(s).add_(m)
85
+ return result
86
+
87
+ def show_transform(self, img):
88
+ if self.normalize:
89
+ img = self.denormalise(img)
90
+ if len(self.mean) == 3:
91
+ return img.permute(1, 2, 0)
92
+ else:
93
+ return img.squeeze(0)
94
+
95
+ def show_examples(self, figsize=(8, 6)):
96
+ batch_data, batch_label = next(self.example_iter)
97
+ images = list()
98
+ labels = list()
99
+
100
+ for i in range(len(batch_data)):
101
+ image = batch_data[i]
102
+ image = self.show_transform(image)
103
+
104
+ label = batch_label[i].item()
105
+ if self.classes is not None:
106
+ label = f'{label}:{self.classes[label]}'
107
+
108
+ images.append(image)
109
+ labels.append(label)
110
+
111
+ plot_examples(images, labels, figsize=figsize)
models/__init__.py ADDED
File without changes
models/custom_resnet.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch import optim
3
+ from pytorch_lightning import LightningModule
4
+ from torchmetrics import MeanMetric
5
+ from torch_lr_finder import LRFinder
6
+
7
+ from utils.metrics import RunningAccuracy
8
+
9
+
10
+ class ConvLayer(nn.Module):
11
+ def __init__(self, input_c, output_c, bias=False, stride=1, padding=1, pool=False, dropout=0.):
12
+ super(ConvLayer, self).__init__()
13
+
14
+ layers = list()
15
+ layers.append(
16
+ nn.Conv2d(input_c, output_c, kernel_size=3, bias=bias, stride=stride, padding=padding,
17
+ padding_mode='replicate')
18
+ )
19
+ if pool:
20
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
21
+ layers.append(nn.BatchNorm2d(output_c))
22
+ layers.append(nn.ReLU())
23
+ if dropout > 0:
24
+ layers.append(nn.Dropout(dropout))
25
+
26
+ self.all_layers = nn.Sequential(*layers)
27
+
28
+ def forward(self, x):
29
+ return self.all_layers(x)
30
+
31
+
32
+ class CustomLayer(nn.Module):
33
+ def __init__(self, input_c, output_c, pool=True, residue=2, dropout=0.):
34
+ super(CustomLayer, self).__init__()
35
+
36
+ self.pool_block = ConvLayer(input_c, output_c, pool=pool, dropout=dropout)
37
+ self.res_block = None
38
+ if residue > 0:
39
+ layers = list()
40
+ for i in range(0, residue):
41
+ layers.append(ConvLayer(output_c, output_c, pool=False, dropout=dropout))
42
+ self.res_block = nn.Sequential(*layers)
43
+
44
+ def forward(self, x):
45
+ x = self.pool_block(x)
46
+ if self.res_block is not None:
47
+ x_ = x
48
+ x = self.res_block(x)
49
+ # += operator causes inplace errors in pytorch if done right after relu.
50
+ x = x + x_
51
+ return x
52
+
53
+
54
+ class Model(LightningModule):
55
+ def __init__(self, dataset, dropout=0.05, max_epochs=24):
56
+ super(Model, self).__init__()
57
+
58
+ self.dataset = dataset
59
+
60
+ self.network = nn.Sequential(
61
+ CustomLayer(3, 64, pool=False, residue=0, dropout=dropout),
62
+ CustomLayer(64, 128, pool=True, residue=2, dropout=dropout),
63
+ CustomLayer(128, 256, pool=True, residue=0, dropout=dropout),
64
+ CustomLayer(256, 512, pool=True, residue=2, dropout=dropout),
65
+ nn.MaxPool2d(kernel_size=4, stride=4),
66
+ nn.Flatten(),
67
+ nn.Linear(512, 10)
68
+ )
69
+
70
+ self.criterion = nn.CrossEntropyLoss()
71
+ self.train_accuracy = RunningAccuracy()
72
+ self.val_accuracy = RunningAccuracy()
73
+ self.train_loss = MeanMetric()
74
+ self.val_loss = MeanMetric()
75
+
76
+ self.max_epochs = max_epochs
77
+ self.epoch_counter = 1
78
+
79
+ def forward(self, x):
80
+ return self.network(x)
81
+
82
+ def common_step(self, batch, loss_metric, acc_metric):
83
+ x, y = batch
84
+ batch_len = y.numel()
85
+ logits = self.forward(x)
86
+ loss = self.criterion(logits, y)
87
+ loss_metric.update(loss, batch_len)
88
+ acc_metric.update(logits, y)
89
+ return loss
90
+
91
+ def training_step(self, batch, batch_idx):
92
+ return self.common_step(batch, self.train_loss, self.train_accuracy)
93
+
94
+ def on_train_epoch_end(self):
95
+ print(f"Epoch: {self.epoch_counter}, Train: Loss: {self.train_loss.compute():0.4f}, Accuracy: "
96
+ f"{self.train_accuracy.compute():0.2f}")
97
+ self.train_loss.reset()
98
+ self.train_accuracy.reset()
99
+ self.epoch_counter += 1
100
+
101
+ def validation_step(self, batch, batch_idx):
102
+ loss = self.common_step(batch, self.val_loss, self.val_accuracy)
103
+ self.log("val_step_loss", self.val_loss, prog_bar=True, logger=True)
104
+ self.log("val_step_acc", self.val_accuracy, prog_bar=True, logger=True)
105
+ return loss
106
+
107
+ def on_validation_epoch_end(self):
108
+ print(f"Epoch: {self.epoch_counter}, Valid: Loss: {self.val_loss.compute():0.4f}, Accuracy: "
109
+ f"{self.val_accuracy.compute():0.2f}")
110
+ self.val_loss.reset()
111
+ self.val_accuracy.reset()
112
+
113
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
114
+ if isinstance(batch, list):
115
+ x, _ = batch
116
+ else:
117
+ x = batch
118
+ return self.forward(x)
119
+
120
+ def find_lr(self, optimizer):
121
+ lr_finder = LRFinder(self, optimizer, self.criterion)
122
+ lr_finder.range_test(self.dataset.train_loader, end_lr=0.1, num_iter=100, step_mode='exp')
123
+ _, best_lr = lr_finder.plot()
124
+ lr_finder.reset()
125
+ return best_lr
126
+
127
+ def configure_optimizers(self):
128
+ optimizer = optim.Adam(self.parameters(), lr=1e-7, weight_decay=1e-2)
129
+ best_lr = self.find_lr(optimizer)
130
+ scheduler = optim.lr_scheduler.OneCycleLR(
131
+ optimizer,
132
+ max_lr=best_lr,
133
+ steps_per_epoch=len(self.dataset.train_loader),
134
+ epochs=self.max_epochs,
135
+ pct_start=5/self.max_epochs,
136
+ div_factor=100,
137
+ three_phase=False,
138
+ final_div_factor=100,
139
+ anneal_strategy='linear'
140
+ )
141
+ return {
142
+ 'optimizer': optimizer,
143
+ 'lr_scheduler': {
144
+ "scheduler": scheduler,
145
+ "interval": "step",
146
+ }
147
+ }
148
+
149
+ def prepare_data(self):
150
+ self.dataset.download()
151
+
152
+ def train_dataloader(self):
153
+ return self.dataset.train_loader
154
+
155
+ def val_dataloader(self):
156
+ return self.dataset.test_loader
157
+
158
+ def predict_dataloader(self):
159
+ return self.val_dataloader()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .misc import *
utils/metrics.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torchmetrics import Metric
4
+
5
+
6
+ class RunningAccuracy(Metric):
7
+ def __init__(self, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
10
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
11
+
12
+ def update(self, preds: Tensor, target: Tensor):
13
+ preds = preds.argmax(dim=1)
14
+ total = target.numel()
15
+ self.correct += preds.eq(target).sum()
16
+ self.total += total
17
+
18
+ def compute(self):
19
+ return 100 * self.correct.float() / self.total
utils/misc.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchinfo
3
+ from matplotlib import pyplot as plt
4
+ from pytorch_grad_cam import GradCAM
5
+ from pytorch_grad_cam.utils.image import show_cam_on_image
6
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
7
+
8
+ SEED = 42
9
+ DEVICE = None
10
+
11
+
12
+ def get_device():
13
+ global DEVICE
14
+ if DEVICE is not None:
15
+ return DEVICE
16
+
17
+ if torch.cuda.is_available():
18
+ DEVICE = "cuda"
19
+ elif torch.backends.mps.is_available():
20
+ DEVICE = "mps"
21
+ else:
22
+ DEVICE = "cpu"
23
+ print("Device Selected:", DEVICE)
24
+ return DEVICE
25
+
26
+
27
+ def set_seed(seed=SEED):
28
+ torch.manual_seed(seed)
29
+ if get_device() == 'cuda':
30
+ torch.cuda.manual_seed(seed)
31
+
32
+
33
+ def plot_examples(images, labels, figsize=None, n=20):
34
+ _ = plt.figure(figsize=figsize)
35
+
36
+ for i in range(n):
37
+ plt.subplot(4, n//4, i + 1)
38
+ plt.tight_layout()
39
+ image = images[i]
40
+ plt.imshow(image, cmap='gray')
41
+ label = labels[i]
42
+ plt.title(str(label))
43
+ plt.xticks([])
44
+ plt.yticks([])
45
+
46
+
47
+ def get_incorrect_preds(prediction, labels):
48
+ prediction = prediction.argmax(dim=1)
49
+ indices = prediction.ne(labels).nonzero().reshape(-1).tolist()
50
+ return indices, prediction[indices].tolist(), labels[indices].tolist()
51
+
52
+
53
+ def get_cam_visualisation(model, dataset, input_tensor, label, target_layer, use_cuda=False):
54
+ grad_cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=use_cuda)
55
+
56
+ targets = [ClassifierOutputTarget(label)]
57
+
58
+ grayscale_cam = grad_cam(input_tensor=input_tensor.unsqueeze(0), targets=targets)
59
+ # In this example grayscale_cam has only one image in the batch:
60
+ grayscale_cam = grayscale_cam[0, :]
61
+
62
+ output = show_cam_on_image(dataset.show_transform(input_tensor).cpu().numpy(), grayscale_cam,
63
+ use_rgb=True)
64
+ return output
65
+
66
+
67
+ def model_summary(model, input_size=None):
68
+ return torchinfo.summary(model, input_size=input_size, depth=5,
69
+ col_names=["input_size", "output_size", "num_params", "params_percent"])