""" ResNet in PyTorch. For Pre-activation ResNet, see 'preact_resnet.py'. Reference: [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun Deep Residual Learning for Image Recognition. arXiv:1512.03385 """ import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl from torchmetrics.functional import accuracy from torchvision import transforms from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 import albumentations as A from albumentations.pytorch import ToTensorV2 class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class LitResNet(pl.LightningModule): def __init__(self, block, num_blocks, num_classes=10,batch_size=128): super(LitResNet, self).__init__() self.batch_size = batch_size self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.linear(out) return out def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) # Calculate loss loss = F.cross_entropy(y_hat, y) #Calculate accuracy acc = accuracy(y_hat, y) self.log_dict( {"train_loss": loss, "train_acc": acc}, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) acc = accuracy(y_hat, y) self.log_dict( {"val_loss": loss, "val_acc": acc}, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return loss def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) argmax_pred = y_hat.argmax(dim=1).cpu() loss = F.cross_entropy(y_hat, y) acc = accuracy(y_hat, y) self.log_dict( {"test_loss": loss, "test_acc": acc}, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) # Update the confusion matrix self.confusion_matrix.update(y_hat, y) # Store the predictions, labels and incorrect predictions x, y, y_hat, argmax_pred = ( x.cpu(), y.cpu(), y_hat.cpu(), argmax_pred.cpu(), ) self.pred_store["test_preds"] = torch.cat( (self.pred_store["test_preds"], argmax_pred), dim=0 ) self.pred_store["test_labels"] = torch.cat( (self.pred_store["test_labels"], y), dim=0 ) for d, t, p, o in zip(x, y, argmax_pred, y_hat): if p.eq(t.view_as(p)).item() == False: self.pred_store["test_incorrect"].append( (d.cpu(), t, p, o[p.item()].cpu()) ) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) def LitResNet18(): return LitResNet(BasicBlock, [2, 2, 2, 2]) def LitResNet34(): return LitResNet(BasicBlock, [3, 4, 6, 3])