|
|
|
import os |
|
import torch |
|
from torch.optim.lr_scheduler import LambdaLR |
|
import torch.optim as optim |
|
from pytorch_lightning import LightningModule |
|
from Custom_Resnet_v1 import CustomResNet |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader, random_split |
|
import torchvision |
|
from torchmetrics.functional import accuracy |
|
from torchvision.datasets import CIFAR10 |
|
from data_transform_cifar10_custom_resnet import get_train_transform, get_test_transform |
|
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") |
|
AVAIL_GPUS = min(1, torch.cuda.device_count()) |
|
BATCH_SIZE = 256 if AVAIL_GPUS else 64 |
|
from cyclic_lr_util import custom_one_cycle_lr |
|
one_cyle_lr = custom_one_cycle_lr(no_of_images=50176, batch_size=2, base_lr=0.04, max_lr=0.4, final_lr=0.004, epoch_stage1=5, epoch_stage2=18, total_epochs=24) |
|
class Assignment12Resnet(LightningModule): |
|
def __init__(self,lr=0.05,data_dir=PATH_DATASETS): |
|
super().__init__() |
|
|
|
self.data_dir = data_dir |
|
self.learning_rate = lr |
|
|
|
|
|
self.num_classes = 10 |
|
self.train_transform = get_train_transform() |
|
self.test_transform = get_test_transform() |
|
self.cifar10_trainset = None |
|
self.cifar10_testset = None |
|
self.save_hyperparameters() |
|
self.model = CustomResNet() |
|
|
|
def forward(self, x): |
|
out = self.model(x) |
|
return out |
|
|
|
def training_step(self, batch, batch_idx): |
|
x, y = batch |
|
logits = self(x) |
|
loss = F.nll_loss(logits, y) |
|
self.log("train_loss", loss) |
|
return loss |
|
|
|
def evaluate(self, batch, stage=None): |
|
x, y = batch |
|
logits = self(x) |
|
loss = F.nll_loss(logits, y) |
|
preds = torch.argmax(logits, dim=1) |
|
acc = accuracy(preds, y,task="multiclass", num_classes=10) |
|
|
|
if stage: |
|
self.log(f"{stage}_loss", loss, prog_bar=True) |
|
self.log(f"{stage}_acc", acc, prog_bar=True) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
self.evaluate(batch, "val") |
|
|
|
def test_step(self, batch, batch_idx): |
|
self.evaluate(batch, "test") |
|
|
|
def configure_optimizers(self): |
|
optimizer = optim.SGD(self.model.parameters(), lr=0.04, momentum=0.9) |
|
steps_per_epoch = 45000 // BATCH_SIZE |
|
scheduler_dict = { |
|
"scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[one_cyle_lr]), |
|
"interval": "step", |
|
} |
|
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} |
|
|
|
|
|
|
|
|
|
def prepare_data(self): |
|
|
|
CIFAR10(self.data_dir, train=True, download=True) |
|
CIFAR10(self.data_dir, train=False, download=True) |
|
|
|
def setup(self, stage=None): |
|
|
|
|
|
if stage == "fit" or stage is None: |
|
cifar10_trainset = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=self.train_transform) |
|
self.cifar_train, self.cifar_val = random_split(cifar10_trainset, [46000, 4000]) |
|
|
|
if stage == "test" or stage is None: |
|
self.cifar10_testset= torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=self.test_transform) |
|
|
|
def train_dataloader(self): |
|
return torch.utils.data.DataLoader(self.cifar_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count()) |
|
|
|
def val_dataloader(self): |
|
return torch.utils.data.DataLoader(self.cifar_val, batch_size=BATCH_SIZE,shuffle=False, num_workers=os.cpu_count()) |
|
|
|
def test_dataloader(self): |
|
return torch.utils.data.DataLoader(self.cifar10_testset, batch_size=BATCH_SIZE, shuffle=False,num_workers=os.cpu_count()) |