Spaces:
Configuration error
Configuration error
import os | |
import pytest | |
import shutil | |
import tempfile | |
try: | |
from datasets.arrow_dataset import Dataset as HF_Dataset | |
except ModuleNotFoundError: | |
print("WARNING: datasets may not be installed") | |
try: | |
from torch.utils.data import Dataset as TV_Dataset | |
except ModuleNotFoundError: | |
print("WARNING: torch may not be installed") | |
try: | |
from tensorflow.data import Dataset as TF_Dataset | |
except ModuleNotFoundError: | |
print("WARNING: tensorflow may not be installed") | |
from downloader import datasets | |
from downloader.types import DatasetType | |
def test_bad_download(dataset_name, catalog, url): | |
""" | |
Tests downloader throws ValueError for bad inputs | |
""" | |
with pytest.raises(ValueError): | |
datasets.DataDownloader(dataset_name, dataset_dir='/tmp/data', catalog=catalog, url=url) | |
class TestDatasetDownload: | |
""" | |
Tests the dataset downloader with a temp download directory that is initialized and cleaned up | |
""" | |
URLS = {'sms_spam_collection': | |
'https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip', | |
'flowers': | |
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', | |
'imagenet_labels': | |
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt', | |
'peacock': | |
'https://c8.staticflickr.com/8/7095/7210797228_c7fe51c3cb_z.jpg', | |
'pennfudan': | |
'https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip'} | |
def setup_class(cls): | |
cls._dataset_dir = tempfile.mkdtemp() | |
def teardown_class(cls): | |
if os.path.exists(cls._dataset_dir): | |
print("Deleting test directory:", cls._dataset_dir) | |
shutil.rmtree(cls._dataset_dir) | |
def test_catalog_download(self, dataset_name, catalog, split, kwargs, size): | |
""" | |
Tests downloader for different dataset catalog types and splits | |
""" | |
downloader = datasets.DataDownloader(dataset_name, dataset_dir=self._dataset_dir, catalog=catalog, **kwargs) | |
data = downloader.download(split=split) | |
# Check the type of the downloader and returned object | |
if catalog == 'tfds': | |
data = data[0] # TFDS returns a list with the dataset in it | |
assert downloader._type == DatasetType.TENSORFLOW_DATASETS | |
assert isinstance(data, TF_Dataset) | |
elif catalog == 'torchvision': | |
assert downloader._type == DatasetType.TORCHVISION | |
assert isinstance(data, TV_Dataset) | |
elif catalog == 'huggingface': | |
assert downloader._type == DatasetType.HUGGING_FACE | |
assert isinstance(data, HF_Dataset) | |
# Verify the split size | |
assert len(data) == size | |
# Check that the directory is not empty | |
assert os.listdir(self._dataset_dir) is not None | |
def test_generic_download(self, dataset_name, url, num_contents): | |
""" | |
Tests downloader for different web URLs and file types | |
""" | |
downloader = datasets.DataDownloader(dataset_name, dataset_dir=self._dataset_dir, url=url) | |
data_path = downloader.download() | |
assert downloader._type == DatasetType.GENERIC | |
# Test that the returned object is the expected type and length | |
if num_contents == 1: | |
assert isinstance(data_path, str) | |
assert os.path.exists(data_path) | |
else: | |
assert isinstance(data_path, list) | |
for path in data_path: | |
assert os.path.exists(path) | |