import torch import torch.nn as nn import torch.nn.functional as F from torchsummary import summary # imports import os import torch from pytorch_lightning import LightningModule, Trainer from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader, random_split from torchmetrics import Accuracy from torchvision import transforms from torchvision.datasets import CIFAR10 from torch_lr_finder import LRFinder import math import torch from torch.utils.data import DataLoader, random_split import torchvision.transforms as transforms import torchvision.datasets as datasets import pytorch_lightning as pl import matplotlib.pyplot as plt PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") BATCH_SIZE = 256 # Model class custom_ResNet(pl.LightningModule): def __init__(self, data_dir=PATH_DATASETS, learning_rate=2e-4): super(custom_ResNet, self).__init__() # Set our init args as class attributes # Hardcode some dataset specific attributes self.data_dir = data_dir self.learning_rate = learning_rate self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] self.num_classes = 10 self.train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # Convert PIL image to tensor transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) ]) self.test_transform = transforms.Compose([ transforms.ToTensor(), # Convert PIL image to tensor transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) ]) # Define PyTorch model # PREPARATION BLOCK self.prepblock = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), nn.ReLU(),nn.BatchNorm2d(64)) # output_size = 32, RF=3 # CONVOLUTION BLOCK 1 self.convblock1_l1 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), # output_size = 32, RF=5 nn.MaxPool2d(2, 2),nn.ReLU(),nn.BatchNorm2d(128)) # output_size = 16, RF=6 self.convblock1_r1 = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), nn.ReLU(),nn.BatchNorm2d(128), # output_size = 16, RF=10 nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), nn.ReLU(),nn.BatchNorm2d(128)) # output_size = 16, RF=14 # CONVOLUTION BLOCK 2 self.convblock2_l1 = nn.Sequential( nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), # output_size = 16, RF=18 nn.MaxPool2d(2, 2),nn.ReLU(),nn.BatchNorm2d(256)) # output_size = 8, RF=20 # CONVOLUTION BLOCK 3 self.convblock3_l1 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), # output_size = 8, RF=28 nn.MaxPool2d(2, 2), nn.ReLU(),nn.BatchNorm2d(512)) # output_size = 4, RF=32 self.convblock3_r2 = nn.Sequential( nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), nn.ReLU(),nn.BatchNorm2d(512), # output_size = 4, RF=48 nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), padding=1, dilation=1, stride=1, bias=False), nn.ReLU(),nn.BatchNorm2d(512)) # output_size = 4, RF=64 # CONVOLUTION BLOCK 4 self.convblock4_mp = nn.Sequential(nn.MaxPool2d(4)) # output_size = 1, RF = 88 # OUTPUT BLOCK - Fully Connected layer self.output_block = nn.Sequential(nn.Linear(in_features=512, out_features=10, bias=False)) # output_size = 1, RF = 88 def forward(self, x): # Preparation Block x1 = self.prepblock(x) # Convolution Block 1 x2 = self.convblock1_l1(x1) x3 = self.convblock1_r1(x2) x4 = x2 + x3 # Convolution Block 2 x5 = self.convblock2_l1(x4) # Convolution Block 3 x6 = self.convblock3_l1(x5) x7 = self.convblock3_r2(x6) x8 = x7 + x6 # Convolution Block 4 x9 = self.convblock4_mp(x8) # Output Block x9 = x9.view(x9.size(0), -1) x10 = self.output_block(x9) return F.log_softmax(x10, dim=1) def training_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) pred = y_hat.argmax(dim=1, keepdim=True) acc = pred.eq(y.view_as(pred)).float().mean() self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) pred = y_hat.argmax(dim=1, keepdim=True) acc = pred.eq(y.view_as(pred)).float().mean() self.log('val_loss', loss, prog_bar=True) self.log('val_acc', acc, prog_bar=True) return loss def test_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) pred = y_hat.argmax(dim=1, keepdim=True) acc = pred.eq(y.view_as(pred)).float().mean() self.log('test_loss', loss, prog_bar=True) self.log('test_acc', acc, prog_bar=True) return pred # Return predictions instead of loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=0.001) return optimizer #################### # DATA RELATED HOOKS #################### def prepare_data(self): # download CIFAR10(self.data_dir, train=True, download=True) CIFAR10(self.data_dir, train=False, download=True) def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == "fit" or stage is None: cifar_full = CIFAR10(self.data_dir, train=True, transform=self.train_transform) self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000]) # Assign test dataset for use in dataloader(s) if stage == "test" or stage is None: self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.test_transform) def train_dataloader(self): return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count()) def val_dataloader(self): return DataLoader(self.cifar_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count()) def test_dataloader(self): return DataLoader(self.cifar_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count()) def collect_misclassified_images(self, num_images): misclassified_images = [] misclassified_true_labels = [] misclassified_predicted_labels = [] num_collected = 0 for batch in self.test_dataloader(): x, y = batch pred = self.forward(x).argmax(dim=1, keepdim=True) misclassified_mask = pred.eq(y.view_as(pred)).squeeze().cpu().numpy() misclassified_images.extend(x[~misclassified_mask]) misclassified_true_labels.extend(y[~misclassified_mask]) misclassified_predicted_labels.extend(pred[~misclassified_mask]) num_collected += sum(~misclassified_mask) if num_collected >= num_images: break return misclassified_images[:num_images], misclassified_true_labels[:num_images], misclassified_predicted_labels[:num_images], len(misclassified_images) def normalize_image(self, img_tensor): min_val = img_tensor.min() max_val = img_tensor.max() return (img_tensor - min_val) / (max_val - min_val) def show_misclassified_images(self, num_images=10): misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images) num_rows = 2 num_cols = math.ceil(num_images / num_rows) fig, axs = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows)) fig.suptitle(f"Misclassified Images (Showing {num_images} out of {num_misclassified})") plt.subplots_adjust(hspace=0.5) # Adjust vertical space between subplots for i in range(num_images): img = self.normalize_image(misclassified_images[i]).permute(1, 2, 0) row_idx = i // num_cols col_idx = i % num_cols axs[row_idx, col_idx].imshow(img) axs[row_idx, col_idx].set_title(f"True label: {self.classes[true_labels[i]]}\nPredicted: {self.classes[predicted_labels[i]]}") axs[row_idx, col_idx].axis("off") # Remove any empty subplots in the last row (when num_images is not divisible by num_rows) for i in range(num_images, num_rows * num_cols): row_idx = i // num_cols col_idx = i % num_cols axs[row_idx, col_idx].remove() plt.show()