Spaces:
Configuration error
Configuration error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# | |
# Copyright (c) 2023 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
import os | |
from pydoc import locate | |
import tarfile | |
import zipfile | |
import inspect | |
from downloader.types import DatasetType | |
from downloader import utils | |
class DataDownloader(): | |
""" | |
A unified dataset downloader class. | |
Can download from TensorFlow Datasets, Torchvision, Hugging Face, and generic web URLs. If initialized for a | |
dataset catalog, the download method will return a dataset object of type tensorflow.data.Dataset, | |
torch.utils.data.Dataset, or datasets.arrow_dataset.Dataset. If initialized for a web URL that is a zipfile or a | |
tarfile, the file will be extracted and the path, or list of paths, to the extracted contents will be returned. | |
""" | |
def __init__(self, dataset_name, dataset_dir, catalog=None, url=None, **kwargs): | |
""" | |
Class constructor for a DataDownloader. | |
Args: | |
dataset_name (str): Name of the dataset | |
dataset_dir (str): Local destination directory of dataset | |
catalog (str, optional): The catalog to download the dataset from; options are 'tensorflow_datasets', | |
'torchvision', 'hugging_face', and None which will result in a GENERIC type dataset which expects | |
an accompanying url input | |
url (str, optional): If downloading from the web, provide the URL location | |
kwargs (optional): Some catalogs accept additional keyword arguments when downloading | |
raises: | |
ValueError if both catalog and url are omitted or if both are provided | |
""" | |
if catalog is None and url is None: | |
raise ValueError("Must provide either a catalog or url as the source.") | |
if catalog is not None and url is not None: | |
raise ValueError("Only one of catalog or url should be provided. Found {} and {}.".format(catalog, url)) | |
if not os.path.isdir(dataset_dir): | |
os.makedirs(dataset_dir) | |
self._dataset_name = dataset_name | |
self._dataset_dir = dataset_dir | |
self._type = DatasetType.from_str(catalog) | |
self._url = url | |
self._args = kwargs | |
def download(self, split='train'): | |
""" | |
Download the dataset | |
Args: | |
split (str): desired split, optional | |
Returns: | |
tensorflow.data.Dataset, torch.utils.data.Dataset, datasets.arrow_dataset.Dataset, str, or list[str] | |
""" | |
if self._type == DatasetType.TENSORFLOW_DATASETS: | |
import tensorflow_datasets as tfds | |
if isinstance(split, str): | |
split = [split] | |
os.environ['NO_GCE_CHECK'] = 'true' | |
return tfds.load(self._dataset_name, | |
data_dir=self._dataset_dir, | |
split=split, | |
**self._args) | |
elif self._type == DatasetType.TORCHVISION: | |
from torchvision.datasets import __all__ as torchvision_datasets | |
dataset_class = locate('torchvision.datasets.{}'.format(self._dataset_name)) | |
if dataset_class: | |
params = inspect.signature(dataset_class).parameters | |
kwargs = dict(download=True, split=split, train=split == 'train') | |
kwargs = dict([(k, v) for k, v in kwargs.items() if k in params]) | |
return dataset_class(self._dataset_dir, **kwargs) | |
else: | |
raise ValueError("Torchvision dataset {} not found in following: {}" | |
.format(self._dataset_name, torchvision_datasets)) | |
elif self._type == DatasetType.HUGGING_FACE: | |
from datasets import load_dataset | |
if 'subset' in self._args: | |
return load_dataset(self._dataset_name, self._args['subset'], split=split, cache_dir=self._dataset_dir) | |
else: | |
return load_dataset(self._dataset_name, split=split, cache_dir=self._dataset_dir) | |
elif self._type == DatasetType.GENERIC: | |
file_path = utils.download_file(self._url, self._dataset_dir) | |
if os.path.isfile(file_path): | |
if tarfile.is_tarfile(file_path): | |
contents = utils.extract_tar_file(file_path, self._dataset_dir) | |
elif zipfile.is_zipfile(file_path): | |
contents = utils.extract_zip_file(file_path, self._dataset_dir) | |
else: | |
return file_path | |
# Contents are a list of top-level extracted members | |
# Convert to absolute paths and return a single string if length is 1 | |
if len(contents) > 1: | |
return [os.path.join(self._dataset_dir, i) for i in contents] | |
else: | |
return os.path.join(self._dataset_dir, contents[0]) | |
else: | |
raise FileNotFoundError("Unable to find the downloaded file at:", file_path) | |