import numpy as np import albumentations from torchvision import datasets from albumentations.pytorch import ToTensorV2 from torch.utils.data import Dataset, DataLoader class CIFAR10Data(Dataset): def __init__(self, dataset, transforms=None) -> None: self.dataset = dataset self.transforms = transforms def __len__(self): return len(self.dataset) def __getitem__(self, index): image, label = self.dataset[index] image = np.array(image) if self.transforms: image = self.transforms(image=image)['image'] return image, label def _get_test_transforms(): test_transforms = albumentations.Compose([albumentations.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]), ToTensorV2()]) return test_transforms def _get_data(is_train, is_download): """Method to get data for training or testing Args: is_train (bool): True if data is for training else false is_download (bool): True to download dataset from iternet Returns: object: Oject of dataset """ data = datasets.CIFAR10('../data', train=is_train, download=is_download) return data def _get_data_loader(data, **kwargs): """Method to get data loader. Args: data (object): Oject of dataset Returns: object: Object of DataLoader class used to feed data to neural network model """ loader = DataLoader(data, **kwargs) return loader def get_test_data_loader(**kwargs): """Method to get data loader for testing Args: batch_size (int): Number of images in a batch Returns: object: Object of DataLoader class used to feed data to neural network model """ test_transforms = _get_test_transforms() test_data = _get_data(is_train=False, is_download=True) test_data = CIFAR10Data(test_data, test_transforms) test_loader = _get_data_loader(data=test_data, **kwargs) return test_loader