Spaces:
Running
Running
""" | |
Unit tests for TokenizerService | |
""" | |
import pytest | |
from unittest.mock import Mock, patch, MagicMock | |
from app.services.tokenizer_service import TokenizerService | |
import time | |
class TestTokenizerService: | |
"""Test cases for TokenizerService.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
self.service = TokenizerService() | |
def test_is_predefined_model(self): | |
"""Test predefined model checking.""" | |
# Test with existing model | |
assert self.service.is_predefined_model('gpt2') is True | |
# Test with non-existing model | |
assert self.service.is_predefined_model('nonexistent-model') is False | |
# Test with empty string | |
assert self.service.is_predefined_model('') is False | |
def test_get_tokenizer_info_basic(self, mock_tokenizer): | |
"""Test basic tokenizer info extraction.""" | |
info = self.service.get_tokenizer_info(mock_tokenizer) | |
assert 'vocab_size' in info | |
assert 'tokenizer_type' in info | |
assert 'special_tokens' in info | |
assert info['vocab_size'] == 50257 | |
assert info['tokenizer_type'] == 'MockTokenizer' | |
# Check special tokens | |
special_tokens = info['special_tokens'] | |
assert 'pad_token' in special_tokens | |
assert 'eos_token' in special_tokens | |
assert special_tokens['pad_token'] == '<pad>' | |
assert special_tokens['eos_token'] == '</s>' | |
def test_get_tokenizer_info_with_max_length(self, mock_tokenizer): | |
"""Test tokenizer info with model_max_length.""" | |
mock_tokenizer.model_max_length = 2048 | |
info = self.service.get_tokenizer_info(mock_tokenizer) | |
assert 'model_max_length' in info | |
assert info['model_max_length'] == 2048 | |
def test_get_tokenizer_info_error_handling(self): | |
"""Test error handling in tokenizer info extraction.""" | |
# Create a mock that raises an exception | |
broken_tokenizer = Mock() | |
broken_tokenizer.__class__.__name__ = 'BrokenTokenizer' | |
broken_tokenizer.vocab_size = property(Mock(side_effect=Exception("Test error"))) | |
info = self.service.get_tokenizer_info(broken_tokenizer) | |
assert 'error' in info | |
assert 'Test error' in info['error'] | |
def test_load_predefined_tokenizer_success(self, mock_auto_tokenizer, mock_tokenizer): | |
"""Test successful loading of predefined tokenizer.""" | |
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer | |
tokenizer, info, error = self.service.load_tokenizer('gpt2') | |
assert tokenizer is not None | |
assert error is None | |
assert isinstance(info, dict) | |
mock_auto_tokenizer.from_pretrained.assert_called_once() | |
def test_load_tokenizer_failure(self, mock_auto_tokenizer): | |
"""Test tokenizer loading failure.""" | |
mock_auto_tokenizer.from_pretrained.side_effect = Exception("Failed to load") | |
tokenizer, info, error = self.service.load_tokenizer('gpt2') | |
assert tokenizer is None | |
assert error is not None | |
assert "Failed to load" in error | |
def test_load_nonexistent_predefined_model(self): | |
"""Test loading non-existent predefined model.""" | |
tokenizer, info, error = self.service.load_tokenizer('nonexistent-model') | |
assert tokenizer is None | |
assert error is not None | |
assert "not found" in error.lower() | |
def test_custom_tokenizer_caching(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app): | |
"""Test custom tokenizer caching behavior.""" | |
with app.app_context(): | |
mock_time.return_value = 1000.0 | |
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer | |
# First load | |
tokenizer1, info1, error1 = self.service.load_tokenizer('custom/model') | |
# Second load (should use cache) | |
mock_time.return_value = 1500.0 # Still within cache time | |
tokenizer2, info2, error2 = self.service.load_tokenizer('custom/model') | |
# Should only call from_pretrained once | |
assert mock_auto_tokenizer.from_pretrained.call_count == 1 | |
assert tokenizer1 is tokenizer2 | |
def test_custom_tokenizer_cache_expiration(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app): | |
"""Test custom tokenizer cache expiration.""" | |
with app.app_context(): | |
mock_time.return_value = 1000.0 | |
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer | |
# First load | |
self.service.load_tokenizer('custom/model') | |
# Second load after cache expiration | |
mock_time.return_value = 5000.0 # Beyond cache expiration | |
self.service.load_tokenizer('custom/model') | |
# Should call from_pretrained twice | |
assert mock_auto_tokenizer.from_pretrained.call_count == 2 | |
def test_tokenizer_models_constant(self): | |
"""Test that TOKENIZER_MODELS contains expected models.""" | |
models = self.service.TOKENIZER_MODELS | |
assert isinstance(models, dict) | |
assert len(models) > 0 | |
# Check that each model has required fields | |
for model_id, model_info in models.items(): | |
assert isinstance(model_id, str) | |
assert isinstance(model_info, dict) | |
assert 'name' in model_info | |
assert 'alias' in model_info | |
assert isinstance(model_info['name'], str) | |
assert isinstance(model_info['alias'], str) | |
def test_cache_initialization(self): | |
"""Test that caches are properly initialized.""" | |
service = TokenizerService() | |
assert hasattr(service, 'tokenizers') | |
assert hasattr(service, 'custom_tokenizers') | |
assert hasattr(service, 'tokenizer_info_cache') | |
assert isinstance(service.tokenizers, dict) | |
assert isinstance(service.custom_tokenizers, dict) | |
assert isinstance(service.tokenizer_info_cache, dict) | |
def test_special_tokens_filtering(self, mock_tokenizer): | |
"""Test that only valid special tokens are included.""" | |
# Add some None and empty special tokens | |
mock_tokenizer.pad_token = '<pad>' | |
mock_tokenizer.eos_token = '</s>' | |
mock_tokenizer.bos_token = None | |
mock_tokenizer.sep_token = '' | |
mock_tokenizer.cls_token = ' ' # Whitespace only | |
mock_tokenizer.unk_token = '<unk>' | |
mock_tokenizer.mask_token = '<mask>' | |
info = self.service.get_tokenizer_info(mock_tokenizer) | |
special_tokens = info['special_tokens'] | |
# Should only include non-None, non-empty tokens | |
assert 'pad_token' in special_tokens | |
assert 'eos_token' in special_tokens | |
assert 'unk_token' in special_tokens | |
assert 'mask_token' in special_tokens | |
# Should not include None or empty tokens | |
assert 'bos_token' not in special_tokens | |
assert 'sep_token' not in special_tokens | |
assert 'cls_token' not in special_tokens |