ERA_V2_S13 / dataset.py
AkashDataScience's picture
Added feature for missclassified images
1ffed57
raw
history blame
2.12 kB
import numpy as np
import albumentations
from torchvision import datasets
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
class CIFAR10Data(Dataset):
def __init__(self, dataset, transforms=None) -> None:
self.dataset = dataset
self.transforms = transforms
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
image, label = self.dataset[index]
image = np.array(image)
if self.transforms:
image = self.transforms(image=image)['image']
return image, label
def _get_test_transforms():
test_transforms = albumentations.Compose([albumentations.Normalize([0.49139968, 0.48215841, 0.44653091],
[0.24703223, 0.24348513, 0.26158784]),
ToTensorV2()])
return test_transforms
def _get_data(is_train, is_download):
"""Method to get data for training or testing
Args:
is_train (bool): True if data is for training else false
is_download (bool): True to download dataset from iternet
Returns:
object: Oject of dataset
"""
data = datasets.CIFAR10('../data', train=is_train, download=is_download)
return data
def _get_data_loader(data, **kwargs):
"""Method to get data loader.
Args:
data (object): Oject of dataset
Returns:
object: Object of DataLoader class used to feed data to neural network model
"""
loader = DataLoader(data, **kwargs)
return loader
def get_test_data_loader(**kwargs):
"""Method to get data loader for testing
Args:
batch_size (int): Number of images in a batch
Returns:
object: Object of DataLoader class used to feed data to neural network model
"""
test_transforms = _get_test_transforms()
test_data = _get_data(is_train=False, is_download=True)
test_data = CIFAR10Data(test_data, test_transforms)
test_loader = _get_data_loader(data=test_data, **kwargs)
return test_loader