""" Unit tests for TokenizerService """ import pytest from unittest.mock import Mock, patch, MagicMock from app.services.tokenizer_service import TokenizerService import time class TestTokenizerService: """Test cases for TokenizerService.""" def setup_method(self): """Set up test fixtures.""" self.service = TokenizerService() def test_is_predefined_model(self): """Test predefined model checking.""" # Test with existing model assert self.service.is_predefined_model('gpt2') is True # Test with non-existing model assert self.service.is_predefined_model('nonexistent-model') is False # Test with empty string assert self.service.is_predefined_model('') is False def test_get_tokenizer_info_basic(self, mock_tokenizer): """Test basic tokenizer info extraction.""" info = self.service.get_tokenizer_info(mock_tokenizer) assert 'vocab_size' in info assert 'tokenizer_type' in info assert 'special_tokens' in info assert info['vocab_size'] == 50257 assert info['tokenizer_type'] == 'MockTokenizer' # Check special tokens special_tokens = info['special_tokens'] assert 'pad_token' in special_tokens assert 'eos_token' in special_tokens assert special_tokens['pad_token'] == '' assert special_tokens['eos_token'] == '' def test_get_tokenizer_info_with_max_length(self, mock_tokenizer): """Test tokenizer info with model_max_length.""" mock_tokenizer.model_max_length = 2048 info = self.service.get_tokenizer_info(mock_tokenizer) assert 'model_max_length' in info assert info['model_max_length'] == 2048 def test_get_tokenizer_info_error_handling(self): """Test error handling in tokenizer info extraction.""" # Create a mock that raises an exception broken_tokenizer = Mock() broken_tokenizer.__class__.__name__ = 'BrokenTokenizer' broken_tokenizer.vocab_size = property(Mock(side_effect=Exception("Test error"))) info = self.service.get_tokenizer_info(broken_tokenizer) assert 'error' in info assert 'Test error' in info['error'] @patch('app.services.tokenizer_service.AutoTokenizer') def test_load_predefined_tokenizer_success(self, mock_auto_tokenizer, mock_tokenizer): """Test successful loading of predefined tokenizer.""" mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer tokenizer, info, error = self.service.load_tokenizer('gpt2') assert tokenizer is not None assert error is None assert isinstance(info, dict) mock_auto_tokenizer.from_pretrained.assert_called_once() @patch('app.services.tokenizer_service.AutoTokenizer') def test_load_tokenizer_failure(self, mock_auto_tokenizer): """Test tokenizer loading failure.""" mock_auto_tokenizer.from_pretrained.side_effect = Exception("Failed to load") tokenizer, info, error = self.service.load_tokenizer('gpt2') assert tokenizer is None assert error is not None assert "Failed to load" in error def test_load_nonexistent_predefined_model(self): """Test loading non-existent predefined model.""" tokenizer, info, error = self.service.load_tokenizer('nonexistent-model') assert tokenizer is None assert error is not None assert "not found" in error.lower() @patch('app.services.tokenizer_service.AutoTokenizer') @patch('time.time') def test_custom_tokenizer_caching(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app): """Test custom tokenizer caching behavior.""" with app.app_context(): mock_time.return_value = 1000.0 mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer # First load tokenizer1, info1, error1 = self.service.load_tokenizer('custom/model') # Second load (should use cache) mock_time.return_value = 1500.0 # Still within cache time tokenizer2, info2, error2 = self.service.load_tokenizer('custom/model') # Should only call from_pretrained once assert mock_auto_tokenizer.from_pretrained.call_count == 1 assert tokenizer1 is tokenizer2 @patch('app.services.tokenizer_service.AutoTokenizer') @patch('time.time') def test_custom_tokenizer_cache_expiration(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app): """Test custom tokenizer cache expiration.""" with app.app_context(): mock_time.return_value = 1000.0 mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer # First load self.service.load_tokenizer('custom/model') # Second load after cache expiration mock_time.return_value = 5000.0 # Beyond cache expiration self.service.load_tokenizer('custom/model') # Should call from_pretrained twice assert mock_auto_tokenizer.from_pretrained.call_count == 2 def test_tokenizer_models_constant(self): """Test that TOKENIZER_MODELS contains expected models.""" models = self.service.TOKENIZER_MODELS assert isinstance(models, dict) assert len(models) > 0 # Check that each model has required fields for model_id, model_info in models.items(): assert isinstance(model_id, str) assert isinstance(model_info, dict) assert 'name' in model_info assert 'alias' in model_info assert isinstance(model_info['name'], str) assert isinstance(model_info['alias'], str) def test_cache_initialization(self): """Test that caches are properly initialized.""" service = TokenizerService() assert hasattr(service, 'tokenizers') assert hasattr(service, 'custom_tokenizers') assert hasattr(service, 'tokenizer_info_cache') assert isinstance(service.tokenizers, dict) assert isinstance(service.custom_tokenizers, dict) assert isinstance(service.tokenizer_info_cache, dict) def test_special_tokens_filtering(self, mock_tokenizer): """Test that only valid special tokens are included.""" # Add some None and empty special tokens mock_tokenizer.pad_token = '' mock_tokenizer.eos_token = '' mock_tokenizer.bos_token = None mock_tokenizer.sep_token = '' mock_tokenizer.cls_token = ' ' # Whitespace only mock_tokenizer.unk_token = '' mock_tokenizer.mask_token = '' info = self.service.get_tokenizer_info(mock_tokenizer) special_tokens = info['special_tokens'] # Should only include non-None, non-empty tokens assert 'pad_token' in special_tokens assert 'eos_token' in special_tokens assert 'unk_token' in special_tokens assert 'mask_token' in special_tokens # Should not include None or empty tokens assert 'bos_token' not in special_tokens assert 'sep_token' not in special_tokens assert 'cls_token' not in special_tokens