nkanungo's picture
Upload 18 files
6c267c2
#myaddition
import os
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim
from pytorch_lightning import LightningModule
from Custom_Resnet_v1 import CustomResNet
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
import torchvision
from torchmetrics.functional import accuracy
from torchvision.datasets import CIFAR10
from data_transform_cifar10_custom_resnet import get_train_transform, get_test_transform
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64
from cyclic_lr_util import custom_one_cycle_lr
one_cyle_lr = custom_one_cycle_lr(no_of_images=50176, batch_size=2, base_lr=0.04, max_lr=0.4, final_lr=0.004, epoch_stage1=5, epoch_stage2=18, total_epochs=24)
class Assignment12Resnet(LightningModule):
def __init__(self,lr=0.05,data_dir=PATH_DATASETS):
super().__init__()
# Set our init args as class attributes
self.data_dir = data_dir
self.learning_rate = lr
# Hardcode some dataset specific attributes
self.num_classes = 10
self.train_transform = get_train_transform()
self.test_transform = get_test_transform()
self.cifar10_trainset = None
self.cifar10_testset = None
self.save_hyperparameters()
self.model = CustomResNet()
def forward(self, x):
out = self.model(x)
return out
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y,task="multiclass", num_classes=10)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
def configure_optimizers(self):
optimizer = optim.SGD(self.model.parameters(), lr=0.04, momentum=0.9)
steps_per_epoch = 45000 // BATCH_SIZE
scheduler_dict = {
"scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[one_cyle_lr]),
"interval": "step",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
####################
# DATA RELATED HOOKS
####################
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
cifar10_trainset = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=self.train_transform)
self.cifar_train, self.cifar_val = random_split(cifar10_trainset, [46000, 4000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.cifar10_testset= torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=self.test_transform)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.cifar_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count())
def val_dataloader(self):
return torch.utils.data.DataLoader(self.cifar_val, batch_size=BATCH_SIZE,shuffle=False, num_workers=os.cpu_count())
def test_dataloader(self):
return torch.utils.data.DataLoader(self.cifar10_testset, batch_size=BATCH_SIZE, shuffle=False,num_workers=os.cpu_count())