#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright (c) 2022 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # import tensorflow as tf from tlt.datasets.tf_dataset import TFDataset from tlt.datasets.image_classification.image_classification_dataset import ImageClassificationDataset from downloader.datasets import DataDownloader class TFDSImageClassificationDataset(ImageClassificationDataset, TFDataset): """ An image classification dataset from the TensorFlow datasets catalog """ def __init__(self, dataset_dir, dataset_name, split=["train"], as_supervised=True, shuffle_files=True, seed=None, **kwargs): """ Class constructor """ if not isinstance(split, list): raise ValueError("Value of split argument must be a list.") ImageClassificationDataset.__init__(self, dataset_dir, dataset_name) self._preprocessed = {} self._seed = seed tf.get_logger().setLevel('ERROR') downloader = DataDownloader(dataset_name, dataset_dir=dataset_dir, catalog='tfds', as_supervised=as_supervised, shuffle_files=shuffle_files, with_info=True) data, self._info = downloader.download(split=split) self._dataset = None self._train_subset = None self._validation_subset = None self._test_subset = None if len(split) == 1: self._validation_type = None # Train & evaluate on the whole dataset self._dataset = data[0] else: self._validation_type = 'defined_split' # Defined by user or TFDS for i, s in enumerate(split): if s == 'train': self._train_subset = data[i] elif s == 'validation': self._validation_subset = data[i] elif s == 'test': self._test_subset = data[i] self._dataset = data[i] if self._dataset is None else self._dataset.concatenate(data[i]) @property def class_names(self): """Returns the list of class names""" return self._info.features["label"].names @property def info(self): """Returns a dictionary of information about the dataset""" return {'dataset_info': self._info, 'preprocessing_info': self._preprocessed} @property def dataset(self): """ Returns the framework dataset object (tf.data.Dataset) """ return self._dataset def preprocess(self, image_size, batch_size, add_aug=None, preprocessor=None): """ Preprocess the dataset to convert to float32, resize, and batch the images Args: image_size (int): desired square image size batch_size (int): desired batch size add_aug (None or list[str]): Choice of augmentations (RandomHorizontalandVerticalFlip, RandomHorizontalFlip, RandomVerticalFlip, RandomZoom, RandomRotation) to be applied during training preprocessor (None or preprocess_input function from keras.applications): Should be provided when using Keras Applications models, which have model-specific preprocessors; otherwise, use None (the default) to apply generic type conversion and resizing Raises: ValueError: if the dataset is not defined or has already been processed """ if self._preprocessed: raise ValueError("Data has already been preprocessed: {}".format(self._preprocessed)) if not isinstance(batch_size, int) or batch_size < 1: raise ValueError("batch_size should be a positive integer") if not isinstance(image_size, int) or image_size < 1: raise ValueError("image_size should be a positive integer") if not (self._dataset or self._train_subset or self._validation_subset or self._test_subset): raise ValueError("Unable to preprocess, because the dataset hasn't been defined.") if add_aug is not None: aug_dict = { 'hvflip': tf.keras.layers.RandomFlip("horizontal_and_vertical", input_shape=(image_size, image_size, 3), seed=self._seed), 'hflip': tf.keras.layers.RandomFlip("horizontal", input_shape=(image_size, image_size, 3), seed=self._seed), 'vflip': tf.keras.layers.RandomFlip("vertical", input_shape=(image_size, image_size, 3), seed=self._seed), 'rotate': tf.keras.layers.RandomRotation(0.5, seed=self._seed), 'zoom': tf.keras.layers.RandomZoom(0.3, seed=self._seed)} aug_list = ['hvflip', 'hflip', 'vflip', 'rotate', 'zoom'] data_augmentation = tf.keras.Sequential() for option in add_aug: if option not in aug_list: raise ValueError("Unsupported augmentation for TensorFlow:{}. \ Supported augmentations are {}".format(option, aug_list)) data_augmentation.add(aug_dict[option]) def preprocess_image(image, label): if preprocessor is None: image = tf.image.convert_image_dtype(image, tf.float32) image = tf.image.resize_with_pad(image, image_size, image_size) return (image, label) # Get the non-None splits split_list = ['_dataset', '_train_subset', '_validation_subset', '_test_subset'] subsets = [s for s in split_list if getattr(self, s, None)] for subset in subsets: setattr(self, subset, getattr(self, subset).map(preprocess_image)) if preprocessor: setattr(self, subset, getattr(self, subset).map(lambda x, y: (preprocessor(x), y))) setattr(self, subset, getattr(self, subset).cache()) setattr(self, subset, getattr(self, subset).batch(batch_size)) setattr(self, subset, getattr(self, subset).prefetch(tf.data.AUTOTUNE)) if add_aug is not None and subset in ['_dataset', '_train_subset']: setattr(self, subset, getattr(self, subset).map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=tf.data.AUTOTUNE)) self._preprocessed = {'image_size': image_size, 'batch_size': batch_size}