Spaces:
Configuration error
Configuration error
| import os | |
| import pytest | |
| import shutil | |
| import tempfile | |
| try: | |
| from torch.nn import Module | |
| except ModuleNotFoundError: | |
| print("WARNING: Unable to import torch. Torch may not be installed") | |
| try: | |
| from tensorflow_hub.keras_layer import KerasLayer | |
| except ModuleNotFoundError: | |
| print("WARNING: Unable to import KerasLayer. Tensorflow Hub may not be installed") | |
| try: | |
| from tensorflow.keras import Model | |
| except ModuleNotFoundError: | |
| print("WARNING: Unable to import Keras Model. Tensorflow may not be installed") | |
| from downloader import models | |
| from downloader.types import ModelType | |
| def test_bad_hub(hub): | |
| """ | |
| Tests downloader throws ValueError for bad inputs | |
| """ | |
| model_name = 'model' | |
| with pytest.raises(ValueError): | |
| models.ModelDownloader(model_name, hub) | |
| class TestModelDownload: | |
| """ | |
| Tests the model downloader with a temp download directory that is initialized and cleaned up | |
| """ | |
| def setup_class(cls): | |
| cls._model_dir = tempfile.mkdtemp() | |
| def teardown_class(cls): | |
| if os.path.exists(cls._model_dir): | |
| print("Deleting test directory:", cls._model_dir) | |
| shutil.rmtree(cls._model_dir) | |
| # Has previously been skipped due to HTTP Error 403: rate limit exceeded') | |
| def test_hub_download(self, model_name, hub, kwargs): | |
| """ | |
| Tests downloader for different model hubs | |
| """ | |
| downloader = models.ModelDownloader(model_name, hub, model_dir=self._model_dir, **kwargs) | |
| model = downloader.download() | |
| # Check the type of the downloader and returned object | |
| if downloader._type == ModelType.TF_HUB: | |
| assert isinstance(model, KerasLayer) | |
| elif downloader._type == ModelType.TORCHVISION: | |
| assert isinstance(model, Module) | |
| elif downloader._type == ModelType.PYTORCH_HUB: | |
| assert isinstance(model, Module) | |
| elif downloader._type == ModelType.HUGGING_FACE: | |
| assert isinstance(model, Module) | |
| elif downloader._type == ModelType.KERAS_APPLICATIONS: | |
| assert isinstance(model, Model) | |
| elif downloader._type == ModelType.TF_BERT_HUGGINGFACE: | |
| assert isinstance(model, Model) | |
| else: | |
| assert False | |
| # Check that the directory is not empty | |
| assert os.listdir(self._model_dir) is not None | |