File size: 3,948 Bytes
6c267c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#myaddition
import os
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim
from pytorch_lightning import LightningModule
from Custom_Resnet_v1 import CustomResNet
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
import torchvision
from torchmetrics.functional import accuracy
from torchvision.datasets import CIFAR10
from data_transform_cifar10_custom_resnet import get_train_transform, get_test_transform
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64
from cyclic_lr_util import custom_one_cycle_lr
one_cyle_lr = custom_one_cycle_lr(no_of_images=50176, batch_size=2, base_lr=0.04, max_lr=0.4, final_lr=0.004, epoch_stage1=5, epoch_stage2=18, total_epochs=24)
class Assignment12Resnet(LightningModule):
    def __init__(self,lr=0.05,data_dir=PATH_DATASETS):
        super().__init__()
        # Set our init args as class attributes
        self.data_dir = data_dir
        self.learning_rate = lr

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.train_transform = get_train_transform()
        self.test_transform = get_test_transform()
        self.cifar10_trainset  = None
        self.cifar10_testset = None
        self.save_hyperparameters()
        self.model = CustomResNet()

    def forward(self, x):
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y,task="multiclass", num_classes=10)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = optim.SGD(self.model.parameters(), lr=0.04, momentum=0.9)
        steps_per_epoch = 45000 // BATCH_SIZE
        scheduler_dict = {
            "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[one_cyle_lr]),
            "interval": "step",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            cifar10_trainset = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=self.train_transform)
            self.cifar_train, self.cifar_val = random_split(cifar10_trainset, [46000, 4000])
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.cifar10_testset= torchvision.datasets.CIFAR10(root=self.data_dir, train=False,  download=True, transform=self.test_transform)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.cifar_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count())

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.cifar_val, batch_size=BATCH_SIZE,shuffle=False, num_workers=os.cpu_count())

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.cifar10_testset, batch_size=BATCH_SIZE, shuffle=False,num_workers=os.cpu_count())