|
"""This file contains functions to prepare dataloader in the way lightning expects""" |
|
import pytorch_lightning as pl |
|
import torchvision.datasets as datasets |
|
from lightning_fabric.utilities.seed import seed_everything |
|
from modules.dataset import CIFAR10Transforms, apply_cifar_image_transformations |
|
from torch.utils.data import DataLoader, random_split |
|
|
|
|
|
class CIFARDataModule(pl.LightningDataModule): |
|
"""Lightning DataModule for CIFAR10 dataset""" |
|
|
|
def __init__(self, data_path, batch_size, seed, val_split=0, num_workers=0): |
|
super().__init__() |
|
|
|
self.data_path = data_path |
|
self.batch_size = batch_size |
|
self.seed = seed |
|
self.val_split = val_split |
|
self.num_workers = num_workers |
|
self.dataloader_dict = { |
|
|
|
"batch_size": self.batch_size, |
|
"num_workers": self.num_workers, |
|
"pin_memory": True, |
|
|
|
"persistent_workers": self.num_workers > 0, |
|
} |
|
self.prepare_data_per_node = False |
|
|
|
|
|
self.training_dataset = None |
|
self.validation_dataset = None |
|
self.testing_dataset = None |
|
|
|
|
|
|
|
|
|
def _split_train_val(self, dataset): |
|
"""Split the dataset into train and validation sets""" |
|
|
|
|
|
if not 0 < self.val_split < 1: |
|
raise ValueError("Validation split must be between 0 and 1") |
|
|
|
|
|
|
|
|
|
|
|
total_length = len(dataset) |
|
train_length = int((1 - self.val_split) * total_length) |
|
val_length = total_length - train_length |
|
|
|
|
|
train_dataset, val_dataset = random_split(dataset, [train_length, val_length]) |
|
|
|
return train_dataset, val_dataset |
|
|
|
|
|
def prepare_data(self): |
|
|
|
datasets.CIFAR10(self.data_path, train=True, download=True) |
|
datasets.CIFAR10(self.data_path, train=False, download=True) |
|
|
|
|
|
|
|
def setup(self, stage=None): |
|
|
|
|
|
|
|
train_transforms, test_transforms = apply_cifar_image_transformations() |
|
val_transforms = test_transforms |
|
|
|
|
|
if stage == "fit" or stage is None: |
|
if self.val_split != 0: |
|
|
|
data_train, data_val = self._split_train_val(datasets.CIFAR10(self.data_path, train=True)) |
|
|
|
self.training_dataset = CIFAR10Transforms(data_train, train_transforms) |
|
self.validation_dataset = CIFAR10Transforms(data_val, val_transforms) |
|
else: |
|
|
|
self.training_dataset = CIFAR10Transforms( |
|
datasets.CIFAR10(self.data_path, train=True), train_transforms |
|
) |
|
|
|
self.validation_dataset = CIFAR10Transforms( |
|
datasets.CIFAR10(self.data_path, train=False), val_transforms |
|
) |
|
|
|
|
|
if stage == "test" or stage is None: |
|
|
|
self.testing_dataset = CIFAR10Transforms(datasets.CIFAR10(self.data_path, train=False), test_transforms) |
|
|
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.training_dataset, **self.dataloader_dict, shuffle=True) |
|
|
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.validation_dataset, **self.dataloader_dict, shuffle=False) |
|
|
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.testing_dataset, **self.dataloader_dict, shuffle=False) |
|
|
|
def _init_fn(self, worker_id): |
|
seed_everything(int(self.seed) + worker_id) |