AkashDataScience commited on
Commit
9d912f9
·
1 Parent(s): e9dfeb1

First commit

Browse files
__pycache__/resnet.cpython-312.pyc ADDED
Binary file (7.95 kB). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (9.6 kB). View file
 
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision
2
+ from torchvision import transforms
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
+ from resnet import ResNet18
9
+ import gradio as gr
10
+
11
+ model = ResNet18()
12
+ model.load_state_dict(torch.load("model.ckpt", map_location=torch.device('cpu')), strict=False)
13
+
14
+ inv_normalize = transforms.Normalize(
15
+ mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
16
+ std=[1/0.23, 1/0.23, 1/0.23]
17
+ )
18
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
19
+ 'dog', 'frog', 'horse', 'ship', 'truck')
20
+
21
+ def resize_image_pil(image, new_width, new_height):
22
+
23
+ # Convert to PIL image
24
+ img = Image.fromarray(np.array(image))
25
+
26
+ # Get original size
27
+ width, height = img.size
28
+
29
+ # Calculate scale
30
+ width_scale = new_width / width
31
+ height_scale = new_height / height
32
+ scale = min(width_scale, height_scale)
33
+
34
+ # Resize
35
+ resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
36
+
37
+ # Crop to exact size
38
+ resized = resized.crop((0, 0, new_width, new_height))
39
+
40
+ return resized
41
+
42
+ def inference(input_img, transparency = 0.5, target_layer_number = -1):
43
+ input_img = resize_image_pil(input_img, 32, 32)
44
+
45
+ input_img = np.array(input_img)
46
+ org_img = input_img
47
+ input_img = input_img.reshape((32, 32, 3))
48
+ transform = transforms.ToTensor()
49
+ input_img = transform(input_img)
50
+ input_img = input_img
51
+ input_img = input_img.unsqueeze(0)
52
+ outputs = model(input_img)
53
+ softmax = torch.nn.Softmax(dim=0)
54
+ o = softmax(outputs.flatten())
55
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
56
+ _, prediction = torch.max(outputs, 1)
57
+ target_layers = [model.layer2[target_layer_number]]
58
+ cam = GradCAM(model=model, target_layers=target_layers)
59
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
60
+ grayscale_cam = grayscale_cam[0, :]
61
+ visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
62
+ return classes[prediction[0].item()], visualization, confidences
63
+
64
+ title = "CIFAR10 trained on ResNet18 Model with GradCAM"
65
+ description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
66
+ examples = [["cat.jpg", 0.5, -1], ["dog.jpg", 0.5, -1]]
67
+ demo = gr.Interface(
68
+ inference,
69
+ inputs = [
70
+ gr.Image(width=256, height=256, label="Input Image"), gr.Slider
71
+ (0, 1, value = 0.5, label="Overall Opacity of Image"),
72
+ gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")
73
+ ],
74
+ outputs = [
75
+ "text",
76
+ gr.Image(width=256, height=256, label="Output"),
77
+ gr.Label(num_top_classes=3)
78
+ ],
79
+ title = title,
80
+ description = description,
81
+ examples = examples,
82
+ )
83
+ demo.launch()
cat.jpg ADDED
dog.jpg ADDED
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:940f21f828787740b7b275a45b29051806977b52570b4e2afbb50a3f1dd04cab
3
+ size 89492032
requirements.txt ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ asttokens==2.4.1
8
+ attrs==23.2.0
9
+ certifi==2024.2.2
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ colorama==0.4.6
13
+ comm==0.2.2
14
+ contourpy==1.2.1
15
+ cycler==0.12.1
16
+ debugpy==1.8.1
17
+ decorator==5.1.1
18
+ executing==2.0.1
19
+ fastapi==0.110.2
20
+ ffmpy==0.3.2
21
+ filelock==3.13.1
22
+ fonttools==4.51.0
23
+ frozenlist==1.4.1
24
+ fsspec==2024.2.0
25
+ grad-cam==1.5.0
26
+ gradio==4.28.3
27
+ gradio_client==0.16.0
28
+ h11==0.14.0
29
+ httpcore==1.0.5
30
+ httpx==0.27.0
31
+ huggingface-hub==0.22.2
32
+ idna==3.7
33
+ importlib_resources==6.4.0
34
+ intel-openmp==2021.4.0
35
+ ipykernel==6.29.4
36
+ ipython==8.24.0
37
+ jedi==0.19.1
38
+ Jinja2==3.1.3
39
+ joblib==1.4.0
40
+ jsonschema==4.21.1
41
+ jsonschema-specifications==2023.12.1
42
+ jupyter_client==8.6.1
43
+ jupyter_core==5.7.2
44
+ kiwisolver==1.4.5
45
+ lightning==2.2.3
46
+ lightning-utilities==0.11.2
47
+ markdown-it-py==3.0.0
48
+ MarkupSafe==2.1.5
49
+ matplotlib==3.8.4
50
+ matplotlib-inline==0.1.7
51
+ mdurl==0.1.2
52
+ mkl==2021.4.0
53
+ mpmath==1.3.0
54
+ multidict==6.0.5
55
+ nest-asyncio==1.6.0
56
+ networkx==3.2.1
57
+ numpy==1.26.3
58
+ opencv-python==4.9.0.80
59
+ orjson==3.10.1
60
+ packaging==24.0
61
+ pandas==2.2.2
62
+ parso==0.8.4
63
+ pillow==10.2.0
64
+ platformdirs==4.2.1
65
+ prompt-toolkit==3.0.43
66
+ psutil==5.9.8
67
+ pure-eval==0.2.2
68
+ pydantic==2.7.1
69
+ pydantic_core==2.18.2
70
+ pydub==0.25.1
71
+ Pygments==2.17.2
72
+ pyparsing==3.1.2
73
+ python-dateutil==2.9.0.post0
74
+ python-multipart==0.0.9
75
+ pytorch-lightning==2.2.3
76
+ pytz==2024.1
77
+ pywin32==306
78
+ PyYAML==6.0.1
79
+ pyzmq==26.0.2
80
+ referencing==0.35.0
81
+ requests==2.31.0
82
+ rich==13.7.1
83
+ rpds-py==0.18.0
84
+ ruff==0.4.2
85
+ scikit-learn==1.4.2
86
+ scipy==1.13.0
87
+ semantic-version==2.10.0
88
+ setuptools==69.5.1
89
+ shellingham==1.5.4
90
+ six==1.16.0
91
+ sniffio==1.3.1
92
+ stack-data==0.6.3
93
+ starlette==0.37.2
94
+ sympy==1.12
95
+ tbb==2021.11.0
96
+ threadpoolctl==3.4.0
97
+ tomlkit==0.12.0
98
+ toolz==0.12.1
99
+ torch==2.3.0+cu121
100
+ torch-lr-finder==0.2.1
101
+ torchaudio==2.3.0+cu121
102
+ torchmetrics==1.3.2
103
+ torchsummary==1.5.1
104
+ torchvision==0.18.0+cu121
105
+ tornado==6.4
106
+ tqdm==4.66.2
107
+ traitlets==5.14.3
108
+ ttach==0.0.3
109
+ typer==0.12.3
110
+ typing_extensions==4.9.0
111
+ tzdata==2024.1
112
+ urllib3==2.2.1
113
+ uvicorn==0.29.0
114
+ wcwidth==0.2.13
115
+ websockets==11.0.3
116
+ yarl==1.9.4
resnet.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ResNet in PyTorch.
3
+ For Pre-activation ResNet, see 'preact_resnet.py'.
4
+
5
+ Reference:
6
+ [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
8
+ """
9
+ import os
10
+ import torch
11
+ import utils
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from torchmetrics import Accuracy
16
+ from torchvision.datasets import CIFAR10
17
+ from pytorch_lightning import LightningModule
18
+ from torch.utils.data import DataLoader, random_split
19
+
20
+
21
+ class BasicBlock(nn.Module):
22
+ expansion = 1
23
+
24
+ def __init__(self, in_planes, planes, stride=1):
25
+ super(BasicBlock, self).__init__()
26
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
27
+ self.bn1 = nn.BatchNorm2d(planes)
28
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
29
+ self.bn2 = nn.BatchNorm2d(planes)
30
+
31
+ self.shortcut = nn.Sequential()
32
+ if stride != 1 or in_planes != self.expansion*planes:
33
+ self.shortcut = nn.Sequential(
34
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
35
+ nn.BatchNorm2d(self.expansion*planes)
36
+ )
37
+
38
+ def forward(self, x):
39
+ out = F.relu(self.bn1(self.conv1(x)))
40
+ out = self.bn2(self.conv2(out))
41
+ out += self.shortcut(x)
42
+ out = F.relu(out)
43
+ return out
44
+
45
+
46
+ class ResNet(LightningModule):
47
+ def __init__(self, block, num_blocks, num_classes=10, loss='cross_entropy', learning_rate=2e-4, momentum=0.9, optimizer="SGD",
48
+ epochs=20):
49
+ super(ResNet, self).__init__()
50
+ self.in_planes = 64
51
+
52
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
53
+ self.bn1 = nn.BatchNorm2d(64)
54
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
55
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
56
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
57
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
58
+ self.linear = nn.Linear(512*block.expansion, num_classes)
59
+ self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
60
+ self.learning_rate = learning_rate
61
+ self.optimizer = optimizer
62
+ self.momentum = momentum
63
+ self.loss = utils.get_criterion(loss)
64
+ self.epochs = epochs
65
+
66
+ def _make_layer(self, block, planes, num_blocks, stride):
67
+ strides = [stride] + [1]*(num_blocks-1)
68
+ layers = []
69
+ for stride in strides:
70
+ layers.append(block(self.in_planes, planes, stride))
71
+ self.in_planes = planes * block.expansion
72
+ return nn.Sequential(*layers)
73
+
74
+ def forward(self, x):
75
+ out = F.relu(self.bn1(self.conv1(x)))
76
+ out = self.layer1(out)
77
+ out = self.layer2(out)
78
+ out = self.layer3(out)
79
+ out = self.layer4(out)
80
+ out = F.avg_pool2d(out, 4)
81
+ out = out.view(out.size(0), -1)
82
+ out = self.linear(out)
83
+ return out
84
+
85
+ def training_step(self, batch, batch_idx):
86
+ x, y = batch
87
+ loss = self.loss(self(x), y)
88
+ return loss
89
+
90
+ def validation_step(self, batch, batch_idx):
91
+ x, y = batch
92
+ logits = self(x)
93
+ loss = self.loss(logits, y)
94
+ preds = torch.argmax(logits, dim=1)
95
+ self.accuracy(preds, y)
96
+
97
+ # Calling self.log will surface up scalars for you in TensorBoard
98
+ self.log("val_loss", loss, prog_bar=True)
99
+ self.log("val_acc", self.accuracy, prog_bar=True)
100
+ return loss
101
+
102
+ def test_step(self, batch, batch_idx):
103
+ # Here we just reuse the validation_step for testing
104
+ return self.validation_step(batch, batch_idx)
105
+
106
+ def configure_optimizers(self):
107
+ optimizer = utils.get_optimizer(self, lr=self.learning_rate, momentum=self.momentum, optimizer_type="SGD")
108
+ max_lr = utils.get_learning_rate(self, optimizer, self.loss, self.trainer.datamodule.train_dataloader())
109
+ scheduler = utils.get_OneCycleLR_scheduler(optimizer, max_lr=max_lr, epochs=self.epochs,
110
+ steps_per_epoch=len(self.trainer.datamodule.train_dataloader()), max_at_epoch=5,
111
+ anneal_strategy = 'linear', div_factor=10,
112
+ final_div_factor=1)
113
+ return [optimizer],[{"scheduler": scheduler, "interval": "step", "frequency": 1}]
114
+
115
+ def ResNet18(loss='cross_entropy', learning_rate=2e-4, momentum=0.9, optimizer="SGD", epochs=20):
116
+ return ResNet(BasicBlock, [2, 2, 2, 2], loss=loss, learning_rate=learning_rate, momentum=momentum,
117
+ optimizer=optimizer, epochs=epochs)
118
+
119
+
120
+ def ResNet34(loss='cross_entropy', learning_rate=2e-4, momentum=0.9, optimizer="SGD", epochs=20):
121
+ return ResNet(BasicBlock, [3, 4, 6, 3], loss=loss, learning_rate=learning_rate, momentum=momentum,
122
+ optimizer=optimizer, epochs=epochs)
utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility Script containing functions to be used for training
4
+ Author: Shilpaj Bhalerao
5
+ """
6
+ # Standard Library Imports
7
+ import math
8
+ from typing import NoReturn
9
+
10
+ # Third-Party Imports
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import torch
14
+ from torchsummary import summary
15
+ from torchvision import transforms
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+
19
+ import torch.optim as optim
20
+ import torch.nn.functional as F
21
+ from torch_lr_finder import LRFinder
22
+
23
+
24
+ def get_summary(model, input_size: tuple) -> NoReturn:
25
+ """
26
+ Function to get the summary of the model architecture
27
+ :param model: Object of model architecture class
28
+ :param input_size: Input data shape (Channels, Height, Width)
29
+ """
30
+ use_cuda = torch.cuda.is_available()
31
+ device = torch.device("cuda" if use_cuda else "cpu")
32
+ network = model.to(device)
33
+ summary(network, input_size=input_size)
34
+
35
+
36
+ def get_misclassified_data(model, device, test_loader):
37
+ """
38
+ Function to run the model on test set and return misclassified images
39
+ :param model: Network Architecture
40
+ :param device: CPU/GPU
41
+ :param test_loader: DataLoader for test set
42
+ """
43
+ # Prepare the model for evaluation i.e. drop the dropout layer
44
+ model.eval()
45
+ model.to(device)
46
+
47
+ # List to store misclassified Images
48
+ misclassified_data = []
49
+
50
+ # Reset the gradients
51
+ with torch.no_grad():
52
+ # Extract images, labels in a batch
53
+ for data, target in test_loader:
54
+
55
+ # Migrate the data to the device
56
+ data, target = data.to(device), target.to(device)
57
+
58
+ # Extract single image, label from the batch
59
+ for image, label in zip(data, target):
60
+
61
+ # Add batch dimension to the image
62
+ image = image.unsqueeze(0)
63
+
64
+ # Get the model prediction on the image
65
+ output = model(image)
66
+
67
+ # Convert the output from one-hot encoding to a value
68
+ pred = output.argmax(dim=1, keepdim=True)
69
+
70
+ # If prediction is incorrect, append the data
71
+ if pred != label:
72
+ misclassified_data.append((image, label, pred))
73
+ return misclassified_data
74
+
75
+
76
+ # -------------------- GradCam --------------------
77
+ def display_gradcam_output(data: list,
78
+ classes: list[str],
79
+ inv_normalize: transforms.Normalize,
80
+ model,
81
+ target_layers,
82
+ targets=None,
83
+ number_of_samples: int = 10,
84
+ transparency: float = 0.60):
85
+ """
86
+ Function to visualize GradCam output on the data
87
+ :param data: List[Tuple(image, label)]
88
+ :param classes: Name of classes in the dataset
89
+ :param inv_normalize: Mean and Standard deviation values of the dataset
90
+ :param model: Model architecture
91
+ :param target_layers: Layers on which GradCam should be executed
92
+ :param targets: Classes to be focused on for GradCam
93
+ :param number_of_samples: Number of images to print
94
+ :param transparency: Weight of Normal image when mixed with activations
95
+ """
96
+ # Plot configuration
97
+ fig = plt.figure(figsize=(10, 10))
98
+ x_count = 5
99
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
100
+
101
+ # Create an object for GradCam
102
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
103
+
104
+ # Iterate over number of specified images
105
+ for i in range(number_of_samples):
106
+ plt.subplot(y_count, x_count, i + 1)
107
+ input_tensor = data[i][0]
108
+
109
+ # Get the activations of the layer for the images
110
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
111
+ grayscale_cam = grayscale_cam[0, :]
112
+
113
+ # Get back the original image
114
+ img = input_tensor.squeeze(0).to('cpu')
115
+ img = inv_normalize(img)
116
+ rgb_img = np.transpose(img, (1, 2, 0))
117
+ rgb_img = rgb_img.numpy()
118
+
119
+ # Mix the activations on the original image
120
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
121
+
122
+ # Display the images on the plot
123
+ plt.imshow(visualization)
124
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
125
+ plt.xticks([])
126
+ plt.yticks([])
127
+
128
+
129
+ def get_optimizer(model, lr, momentum=0, weight_decay=0, optimizer_type='SGD'):
130
+ """Method to get object of stochastic gradient descent. Used to update weights.
131
+
132
+ Args:
133
+ model (Object): Neural Network model
134
+ lr (float): Value of learning rate
135
+ momentum (float): Value of momentum
136
+ weight_decay (float): Value of weight decay
137
+ optimizer_type (str): Type of optimizer SGD or ADAM
138
+
139
+ Returns:
140
+ object: Object of optimizer class to update weights
141
+ """
142
+ if optimizer_type == 'SGD':
143
+ optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
144
+ elif optimizer_type == 'ADAM':
145
+ optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
146
+ return optimizer
147
+
148
+ def get_StepLR_scheduler(optimizer, step_size, gamma):
149
+ """Method to get object of scheduler class. Used to update learning rate
150
+
151
+ Args:
152
+ optimizer (Object): Object of optimizer
153
+ step_size (int): Period of learning rate decay
154
+ gamma (float): Number to multiply with learning rate
155
+
156
+ Returns:
157
+ object: Object of StepLR class to update learning rate
158
+ """
159
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma, verbose=True)
160
+ return scheduler
161
+
162
+ def get_ReduceLROnPlateau_scheduler(optimizer, factor, patience):
163
+ """Method to get object of scheduler class. Used to update learning rate
164
+
165
+ Args:
166
+ optimizer (Object): Object of optimizer
167
+ factor (float): Number to multiply with learning rate
168
+ patience (int): Number of epoch to wait
169
+
170
+ Returns:
171
+ object: Object of StepLR class to update learning rate
172
+ """
173
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, patience=patience, verbose=True)
174
+ return scheduler
175
+
176
+ def get_OneCycleLR_scheduler(optimizer, max_lr, epochs, steps_per_epoch, max_at_epoch, anneal_strategy, div_factor, final_div_factor):
177
+ """Method to get object of scheduler class. Used to update learning rate
178
+
179
+ Args:
180
+ optimizer (Object): Object of optimizer
181
+ max_lr (float): Maximum learning rate to reach during training
182
+ epochs (float): Total number of epoch
183
+ steps_per_epoch (int): Total steps in an epoch
184
+ max_at_epoch (int): Epoch to reach maximum learning rate
185
+ anneal_strategy (string): Strategy to interpolate between minimum and maximum lr
186
+ div_factor (int): Divisive factor to calculate intial learning rate
187
+ final_div_factor (int): Divisive factor to calculate minimum learning rate
188
+
189
+ Returns:
190
+ object: Object of StepLR class to update learning rate
191
+ """
192
+ scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, epochs=epochs,
193
+ steps_per_epoch=steps_per_epoch,
194
+ pct_start=max_at_epoch/epochs,
195
+ anneal_strategy=anneal_strategy,
196
+ div_factor=div_factor,
197
+ final_div_factor=final_div_factor)
198
+ return scheduler
199
+
200
+ def get_criterion(loss_type='cross_entropy'):
201
+ """Method to get loss calculation ctiterion
202
+
203
+ Args:
204
+ loss_type (str): Type of loss 'nll_loss' or 'cross_entropy' loss
205
+
206
+ Returns:
207
+ object: Object to calculate loss
208
+ """
209
+ if loss_type == 'nll_loss':
210
+ criterion = F.nll_loss
211
+ elif loss_type == 'cross_entropy':
212
+ criterion = F.cross_entropy
213
+ return criterion
214
+
215
+ def get_learning_rate(model, optimizer, criterion, trainloader):
216
+ """Method to find learning rate using LR finder.
217
+
218
+ Args:
219
+ model (Object): Object of model
220
+ optimizer (Object): Object of optimizer class
221
+ criterion (Object): Loss function
222
+ trainloader (Object): Object of dataloader class
223
+
224
+ Returns:
225
+ float: Learning rate suggested by lr finder
226
+ """
227
+ # Create object and perform range test
228
+ lr_finder = LRFinder(model, optimizer, criterion)
229
+ lr_finder.range_test(trainloader, end_lr=100, num_iter=100)
230
+
231
+ # Plot result and store suggested lr
232
+ plot, suggested_lr = lr_finder.plot()
233
+
234
+ # Reset model and optimizer
235
+ lr_finder.reset()
236
+
237
+ return suggested_lr