mkthoma's picture
Update resnet.py
83214bf
raw
history blame
9.66 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
# imports
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch_lr_finder import LRFinder
import math
import torch
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pytorch_lightning as pl
import matplotlib.pyplot as plt
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256
# Model
class custom_ResNet(pl.LightningModule):
def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4):
super(custom_ResNet, self).__init__()
# Set our init args as class attributes
# Hardcode some dataset specific attributes
self.data_dir = data_dir
self.learning_rate = learning_rate
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
self.num_classes = 10
self.train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # Convert PIL image to tensor
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
self.test_transform = transforms.Compose([
transforms.ToTensor(), # Convert PIL image to tensor
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
# Define PyTorch model
# PREPARATION BLOCK
self.prepblock = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
nn.ReLU(),nn.BatchNorm2d(64))
# output_size = 32, RF=3
# CONVOLUTION BLOCK 1
self.convblock1_l1 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
# output_size = 32, RF=5
nn.MaxPool2d(2, 2),nn.ReLU(),nn.BatchNorm2d(128))
# output_size = 16, RF=6
self.convblock1_r1 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
nn.ReLU(),nn.BatchNorm2d(128),
# output_size = 16, RF=10
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
nn.ReLU(),nn.BatchNorm2d(128))
# output_size = 16, RF=14
# CONVOLUTION BLOCK 2
self.convblock2_l1 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
# output_size = 16, RF=18
nn.MaxPool2d(2, 2),nn.ReLU(),nn.BatchNorm2d(256))
# output_size = 8, RF=20
# CONVOLUTION BLOCK 3
self.convblock3_l1 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
# output_size = 8, RF=28
nn.MaxPool2d(2, 2),
nn.ReLU(),nn.BatchNorm2d(512))
# output_size = 4, RF=32
self.convblock3_r2 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
nn.ReLU(),nn.BatchNorm2d(512),
# output_size = 4, RF=48
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False),
nn.ReLU(),nn.BatchNorm2d(512))
# output_size = 4, RF=64
# CONVOLUTION BLOCK 4
self.convblock4_mp = nn.Sequential(nn.MaxPool2d(4))
# output_size = 1, RF = 88
# OUTPUT BLOCK - Fully Connected layer
self.output_block = nn.Sequential(nn.Linear(in_features=512, out_features=10, bias=False))
# output_size = 1, RF = 88
def forward(self, x):
# Preparation Block
x1 = self.prepblock(x)
# Convolution Block 1
x2 = self.convblock1_l1(x1)
x3 = self.convblock1_r1(x2)
x4 = x2 + x3
# Convolution Block 2
x5 = self.convblock2_l1(x4)
# Convolution Block 3
x6 = self.convblock3_l1(x5)
x7 = self.convblock3_r2(x6)
x8 = x7 + x6
# Convolution Block 4
x9 = self.convblock4_mp(x8)
# Output Block
x9 = x9.view(x9.size(0), -1)
x10 = self.output_block(x9)
return F.log_softmax(x10, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
pred = y_hat.argmax(dim=1, keepdim=True)
acc = pred.eq(y.view_as(pred)).float().mean()
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
pred = y_hat.argmax(dim=1, keepdim=True)
acc = pred.eq(y.view_as(pred)).float().mean()
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
pred = y_hat.argmax(dim=1, keepdim=True)
acc = pred.eq(y.view_as(pred)).float().mean()
self.log('test_loss', loss, prog_bar=True)
self.log('test_acc', acc, prog_bar=True)
return pred # Return predictions instead of loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer
####################
# DATA RELATED HOOKS
####################
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.test_transform)
def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
def collect_misclassified_images(self, num_images):
misclassified_images = []
misclassified_true_labels = []
misclassified_predicted_labels = []
num_collected = 0
for batch in self.test_dataloader():
x, y = batch
pred = self.forward(x).argmax(dim=1, keepdim=True)
misclassified_mask = pred.eq(y.view_as(pred)).squeeze().cpu().numpy()
misclassified_images.extend(x[~misclassified_mask])
misclassified_true_labels.extend(y[~misclassified_mask])
misclassified_predicted_labels.extend(pred[~misclassified_mask])
num_collected += sum(~misclassified_mask)
if num_collected >= num_images:
break
return misclassified_images[:num_images], misclassified_true_labels[:num_images], misclassified_predicted_labels[:num_images], len(misclassified_images)
def normalize_image(self, img_tensor):
min_val = img_tensor.min()
max_val = img_tensor.max()
return (img_tensor - min_val) / (max_val - min_val)
def show_misclassified_images(self, num_images=10):
misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
num_rows = 2
num_cols = math.ceil(num_images / num_rows)
fig, axs = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
fig.suptitle(f"Misclassified Images (Showing {num_images} out of {num_misclassified})")
plt.subplots_adjust(hspace=0.5) # Adjust vertical space between subplots
for i in range(num_images):
img = self.normalize_image(misclassified_images[i]).permute(1, 2, 0)
row_idx = i // num_cols
col_idx = i % num_cols
axs[row_idx, col_idx].imshow(img)
axs[row_idx, col_idx].set_title(f"True label: {self.classes[true_labels[i]]}\nPredicted: {self.classes[predicted_labels[i]]}")
axs[row_idx, col_idx].axis("off")
# Remove any empty subplots in the last row (when num_images is not divisible by num_rows)
for i in range(num_images, num_rows * num_cols):
row_idx = i // num_cols
col_idx = i % num_cols
axs[row_idx, col_idx].remove()
plt.show()