|  | from tempfile import TemporaryDirectory | 
					
						
						|  | from unittest import TestCase | 
					
						
						|  | from unittest.mock import MagicMock, patch | 
					
						
						|  |  | 
					
						
						|  | from transformers import AutoModel, TFAutoModel | 
					
						
						|  | from transformers.onnx import FeaturesManager | 
					
						
						|  | from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @require_torch | 
					
						
						|  | @require_tf | 
					
						
						|  | class DetermineFrameworkTest(TestCase): | 
					
						
						|  | """ | 
					
						
						|  | Test `FeaturesManager.determine_framework` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def setUp(self): | 
					
						
						|  | self.test_model = SMALL_MODEL_IDENTIFIER | 
					
						
						|  | self.framework_pt = "pt" | 
					
						
						|  | self.framework_tf = "tf" | 
					
						
						|  |  | 
					
						
						|  | def _setup_pt_ckpt(self, save_dir): | 
					
						
						|  | model_pt = AutoModel.from_pretrained(self.test_model) | 
					
						
						|  | model_pt.save_pretrained(save_dir) | 
					
						
						|  |  | 
					
						
						|  | def _setup_tf_ckpt(self, save_dir): | 
					
						
						|  | model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True) | 
					
						
						|  | model_tf.save_pretrained(save_dir) | 
					
						
						|  |  | 
					
						
						|  | def test_framework_provided(self): | 
					
						
						|  | """ | 
					
						
						|  | Ensure the that the provided framework is returned. | 
					
						
						|  | """ | 
					
						
						|  | mock_framework = "mock_framework" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | result = FeaturesManager.determine_framework(self.test_model, mock_framework) | 
					
						
						|  | self.assertEqual(result, mock_framework) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with TemporaryDirectory() as local_pt_ckpt: | 
					
						
						|  | self._setup_pt_ckpt(local_pt_ckpt) | 
					
						
						|  | result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework) | 
					
						
						|  | self.assertEqual(result, mock_framework) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with TemporaryDirectory() as local_tf_ckpt: | 
					
						
						|  | self._setup_tf_ckpt(local_tf_ckpt) | 
					
						
						|  | result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework) | 
					
						
						|  | self.assertEqual(result, mock_framework) | 
					
						
						|  |  | 
					
						
						|  | def test_checkpoint_provided(self): | 
					
						
						|  | """ | 
					
						
						|  | Ensure that the determined framework is the one used for the local checkpoint. | 
					
						
						|  |  | 
					
						
						|  | For the functionality to execute, local checkpoints are provided but framework is not. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | with TemporaryDirectory() as local_pt_ckpt: | 
					
						
						|  | self._setup_pt_ckpt(local_pt_ckpt) | 
					
						
						|  | result = FeaturesManager.determine_framework(local_pt_ckpt) | 
					
						
						|  | self.assertEqual(result, self.framework_pt) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with TemporaryDirectory() as local_tf_ckpt: | 
					
						
						|  | self._setup_tf_ckpt(local_tf_ckpt) | 
					
						
						|  | result = FeaturesManager.determine_framework(local_tf_ckpt) | 
					
						
						|  | self.assertEqual(result, self.framework_tf) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with TemporaryDirectory() as local_invalid_ckpt: | 
					
						
						|  | with self.assertRaises(FileNotFoundError): | 
					
						
						|  | result = FeaturesManager.determine_framework(local_invalid_ckpt) | 
					
						
						|  |  | 
					
						
						|  | def test_from_environment(self): | 
					
						
						|  | """ | 
					
						
						|  | Ensure that the determined framework is the one available in the environment. | 
					
						
						|  |  | 
					
						
						|  | For the functionality to execute, framework and local checkpoints are not provided. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mock_tf_available = MagicMock(return_value=False) | 
					
						
						|  | with patch("transformers.onnx.features.is_tf_available", mock_tf_available): | 
					
						
						|  | result = FeaturesManager.determine_framework(self.test_model) | 
					
						
						|  | self.assertEqual(result, self.framework_pt) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mock_torch_available = MagicMock(return_value=False) | 
					
						
						|  | with patch("transformers.onnx.features.is_torch_available", mock_torch_available): | 
					
						
						|  | result = FeaturesManager.determine_framework(self.test_model) | 
					
						
						|  | self.assertEqual(result, self.framework_tf) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mock_tf_available = MagicMock(return_value=True) | 
					
						
						|  | mock_torch_available = MagicMock(return_value=True) | 
					
						
						|  | with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | 
					
						
						|  | "transformers.onnx.features.is_torch_available", mock_torch_available | 
					
						
						|  | ): | 
					
						
						|  | result = FeaturesManager.determine_framework(self.test_model) | 
					
						
						|  | self.assertEqual(result, self.framework_pt) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mock_tf_available = MagicMock(return_value=False) | 
					
						
						|  | mock_torch_available = MagicMock(return_value=False) | 
					
						
						|  | with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | 
					
						
						|  | "transformers.onnx.features.is_torch_available", mock_torch_available | 
					
						
						|  | ): | 
					
						
						|  | with self.assertRaises(EnvironmentError): | 
					
						
						|  | result = FeaturesManager.determine_framework(self.test_model) | 
					
						
						|  |  |