Spaces:
Configuration error
Configuration error
| import os | |
| import pytest | |
| import shutil | |
| import tempfile | |
| try: | |
| from datasets.arrow_dataset import Dataset as HF_Dataset | |
| except ModuleNotFoundError: | |
| print("WARNING: datasets may not be installed") | |
| try: | |
| from torch.utils.data import Dataset as TV_Dataset | |
| except ModuleNotFoundError: | |
| print("WARNING: torch may not be installed") | |
| try: | |
| from tensorflow.data import Dataset as TF_Dataset | |
| except ModuleNotFoundError: | |
| print("WARNING: tensorflow may not be installed") | |
| from downloader import datasets | |
| from downloader.types import DatasetType | |
| def test_bad_download(dataset_name, catalog, url): | |
| """ | |
| Tests downloader throws ValueError for bad inputs | |
| """ | |
| with pytest.raises(ValueError): | |
| datasets.DataDownloader(dataset_name, dataset_dir='/tmp/data', catalog=catalog, url=url) | |
| class TestDatasetDownload: | |
| """ | |
| Tests the dataset downloader with a temp download directory that is initialized and cleaned up | |
| """ | |
| URLS = {'sms_spam_collection': | |
| 'https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip', | |
| 'flowers': | |
| 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', | |
| 'imagenet_labels': | |
| 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt', | |
| 'peacock': | |
| 'https://c8.staticflickr.com/8/7095/7210797228_c7fe51c3cb_z.jpg', | |
| 'pennfudan': | |
| 'https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip'} | |
| def setup_class(cls): | |
| cls._dataset_dir = tempfile.mkdtemp() | |
| def teardown_class(cls): | |
| if os.path.exists(cls._dataset_dir): | |
| print("Deleting test directory:", cls._dataset_dir) | |
| shutil.rmtree(cls._dataset_dir) | |
| def test_catalog_download(self, dataset_name, catalog, split, kwargs, size): | |
| """ | |
| Tests downloader for different dataset catalog types and splits | |
| """ | |
| downloader = datasets.DataDownloader(dataset_name, dataset_dir=self._dataset_dir, catalog=catalog, **kwargs) | |
| data = downloader.download(split=split) | |
| # Check the type of the downloader and returned object | |
| if catalog == 'tfds': | |
| data = data[0] # TFDS returns a list with the dataset in it | |
| assert downloader._type == DatasetType.TENSORFLOW_DATASETS | |
| assert isinstance(data, TF_Dataset) | |
| elif catalog == 'torchvision': | |
| assert downloader._type == DatasetType.TORCHVISION | |
| assert isinstance(data, TV_Dataset) | |
| elif catalog == 'huggingface': | |
| assert downloader._type == DatasetType.HUGGING_FACE | |
| assert isinstance(data, HF_Dataset) | |
| # Verify the split size | |
| assert len(data) == size | |
| # Check that the directory is not empty | |
| assert os.listdir(self._dataset_dir) is not None | |
| def test_generic_download(self, dataset_name, url, num_contents): | |
| """ | |
| Tests downloader for different web URLs and file types | |
| """ | |
| downloader = datasets.DataDownloader(dataset_name, dataset_dir=self._dataset_dir, url=url) | |
| data_path = downloader.download() | |
| assert downloader._type == DatasetType.GENERIC | |
| # Test that the returned object is the expected type and length | |
| if num_contents == 1: | |
| assert isinstance(data_path, str) | |
| assert os.path.exists(data_path) | |
| else: | |
| assert isinstance(data_path, list) | |
| for path in data_path: | |
| assert os.path.exists(path) | |