|
"""Module to define the model.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
import modules.config as config |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import torchinfo |
|
from torch.optim.lr_scheduler import OneCycleLR |
|
from torch_lr_finder import LRFinder |
|
from torchmetrics import Accuracy |
|
|
|
|
|
PREFERRED_START_LR = config.PREFERRED_START_LR |
|
PREFERRED_WEIGHT_DECAY = config.PREFERRED_WEIGHT_DECAY |
|
|
|
|
|
def detailed_model_summary(model, input_size): |
|
"""Define a function to print the model summary.""" |
|
|
|
|
|
torchinfo.summary( |
|
model, |
|
input_size=input_size, |
|
batch_dim=0, |
|
col_names=( |
|
"input_size", |
|
"kernel_size", |
|
"output_size", |
|
"num_params", |
|
"trainable", |
|
), |
|
verbose=1, |
|
col_width=16, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomResNet(pl.LightningModule): |
|
"""This defines the structure of the NN.""" |
|
|
|
|
|
print_shape = False |
|
|
|
dropout_value = 0.02 |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
|
|
self.loss_function = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
self.accuracy_function = Accuracy(task="multiclass", num_classes=10) |
|
|
|
|
|
self.results = { |
|
"train_loss": [], |
|
"train_acc": [], |
|
"test_loss": [], |
|
"test_acc": [], |
|
"val_loss": [], |
|
"val_acc": [], |
|
} |
|
|
|
|
|
self.misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []} |
|
|
|
|
|
self.learning_rate = PREFERRED_START_LR |
|
|
|
|
|
|
|
|
|
|
|
self.prep = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=3, |
|
out_channels=64, |
|
kernel_size=(3, 3), |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout_value), |
|
) |
|
|
|
|
|
self.layer1_x = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=64, |
|
out_channels=128, |
|
kernel_size=(3, 3), |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
bias=False, |
|
), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout_value), |
|
) |
|
|
|
|
|
self.layer1_r1 = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=(3, 3), |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout_value), |
|
nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=(3, 3), |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout_value), |
|
) |
|
|
|
|
|
self.layer2 = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=128, |
|
out_channels=256, |
|
kernel_size=(3, 3), |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
bias=False, |
|
), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout_value), |
|
) |
|
|
|
|
|
self.layer3_x = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=256, |
|
out_channels=512, |
|
kernel_size=(3, 3), |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
bias=False, |
|
), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout_value), |
|
) |
|
|
|
|
|
self.layer3_r2 = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=512, |
|
out_channels=512, |
|
kernel_size=(3, 3), |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout_value), |
|
nn.Conv2d( |
|
in_channels=512, |
|
out_channels=512, |
|
kernel_size=(3, 3), |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(512), |
|
nn.ReLU(), |
|
nn.Dropout(self.dropout_value), |
|
) |
|
|
|
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=4, stride=4) |
|
|
|
|
|
self.fc = nn.Linear(512, 10) |
|
|
|
|
|
self.save_hyperparameters() |
|
|
|
def print_view(self, x, msg=""): |
|
"""Print shape of the model""" |
|
if self.print_shape: |
|
if msg != "": |
|
print(msg, "\n\t", x.shape, "\n") |
|
else: |
|
print(x.shape) |
|
|
|
def forward(self, x): |
|
"""Forward pass""" |
|
|
|
|
|
x = self.prep(x) |
|
self.print_view(x, "PrepLayer") |
|
|
|
|
|
x = self.layer1_x(x) |
|
self.print_view(x, "Layer 1, X") |
|
r1 = self.layer1_r1(x) |
|
self.print_view(r1, "Layer 1, R1") |
|
x = x + r1 |
|
self.print_view(x, "Layer 1, X + R1") |
|
|
|
|
|
x = self.layer2(x) |
|
self.print_view(x, "Layer 2") |
|
|
|
|
|
x = self.layer3_x(x) |
|
self.print_view(x, "Layer 3, X") |
|
r2 = self.layer3_r2(x) |
|
self.print_view(r2, "Layer 3, R2") |
|
x = x + r2 |
|
self.print_view(x, "Layer 3, X + R2") |
|
|
|
|
|
x = self.maxpool(x) |
|
self.print_view(x, "Max Pooling") |
|
|
|
|
|
|
|
x = x.view(x.shape[0], -1) |
|
self.print_view(x, "Reshape before FC") |
|
x = self.fc(x) |
|
self.print_view(x, "After FC") |
|
|
|
|
|
return F.log_softmax(x, dim=-1) |
|
|
|
|
|
def find_optimal_lr(self, train_loader): |
|
"""Use LR Finder to find the best starting learning rate""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
tmp_optimizer = optim.Adam(self.parameters(), lr=PREFERRED_START_LR, weight_decay=PREFERRED_WEIGHT_DECAY) |
|
|
|
|
|
lr_finder = LRFinder(self, optimizer=tmp_optimizer, criterion=self.loss_function) |
|
lr_finder.range_test(train_loader=train_loader, end_lr=10, num_iter=100) |
|
|
|
_, suggested_lr = lr_finder.plot(suggest_lr=True) |
|
lr_finder.reset() |
|
|
|
|
|
print(f"Suggested Max LR: {suggested_lr}") |
|
|
|
if suggested_lr is None: |
|
suggested_lr = PREFERRED_START_LR |
|
|
|
return suggested_lr |
|
|
|
|
|
def configure_optimizers(self): |
|
"""Add ADAM optimizer to the lightning module""" |
|
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=PREFERRED_WEIGHT_DECAY) |
|
|
|
|
|
|
|
percent_start = 5 / int(self.trainer.max_epochs) |
|
if percent_start >= 1: |
|
percent_start = 0.3 |
|
|
|
|
|
scheduler_dict = { |
|
"scheduler": OneCycleLR( |
|
optimizer=optimizer, |
|
max_lr=self.learning_rate, |
|
total_steps=int(self.trainer.estimated_stepping_batches), |
|
pct_start=percent_start, |
|
div_factor=100, |
|
three_phase=False, |
|
anneal_strategy="linear", |
|
final_div_factor=100, |
|
verbose=False, |
|
), |
|
"interval": "step", |
|
} |
|
|
|
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} |
|
|
|
|
|
def compute_loss(self, prediction, target): |
|
"""Compute Loss""" |
|
|
|
|
|
loss = self.loss_function(prediction, target) |
|
|
|
return loss |
|
|
|
|
|
def compute_accuracy(self, prediction, target): |
|
"""Compute accuracy""" |
|
|
|
|
|
acc = self.accuracy_function(prediction, target) |
|
|
|
return acc * 100 |
|
|
|
|
|
def compute_metrics(self, batch): |
|
"""Function to calculate loss and accuracy""" |
|
|
|
|
|
data, target = batch |
|
|
|
|
|
pred = self(data) |
|
|
|
|
|
loss = self.compute_loss(prediction=pred, target=target) |
|
|
|
|
|
acc = self.compute_accuracy(prediction=pred, target=target) |
|
|
|
return loss, acc |
|
|
|
|
|
def store_misclassified_images(self): |
|
"""Get an array of misclassified images""" |
|
|
|
self.misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []} |
|
|
|
|
|
self.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
for batch in self.trainer.test_dataloaders: |
|
|
|
data, target = batch |
|
data, target = data.to(self.device), target.to(self.device) |
|
|
|
|
|
pred = self(data) |
|
|
|
|
|
output = pred.argmax(dim=1) |
|
|
|
|
|
incorrect_indices = ~output.eq(target) |
|
|
|
|
|
self.misclassified_image_data["images"].extend(data[incorrect_indices]) |
|
self.misclassified_image_data["ground_truths"].extend(target[incorrect_indices]) |
|
self.misclassified_image_data["predicted_vals"].extend(output[incorrect_indices]) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
"""Training step""" |
|
|
|
|
|
loss, acc = self.compute_metrics(batch) |
|
|
|
self.log("train_loss", loss, prog_bar=True, on_epoch=True, logger=True) |
|
self.log("train_acc", acc, prog_bar=True, on_epoch=True, logger=True) |
|
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
"""Validation step""" |
|
|
|
|
|
loss, acc = self.compute_metrics(batch) |
|
|
|
self.log("val_loss", loss, prog_bar=True, on_epoch=True, logger=True) |
|
self.log("val_acc", acc, prog_bar=True, on_epoch=True, logger=True) |
|
|
|
return loss |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
"""Test step""" |
|
|
|
|
|
loss, acc = self.compute_metrics(batch) |
|
|
|
self.log("test_loss", loss, prog_bar=False, on_epoch=True, logger=True) |
|
self.log("test_acc", acc, prog_bar=False, on_epoch=True, logger=True) |
|
|
|
return loss |
|
|
|
|
|
def on_train_epoch_end(self): |
|
"""On train epoch end""" |
|
|
|
|
|
self.results["train_loss"].append(self.trainer.callback_metrics["train_loss"].detach().item()) |
|
self.results["train_acc"].append(self.trainer.callback_metrics["train_acc"].detach().item()) |
|
|
|
|
|
def on_validation_epoch_end(self): |
|
"""On validation epoch end""" |
|
|
|
|
|
self.results["test_loss"].append(self.trainer.callback_metrics["val_loss"].detach().item()) |
|
self.results["test_acc"].append(self.trainer.callback_metrics["val_acc"].detach().item()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_test_end(self): |
|
"""On test end""" |
|
|
|
print("Test ended! Saving misclassified images") |
|
|
|
self.store_misclassified_images() |