ParamDev's picture
Upload folder using huggingface_hub
a01ef8c verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#
import os
from pydoc import locate
import tarfile
import zipfile
import inspect
from downloader.types import DatasetType
from downloader import utils
class DataDownloader():
"""
A unified dataset downloader class.
Can download from TensorFlow Datasets, Torchvision, Hugging Face, and generic web URLs. If initialized for a
dataset catalog, the download method will return a dataset object of type tensorflow.data.Dataset,
torch.utils.data.Dataset, or datasets.arrow_dataset.Dataset. If initialized for a web URL that is a zipfile or a
tarfile, the file will be extracted and the path, or list of paths, to the extracted contents will be returned.
"""
def __init__(self, dataset_name, dataset_dir, catalog=None, url=None, **kwargs):
"""
Class constructor for a DataDownloader.
Args:
dataset_name (str): Name of the dataset
dataset_dir (str): Local destination directory of dataset
catalog (str, optional): The catalog to download the dataset from; options are 'tensorflow_datasets',
'torchvision', 'hugging_face', and None which will result in a GENERIC type dataset which expects
an accompanying url input
url (str, optional): If downloading from the web, provide the URL location
kwargs (optional): Some catalogs accept additional keyword arguments when downloading
raises:
ValueError if both catalog and url are omitted or if both are provided
"""
if catalog is None and url is None:
raise ValueError("Must provide either a catalog or url as the source.")
if catalog is not None and url is not None:
raise ValueError("Only one of catalog or url should be provided. Found {} and {}.".format(catalog, url))
if not os.path.isdir(dataset_dir):
os.makedirs(dataset_dir)
self._dataset_name = dataset_name
self._dataset_dir = dataset_dir
self._type = DatasetType.from_str(catalog)
self._url = url
self._args = kwargs
def download(self, split='train'):
"""
Download the dataset
Args:
split (str): desired split, optional
Returns:
tensorflow.data.Dataset, torch.utils.data.Dataset, datasets.arrow_dataset.Dataset, str, or list[str]
"""
if self._type == DatasetType.TENSORFLOW_DATASETS:
import tensorflow_datasets as tfds
if isinstance(split, str):
split = [split]
os.environ['NO_GCE_CHECK'] = 'true'
return tfds.load(self._dataset_name,
data_dir=self._dataset_dir,
split=split,
**self._args)
elif self._type == DatasetType.TORCHVISION:
from torchvision.datasets import __all__ as torchvision_datasets
dataset_class = locate('torchvision.datasets.{}'.format(self._dataset_name))
if dataset_class:
params = inspect.signature(dataset_class).parameters
kwargs = dict(download=True, split=split, train=split == 'train')
kwargs = dict([(k, v) for k, v in kwargs.items() if k in params])
return dataset_class(self._dataset_dir, **kwargs)
else:
raise ValueError("Torchvision dataset {} not found in following: {}"
.format(self._dataset_name, torchvision_datasets))
elif self._type == DatasetType.HUGGING_FACE:
from datasets import load_dataset
if 'subset' in self._args:
return load_dataset(self._dataset_name, self._args['subset'], split=split, cache_dir=self._dataset_dir)
else:
return load_dataset(self._dataset_name, split=split, cache_dir=self._dataset_dir)
elif self._type == DatasetType.GENERIC:
file_path = utils.download_file(self._url, self._dataset_dir)
if os.path.isfile(file_path):
if tarfile.is_tarfile(file_path):
contents = utils.extract_tar_file(file_path, self._dataset_dir)
elif zipfile.is_zipfile(file_path):
contents = utils.extract_zip_file(file_path, self._dataset_dir)
else:
return file_path
# Contents are a list of top-level extracted members
# Convert to absolute paths and return a single string if length is 1
if len(contents) > 1:
return [os.path.join(self._dataset_dir, i) for i in contents]
else:
return os.path.join(self._dataset_dir, contents[0])
else:
raise FileNotFoundError("Unable to find the downloaded file at:", file_path)