ParamDev's picture
Upload folder using huggingface_hub
a01ef8c verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader as loader
import numpy as np
import random
import inspect
from tlt.datasets.dataset import BaseDataset
class PyTorchDataset(BaseDataset):
"""
Base class to represent a PyTorch Dataset
"""
def __init__(self, dataset_dir, dataset_name="", dataset_catalog=""):
"""
Class constructor
"""
BaseDataset.__init__(self, dataset_dir, dataset_name, dataset_catalog)
@property
def train_subset(self):
"""
A subset of the dataset used for training
"""
return torch.utils.data.Subset(self._dataset, self._train_indices) if self._train_indices else None
@property
def validation_subset(self):
"""
A subset of the dataset used for validation/evaluation
"""
return torch.utils.data.Subset(self._dataset, self._validation_indices) if self._validation_indices else None
@property
def test_subset(self):
"""
A subset of the dataset held out for final testing/evaluation
"""
return torch.utils.data.Subset(self._dataset, self._test_indices) if self._test_indices else None
@property
def data_loader(self):
"""
A data loader object corresponding to the dataset
"""
return self._data_loader
@property
def train_loader(self):
"""
A data loader object corresponding to the training subset
"""
return self._train_loader
@property
def validation_loader(self):
"""
A data loader object corresponding to the validation subset
"""
return self._validation_loader
@property
def test_loader(self):
"""
A data loader object corresponding to the test subset
"""
return self._test_loader
def get_batch(self, subset='all'):
"""
Get a single batch of images and labels from the dataset.
Args:
subset (str): default "all", can also be "train", "validation", or "test"
Returns:
(examples, labels)
Raises:
ValueError: if the dataset is not defined yet or the given subset is not valid
"""
if subset == 'all' and self._dataset is not None:
return next(iter(self._data_loader))
elif subset == 'train' and self._train_loader is not None:
return next(iter(self._train_loader))
elif subset == 'validation' and self._validation_loader is not None:
return next(iter(self._validation_loader))
elif subset == 'test' and self._test_loader is not None:
return next(iter(self._test_loader))
else:
raise ValueError("Unable to return a batch, because the dataset or subset hasn't been defined.")
def shuffle_split(self, train_pct=.75, val_pct=.25, test_pct=0., shuffle_files=True, seed=None):
"""
Randomly split the dataset into train, validation, and test subsets with a pseudo-random seed option.
Args:
train_pct (float): default .75, percentage of dataset to use for training
val_pct (float): default .25, percentage of dataset to use for validation
test_pct (float): default 0.0, percentage of dataset to use for testing
shuffle_files (bool): default True, optionally control whether shuffling occurs
seed (None or int): default None, can be set for pseudo-randomization
Raises:
ValueError: if percentage input args are not floats or sum to greater than 1
"""
if not (isinstance(train_pct, float) and isinstance(val_pct, float) and isinstance(test_pct, float)):
raise ValueError("Percentage arguments must be floats.")
if train_pct + val_pct + test_pct > 1.0:
raise ValueError("Sum of percentage arguments must be less than or equal to 1.")
length = len(self._dataset)
train_size = int(train_pct * length)
val_size = int(val_pct * length)
test_size = int(test_pct * length)
generator = torch.Generator().manual_seed(seed) if seed else None
if shuffle_files:
dataset_indices = torch.randperm(length, generator=generator).tolist()
else:
dataset_indices = range(length)
self._train_indices = dataset_indices[:train_size]
self._validation_indices = dataset_indices[train_size:train_size + val_size]
if test_pct:
self._test_indices = dataset_indices[train_size + val_size:train_size + val_size + test_size]
else:
self._test_indices = None
self._validation_type = 'shuffle_split'
if self._preprocessed and 'batch_size' in self._preprocessed:
self._make_data_loaders(batch_size=self._preprocessed['batch_size'], generator=generator)
def _make_data_loaders(self, batch_size, generator=None):
"""Make data loaders for the whole dataset and the subsets that have indices defined"""
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
if self._dataset:
self._data_loader = loader(self.dataset, batch_size=batch_size, shuffle=False,
num_workers=self._num_workers, worker_init_fn=seed_worker, generator=generator)
else:
self._data_loader = None
if self._train_indices:
self._train_loader = loader(self.train_subset, batch_size=batch_size, shuffle=False,
num_workers=self._num_workers, worker_init_fn=seed_worker, generator=generator)
else:
self._train_loader = None
if self._validation_indices:
self._validation_loader = loader(self.validation_subset, batch_size=batch_size, shuffle=False,
num_workers=self._num_workers, worker_init_fn=seed_worker,
generator=generator)
else:
self._validation_loader = None
if self._test_indices:
self._test_loader = loader(self.test_subset, batch_size=batch_size, shuffle=False,
num_workers=self._num_workers, worker_init_fn=seed_worker,
generator=generator)
else:
self._test_loader = None
def preprocess(self, image_size='variable', batch_size=32, add_aug=None, **kwargs):
"""
Preprocess the dataset to resize, normalize, and batch the images. Apply augmentation
if specified.
Args:
image_size (int or 'variable'): desired square image size (if 'variable', does not alter image size)
batch_size (int): desired batch size (default 32)
add_aug (None or list[str]): Choice of augmentations (RandomHorizontalFlip, RandomRotation) to be
applied during training
kwargs: optional; additional keyword arguments for Resize and Normalize transforms
Raises:
ValueError if the dataset is not defined or has already been processed
"""
# NOTE: Should this be part of init? If we get image_size and batch size during init,
# then we don't need a separate call to preprocess.
if not (self._dataset):
raise ValueError("Unable to preprocess, because the dataset hasn't been defined.")
if self._preprocessed:
raise ValueError("Data has already been preprocessed: {}".format(self._preprocessed))
if not isinstance(batch_size, int) or batch_size < 1:
raise ValueError("batch_size should be an positive integer")
if not image_size == 'variable' and not (isinstance(image_size, int) and image_size >= 1):
raise ValueError("Input image_size must be either a positive int or 'variable'")
# Get the user-specified keyword arguments
resize_args = {k: v for k, v in kwargs.items() if k in inspect.getfullargspec(T.Resize).args}
normalize_args = {k: v for k, v in kwargs.items() if k in inspect.getfullargspec(T.Normalize).args}
def get_transform(image_size, add_aug):
transforms = []
if isinstance(image_size, int):
transforms.append(T.Resize([image_size, image_size], **resize_args))
if add_aug is not None:
aug_dict = {'hflip': T.RandomHorizontalFlip(),
'rotate': T.RandomRotation(0.5)}
aug_list = ['hflip', 'rotate']
for option in add_aug:
if option not in aug_list:
raise ValueError("Unsupported augmentation for PyTorch:{}. \
Supported augmentations are {}".format(option, aug_list))
transforms.append(aug_dict[option])
transforms.append(T.ToTensor())
transforms.append(T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], **normalize_args))
return T.Compose(transforms)
self._dataset.transform = get_transform(image_size, add_aug)
self._preprocessed = {'image_size': image_size, 'batch_size': batch_size}
self._make_data_loaders(batch_size=batch_size)
def get_inc_dataloaders(self):
calib_dataloader = self.train_loader
if self.validation_loader is not None:
eval_dataloader = self.validation_loader
elif self.test_loader is not None:
eval_dataloader = self.test_loader
else:
eval_dataloader = self.train_loader
return calib_dataloader, eval_dataloader