Spaces:
Sleeping
Sleeping
Commit
·
375ca58
1
Parent(s):
9d2eeb0
Upload 8 files
Browse files- datasets/__init__.py +1 -0
- datasets/cifar10.py +32 -0
- datasets/generic.py +111 -0
- models/__init__.py +0 -0
- models/custom_resnet.py +159 -0
- utils/__init__.py +1 -0
- utils/metrics.py +19 -0
- utils/misc.py +69 -0
datasets/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .cifar10 import CIFAR10
|
datasets/cifar10.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from torchvision import datasets
|
4 |
+
import albumentations as A
|
5 |
+
|
6 |
+
from .generic import MyDataSet
|
7 |
+
|
8 |
+
|
9 |
+
class AlbCIFAR10(datasets.CIFAR10):
|
10 |
+
def __init__(self, root, alb_transform=None, **kwargs):
|
11 |
+
super(AlbCIFAR10, self).__init__(root, **kwargs)
|
12 |
+
self.alb_transform = alb_transform
|
13 |
+
|
14 |
+
def __getitem__(self, index):
|
15 |
+
image, label = super(AlbCIFAR10, self).__getitem__(index)
|
16 |
+
if self.alb_transform is not None:
|
17 |
+
image = self.alb_transform(image=np.array(image))['image']
|
18 |
+
return image, label
|
19 |
+
|
20 |
+
|
21 |
+
class CIFAR10(MyDataSet):
|
22 |
+
DataSet = AlbCIFAR10
|
23 |
+
mean = (0.49139968, 0.48215827, 0.44653124)
|
24 |
+
std = (0.24703233, 0.24348505, 0.26158768)
|
25 |
+
default_alb_transforms = [
|
26 |
+
A.ToGray(p=0.2),
|
27 |
+
A.PadIfNeeded(40, 40, p=1),
|
28 |
+
A.RandomCrop(32, 32, p=1),
|
29 |
+
A.HorizontalFlip(p=0.5),
|
30 |
+
# Since normalisation was the first step, mean is already 0, so cutout fill_value = 0
|
31 |
+
A.CoarseDropout(max_holes=1, max_height=8, max_width=8, fill_value=0, p=1)
|
32 |
+
]
|
datasets/generic.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC
|
3 |
+
from functools import cached_property
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import albumentations as A
|
7 |
+
from albumentations.pytorch import ToTensorV2
|
8 |
+
|
9 |
+
try:
|
10 |
+
from epoch.utils import plot_examples
|
11 |
+
except ModuleNotFoundError:
|
12 |
+
from utils import plot_examples
|
13 |
+
|
14 |
+
|
15 |
+
class MyDataSet(ABC):
|
16 |
+
DataSet = None
|
17 |
+
mean = None
|
18 |
+
std = None
|
19 |
+
classes = None
|
20 |
+
default_alb_transforms = None
|
21 |
+
|
22 |
+
def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None):
|
23 |
+
self.batch_size = batch_size
|
24 |
+
self.normalize = normalize
|
25 |
+
self.shuffle = shuffle
|
26 |
+
self.augment = augment
|
27 |
+
self.alb_transforms = alb_transforms or self.default_alb_transforms
|
28 |
+
|
29 |
+
self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True}
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def set_classes(cls, data):
|
33 |
+
if cls.classes is None:
|
34 |
+
cls.classes = {i: c for i, c in enumerate(data.classes)}
|
35 |
+
|
36 |
+
@cached_property
|
37 |
+
def train_data(self):
|
38 |
+
res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms())
|
39 |
+
self.set_classes(res)
|
40 |
+
return res
|
41 |
+
|
42 |
+
@cached_property
|
43 |
+
def test_data(self):
|
44 |
+
res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms())
|
45 |
+
self.set_classes(res)
|
46 |
+
return res
|
47 |
+
|
48 |
+
@cached_property
|
49 |
+
def train_loader(self):
|
50 |
+
return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs)
|
51 |
+
|
52 |
+
@cached_property
|
53 |
+
def test_loader(self):
|
54 |
+
return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs)
|
55 |
+
|
56 |
+
@cached_property
|
57 |
+
def example_iter(self):
|
58 |
+
return iter(self.train_loader)
|
59 |
+
|
60 |
+
def get_train_transforms(self):
|
61 |
+
all_transforms = list()
|
62 |
+
if self.normalize:
|
63 |
+
all_transforms.append(A.Normalize(self.mean, self.std))
|
64 |
+
if self.augment and self.alb_transforms is not None:
|
65 |
+
all_transforms.extend(self.alb_transforms)
|
66 |
+
all_transforms.append(ToTensorV2())
|
67 |
+
return A.Compose(all_transforms)
|
68 |
+
|
69 |
+
def get_test_transforms(self):
|
70 |
+
all_transforms = list()
|
71 |
+
if self.normalize:
|
72 |
+
all_transforms.append(A.Normalize(self.mean, self.std))
|
73 |
+
all_transforms.append(ToTensorV2())
|
74 |
+
return A.Compose(all_transforms)
|
75 |
+
|
76 |
+
def download(self):
|
77 |
+
self.DataSet('../data', train=True, download=True)
|
78 |
+
self.DataSet('../data', train=False, download=True)
|
79 |
+
|
80 |
+
def denormalise(self, tensor):
|
81 |
+
result = tensor.clone().detach().requires_grad_(False)
|
82 |
+
if self.normalize:
|
83 |
+
for t, m, s in zip(result, self.mean, self.std):
|
84 |
+
t.mul_(s).add_(m)
|
85 |
+
return result
|
86 |
+
|
87 |
+
def show_transform(self, img):
|
88 |
+
if self.normalize:
|
89 |
+
img = self.denormalise(img)
|
90 |
+
if len(self.mean) == 3:
|
91 |
+
return img.permute(1, 2, 0)
|
92 |
+
else:
|
93 |
+
return img.squeeze(0)
|
94 |
+
|
95 |
+
def show_examples(self, figsize=(8, 6)):
|
96 |
+
batch_data, batch_label = next(self.example_iter)
|
97 |
+
images = list()
|
98 |
+
labels = list()
|
99 |
+
|
100 |
+
for i in range(len(batch_data)):
|
101 |
+
image = batch_data[i]
|
102 |
+
image = self.show_transform(image)
|
103 |
+
|
104 |
+
label = batch_label[i].item()
|
105 |
+
if self.classes is not None:
|
106 |
+
label = f'{label}:{self.classes[label]}'
|
107 |
+
|
108 |
+
images.append(image)
|
109 |
+
labels.append(label)
|
110 |
+
|
111 |
+
plot_examples(images, labels, figsize=figsize)
|
models/__init__.py
ADDED
File without changes
|
models/custom_resnet.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from torch import optim
|
3 |
+
from pytorch_lightning import LightningModule
|
4 |
+
from torchmetrics import MeanMetric
|
5 |
+
from torch_lr_finder import LRFinder
|
6 |
+
|
7 |
+
from utils.metrics import RunningAccuracy
|
8 |
+
|
9 |
+
|
10 |
+
class ConvLayer(nn.Module):
|
11 |
+
def __init__(self, input_c, output_c, bias=False, stride=1, padding=1, pool=False, dropout=0.):
|
12 |
+
super(ConvLayer, self).__init__()
|
13 |
+
|
14 |
+
layers = list()
|
15 |
+
layers.append(
|
16 |
+
nn.Conv2d(input_c, output_c, kernel_size=3, bias=bias, stride=stride, padding=padding,
|
17 |
+
padding_mode='replicate')
|
18 |
+
)
|
19 |
+
if pool:
|
20 |
+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
|
21 |
+
layers.append(nn.BatchNorm2d(output_c))
|
22 |
+
layers.append(nn.ReLU())
|
23 |
+
if dropout > 0:
|
24 |
+
layers.append(nn.Dropout(dropout))
|
25 |
+
|
26 |
+
self.all_layers = nn.Sequential(*layers)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return self.all_layers(x)
|
30 |
+
|
31 |
+
|
32 |
+
class CustomLayer(nn.Module):
|
33 |
+
def __init__(self, input_c, output_c, pool=True, residue=2, dropout=0.):
|
34 |
+
super(CustomLayer, self).__init__()
|
35 |
+
|
36 |
+
self.pool_block = ConvLayer(input_c, output_c, pool=pool, dropout=dropout)
|
37 |
+
self.res_block = None
|
38 |
+
if residue > 0:
|
39 |
+
layers = list()
|
40 |
+
for i in range(0, residue):
|
41 |
+
layers.append(ConvLayer(output_c, output_c, pool=False, dropout=dropout))
|
42 |
+
self.res_block = nn.Sequential(*layers)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.pool_block(x)
|
46 |
+
if self.res_block is not None:
|
47 |
+
x_ = x
|
48 |
+
x = self.res_block(x)
|
49 |
+
# += operator causes inplace errors in pytorch if done right after relu.
|
50 |
+
x = x + x_
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class Model(LightningModule):
|
55 |
+
def __init__(self, dataset, dropout=0.05, max_epochs=24):
|
56 |
+
super(Model, self).__init__()
|
57 |
+
|
58 |
+
self.dataset = dataset
|
59 |
+
|
60 |
+
self.network = nn.Sequential(
|
61 |
+
CustomLayer(3, 64, pool=False, residue=0, dropout=dropout),
|
62 |
+
CustomLayer(64, 128, pool=True, residue=2, dropout=dropout),
|
63 |
+
CustomLayer(128, 256, pool=True, residue=0, dropout=dropout),
|
64 |
+
CustomLayer(256, 512, pool=True, residue=2, dropout=dropout),
|
65 |
+
nn.MaxPool2d(kernel_size=4, stride=4),
|
66 |
+
nn.Flatten(),
|
67 |
+
nn.Linear(512, 10)
|
68 |
+
)
|
69 |
+
|
70 |
+
self.criterion = nn.CrossEntropyLoss()
|
71 |
+
self.train_accuracy = RunningAccuracy()
|
72 |
+
self.val_accuracy = RunningAccuracy()
|
73 |
+
self.train_loss = MeanMetric()
|
74 |
+
self.val_loss = MeanMetric()
|
75 |
+
|
76 |
+
self.max_epochs = max_epochs
|
77 |
+
self.epoch_counter = 1
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
return self.network(x)
|
81 |
+
|
82 |
+
def common_step(self, batch, loss_metric, acc_metric):
|
83 |
+
x, y = batch
|
84 |
+
batch_len = y.numel()
|
85 |
+
logits = self.forward(x)
|
86 |
+
loss = self.criterion(logits, y)
|
87 |
+
loss_metric.update(loss, batch_len)
|
88 |
+
acc_metric.update(logits, y)
|
89 |
+
return loss
|
90 |
+
|
91 |
+
def training_step(self, batch, batch_idx):
|
92 |
+
return self.common_step(batch, self.train_loss, self.train_accuracy)
|
93 |
+
|
94 |
+
def on_train_epoch_end(self):
|
95 |
+
print(f"Epoch: {self.epoch_counter}, Train: Loss: {self.train_loss.compute():0.4f}, Accuracy: "
|
96 |
+
f"{self.train_accuracy.compute():0.2f}")
|
97 |
+
self.train_loss.reset()
|
98 |
+
self.train_accuracy.reset()
|
99 |
+
self.epoch_counter += 1
|
100 |
+
|
101 |
+
def validation_step(self, batch, batch_idx):
|
102 |
+
loss = self.common_step(batch, self.val_loss, self.val_accuracy)
|
103 |
+
self.log("val_step_loss", self.val_loss, prog_bar=True, logger=True)
|
104 |
+
self.log("val_step_acc", self.val_accuracy, prog_bar=True, logger=True)
|
105 |
+
return loss
|
106 |
+
|
107 |
+
def on_validation_epoch_end(self):
|
108 |
+
print(f"Epoch: {self.epoch_counter}, Valid: Loss: {self.val_loss.compute():0.4f}, Accuracy: "
|
109 |
+
f"{self.val_accuracy.compute():0.2f}")
|
110 |
+
self.val_loss.reset()
|
111 |
+
self.val_accuracy.reset()
|
112 |
+
|
113 |
+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
114 |
+
if isinstance(batch, list):
|
115 |
+
x, _ = batch
|
116 |
+
else:
|
117 |
+
x = batch
|
118 |
+
return self.forward(x)
|
119 |
+
|
120 |
+
def find_lr(self, optimizer):
|
121 |
+
lr_finder = LRFinder(self, optimizer, self.criterion)
|
122 |
+
lr_finder.range_test(self.dataset.train_loader, end_lr=0.1, num_iter=100, step_mode='exp')
|
123 |
+
_, best_lr = lr_finder.plot()
|
124 |
+
lr_finder.reset()
|
125 |
+
return best_lr
|
126 |
+
|
127 |
+
def configure_optimizers(self):
|
128 |
+
optimizer = optim.Adam(self.parameters(), lr=1e-7, weight_decay=1e-2)
|
129 |
+
best_lr = self.find_lr(optimizer)
|
130 |
+
scheduler = optim.lr_scheduler.OneCycleLR(
|
131 |
+
optimizer,
|
132 |
+
max_lr=best_lr,
|
133 |
+
steps_per_epoch=len(self.dataset.train_loader),
|
134 |
+
epochs=self.max_epochs,
|
135 |
+
pct_start=5/self.max_epochs,
|
136 |
+
div_factor=100,
|
137 |
+
three_phase=False,
|
138 |
+
final_div_factor=100,
|
139 |
+
anneal_strategy='linear'
|
140 |
+
)
|
141 |
+
return {
|
142 |
+
'optimizer': optimizer,
|
143 |
+
'lr_scheduler': {
|
144 |
+
"scheduler": scheduler,
|
145 |
+
"interval": "step",
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
def prepare_data(self):
|
150 |
+
self.dataset.download()
|
151 |
+
|
152 |
+
def train_dataloader(self):
|
153 |
+
return self.dataset.train_loader
|
154 |
+
|
155 |
+
def val_dataloader(self):
|
156 |
+
return self.dataset.test_loader
|
157 |
+
|
158 |
+
def predict_dataloader(self):
|
159 |
+
return self.val_dataloader()
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .misc import *
|
utils/metrics.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from torchmetrics import Metric
|
4 |
+
|
5 |
+
|
6 |
+
class RunningAccuracy(Metric):
|
7 |
+
def __init__(self, **kwargs):
|
8 |
+
super().__init__(**kwargs)
|
9 |
+
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
10 |
+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
11 |
+
|
12 |
+
def update(self, preds: Tensor, target: Tensor):
|
13 |
+
preds = preds.argmax(dim=1)
|
14 |
+
total = target.numel()
|
15 |
+
self.correct += preds.eq(target).sum()
|
16 |
+
self.total += total
|
17 |
+
|
18 |
+
def compute(self):
|
19 |
+
return 100 * self.correct.float() / self.total
|
utils/misc.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchinfo
|
3 |
+
from matplotlib import pyplot as plt
|
4 |
+
from pytorch_grad_cam import GradCAM
|
5 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
6 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
7 |
+
|
8 |
+
SEED = 42
|
9 |
+
DEVICE = None
|
10 |
+
|
11 |
+
|
12 |
+
def get_device():
|
13 |
+
global DEVICE
|
14 |
+
if DEVICE is not None:
|
15 |
+
return DEVICE
|
16 |
+
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
DEVICE = "cuda"
|
19 |
+
elif torch.backends.mps.is_available():
|
20 |
+
DEVICE = "mps"
|
21 |
+
else:
|
22 |
+
DEVICE = "cpu"
|
23 |
+
print("Device Selected:", DEVICE)
|
24 |
+
return DEVICE
|
25 |
+
|
26 |
+
|
27 |
+
def set_seed(seed=SEED):
|
28 |
+
torch.manual_seed(seed)
|
29 |
+
if get_device() == 'cuda':
|
30 |
+
torch.cuda.manual_seed(seed)
|
31 |
+
|
32 |
+
|
33 |
+
def plot_examples(images, labels, figsize=None, n=20):
|
34 |
+
_ = plt.figure(figsize=figsize)
|
35 |
+
|
36 |
+
for i in range(n):
|
37 |
+
plt.subplot(4, n//4, i + 1)
|
38 |
+
plt.tight_layout()
|
39 |
+
image = images[i]
|
40 |
+
plt.imshow(image, cmap='gray')
|
41 |
+
label = labels[i]
|
42 |
+
plt.title(str(label))
|
43 |
+
plt.xticks([])
|
44 |
+
plt.yticks([])
|
45 |
+
|
46 |
+
|
47 |
+
def get_incorrect_preds(prediction, labels):
|
48 |
+
prediction = prediction.argmax(dim=1)
|
49 |
+
indices = prediction.ne(labels).nonzero().reshape(-1).tolist()
|
50 |
+
return indices, prediction[indices].tolist(), labels[indices].tolist()
|
51 |
+
|
52 |
+
|
53 |
+
def get_cam_visualisation(model, dataset, input_tensor, label, target_layer, use_cuda=False):
|
54 |
+
grad_cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=use_cuda)
|
55 |
+
|
56 |
+
targets = [ClassifierOutputTarget(label)]
|
57 |
+
|
58 |
+
grayscale_cam = grad_cam(input_tensor=input_tensor.unsqueeze(0), targets=targets)
|
59 |
+
# In this example grayscale_cam has only one image in the batch:
|
60 |
+
grayscale_cam = grayscale_cam[0, :]
|
61 |
+
|
62 |
+
output = show_cam_on_image(dataset.show_transform(input_tensor).cpu().numpy(), grayscale_cam,
|
63 |
+
use_rgb=True)
|
64 |
+
return output
|
65 |
+
|
66 |
+
|
67 |
+
def model_summary(model, input_size=None):
|
68 |
+
return torchinfo.summary(model, input_size=input_size, depth=5,
|
69 |
+
col_names=["input_size", "output_size", "num_params", "params_percent"])
|