Spaces:
Sleeping
Sleeping
import numpy as np | |
import albumentations | |
from torchvision import datasets | |
from albumentations.pytorch import ToTensorV2 | |
from torch.utils.data import Dataset, DataLoader | |
class CIFAR10Data(Dataset): | |
def __init__(self, dataset, transforms=None) -> None: | |
self.dataset = dataset | |
self.transforms = transforms | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, index): | |
image, label = self.dataset[index] | |
image = np.array(image) | |
if self.transforms: | |
image = self.transforms(image=image)['image'] | |
return image, label | |
def _get_test_transforms(): | |
test_transforms = albumentations.Compose([albumentations.Normalize([0.49139968, 0.48215841, 0.44653091], | |
[0.24703223, 0.24348513, 0.26158784]), | |
ToTensorV2()]) | |
return test_transforms | |
def _get_data(is_train, is_download): | |
"""Method to get data for training or testing | |
Args: | |
is_train (bool): True if data is for training else false | |
is_download (bool): True to download dataset from iternet | |
Returns: | |
object: Oject of dataset | |
""" | |
data = datasets.CIFAR10('../data', train=is_train, download=is_download) | |
return data | |
def _get_data_loader(data, **kwargs): | |
"""Method to get data loader. | |
Args: | |
data (object): Oject of dataset | |
Returns: | |
object: Object of DataLoader class used to feed data to neural network model | |
""" | |
loader = DataLoader(data, **kwargs) | |
return loader | |
def get_test_data_loader(**kwargs): | |
"""Method to get data loader for testing | |
Args: | |
batch_size (int): Number of images in a batch | |
Returns: | |
object: Object of DataLoader class used to feed data to neural network model | |
""" | |
test_transforms = _get_test_transforms() | |
test_data = _get_data(is_train=False, is_download=True) | |
test_data = CIFAR10Data(test_data, test_transforms) | |
test_loader = _get_data_loader(data=test_data, **kwargs) | |
return test_loader |