File size: 5,618 Bytes
a01ef8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#

import os
from pydoc import locate
import tarfile
import zipfile
import inspect

from downloader.types import DatasetType
from downloader import utils


class DataDownloader():
    """
    A unified dataset downloader class.

    Can download from TensorFlow Datasets, Torchvision, Hugging Face, and generic web URLs. If initialized for a
    dataset catalog, the download method will return a dataset object of type tensorflow.data.Dataset,
    torch.utils.data.Dataset, or datasets.arrow_dataset.Dataset. If initialized for a web URL that is a zipfile or a
    tarfile, the file will be extracted and the path, or list of paths, to the extracted contents will be returned.
    """
    def __init__(self, dataset_name, dataset_dir, catalog=None, url=None, **kwargs):
        """
        Class constructor for a DataDownloader.

            Args:
                dataset_name (str): Name of the dataset
                dataset_dir (str): Local destination directory of dataset
                catalog (str, optional): The catalog to download the dataset from; options are 'tensorflow_datasets',
                    'torchvision', 'hugging_face', and None which will result in a GENERIC type dataset which expects
                    an accompanying url input
                url (str, optional): If downloading from the web, provide the URL location
                kwargs (optional): Some catalogs accept additional keyword arguments when downloading

            raises:
                ValueError if both catalog and url are omitted or if both are provided

        """
        if catalog is None and url is None:
            raise ValueError("Must provide either a catalog or url as the source.")
        if catalog is not None and url is not None:
            raise ValueError("Only one of catalog or url should be provided. Found {} and {}.".format(catalog, url))

        if not os.path.isdir(dataset_dir):
            os.makedirs(dataset_dir)

        self._dataset_name = dataset_name
        self._dataset_dir = dataset_dir
        self._type = DatasetType.from_str(catalog)
        self._url = url
        self._args = kwargs

    def download(self, split='train'):
        """
        Download the dataset

            Args:
                split (str): desired split, optional

            Returns:
                tensorflow.data.Dataset, torch.utils.data.Dataset, datasets.arrow_dataset.Dataset, str, or list[str]

        """
        if self._type == DatasetType.TENSORFLOW_DATASETS:
            import tensorflow_datasets as tfds
            if isinstance(split, str):
                split = [split]
            os.environ['NO_GCE_CHECK'] = 'true'
            return tfds.load(self._dataset_name,
                             data_dir=self._dataset_dir,
                             split=split,
                             **self._args)

        elif self._type == DatasetType.TORCHVISION:
            from torchvision.datasets import __all__ as torchvision_datasets
            dataset_class = locate('torchvision.datasets.{}'.format(self._dataset_name))
            if dataset_class:
                params = inspect.signature(dataset_class).parameters
                kwargs = dict(download=True, split=split, train=split == 'train')
                kwargs = dict([(k, v) for k, v in kwargs.items() if k in params])
                return dataset_class(self._dataset_dir, **kwargs)
            else:
                raise ValueError("Torchvision dataset {} not found in following: {}"
                                 .format(self._dataset_name, torchvision_datasets))

        elif self._type == DatasetType.HUGGING_FACE:
            from datasets import load_dataset
            if 'subset' in self._args:
                return load_dataset(self._dataset_name, self._args['subset'], split=split, cache_dir=self._dataset_dir)
            else:
                return load_dataset(self._dataset_name, split=split, cache_dir=self._dataset_dir)

        elif self._type == DatasetType.GENERIC:
            file_path = utils.download_file(self._url, self._dataset_dir)
            if os.path.isfile(file_path):
                if tarfile.is_tarfile(file_path):
                    contents = utils.extract_tar_file(file_path, self._dataset_dir)
                elif zipfile.is_zipfile(file_path):
                    contents = utils.extract_zip_file(file_path, self._dataset_dir)
                else:
                    return file_path

                # Contents are a list of top-level extracted members
                # Convert to absolute paths and return a single string if length is 1
                if len(contents) > 1:
                    return [os.path.join(self._dataset_dir, i) for i in contents]
                else:
                    return os.path.join(self._dataset_dir, contents[0])

            else:
                raise FileNotFoundError("Unable to find the downloaded file at:", file_path)