tokenizers / tests /test_tokenizer_service.py
bartar's picture
Upload 26 files
d66ab65 verified
"""
Unit tests for TokenizerService
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
from app.services.tokenizer_service import TokenizerService
import time
class TestTokenizerService:
"""Test cases for TokenizerService."""
def setup_method(self):
"""Set up test fixtures."""
self.service = TokenizerService()
def test_is_predefined_model(self):
"""Test predefined model checking."""
# Test with existing model
assert self.service.is_predefined_model('gpt2') is True
# Test with non-existing model
assert self.service.is_predefined_model('nonexistent-model') is False
# Test with empty string
assert self.service.is_predefined_model('') is False
def test_get_tokenizer_info_basic(self, mock_tokenizer):
"""Test basic tokenizer info extraction."""
info = self.service.get_tokenizer_info(mock_tokenizer)
assert 'vocab_size' in info
assert 'tokenizer_type' in info
assert 'special_tokens' in info
assert info['vocab_size'] == 50257
assert info['tokenizer_type'] == 'MockTokenizer'
# Check special tokens
special_tokens = info['special_tokens']
assert 'pad_token' in special_tokens
assert 'eos_token' in special_tokens
assert special_tokens['pad_token'] == '<pad>'
assert special_tokens['eos_token'] == '</s>'
def test_get_tokenizer_info_with_max_length(self, mock_tokenizer):
"""Test tokenizer info with model_max_length."""
mock_tokenizer.model_max_length = 2048
info = self.service.get_tokenizer_info(mock_tokenizer)
assert 'model_max_length' in info
assert info['model_max_length'] == 2048
def test_get_tokenizer_info_error_handling(self):
"""Test error handling in tokenizer info extraction."""
# Create a mock that raises an exception
broken_tokenizer = Mock()
broken_tokenizer.__class__.__name__ = 'BrokenTokenizer'
broken_tokenizer.vocab_size = property(Mock(side_effect=Exception("Test error")))
info = self.service.get_tokenizer_info(broken_tokenizer)
assert 'error' in info
assert 'Test error' in info['error']
@patch('app.services.tokenizer_service.AutoTokenizer')
def test_load_predefined_tokenizer_success(self, mock_auto_tokenizer, mock_tokenizer):
"""Test successful loading of predefined tokenizer."""
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer, info, error = self.service.load_tokenizer('gpt2')
assert tokenizer is not None
assert error is None
assert isinstance(info, dict)
mock_auto_tokenizer.from_pretrained.assert_called_once()
@patch('app.services.tokenizer_service.AutoTokenizer')
def test_load_tokenizer_failure(self, mock_auto_tokenizer):
"""Test tokenizer loading failure."""
mock_auto_tokenizer.from_pretrained.side_effect = Exception("Failed to load")
tokenizer, info, error = self.service.load_tokenizer('gpt2')
assert tokenizer is None
assert error is not None
assert "Failed to load" in error
def test_load_nonexistent_predefined_model(self):
"""Test loading non-existent predefined model."""
tokenizer, info, error = self.service.load_tokenizer('nonexistent-model')
assert tokenizer is None
assert error is not None
assert "not found" in error.lower()
@patch('app.services.tokenizer_service.AutoTokenizer')
@patch('time.time')
def test_custom_tokenizer_caching(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app):
"""Test custom tokenizer caching behavior."""
with app.app_context():
mock_time.return_value = 1000.0
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# First load
tokenizer1, info1, error1 = self.service.load_tokenizer('custom/model')
# Second load (should use cache)
mock_time.return_value = 1500.0 # Still within cache time
tokenizer2, info2, error2 = self.service.load_tokenizer('custom/model')
# Should only call from_pretrained once
assert mock_auto_tokenizer.from_pretrained.call_count == 1
assert tokenizer1 is tokenizer2
@patch('app.services.tokenizer_service.AutoTokenizer')
@patch('time.time')
def test_custom_tokenizer_cache_expiration(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app):
"""Test custom tokenizer cache expiration."""
with app.app_context():
mock_time.return_value = 1000.0
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# First load
self.service.load_tokenizer('custom/model')
# Second load after cache expiration
mock_time.return_value = 5000.0 # Beyond cache expiration
self.service.load_tokenizer('custom/model')
# Should call from_pretrained twice
assert mock_auto_tokenizer.from_pretrained.call_count == 2
def test_tokenizer_models_constant(self):
"""Test that TOKENIZER_MODELS contains expected models."""
models = self.service.TOKENIZER_MODELS
assert isinstance(models, dict)
assert len(models) > 0
# Check that each model has required fields
for model_id, model_info in models.items():
assert isinstance(model_id, str)
assert isinstance(model_info, dict)
assert 'name' in model_info
assert 'alias' in model_info
assert isinstance(model_info['name'], str)
assert isinstance(model_info['alias'], str)
def test_cache_initialization(self):
"""Test that caches are properly initialized."""
service = TokenizerService()
assert hasattr(service, 'tokenizers')
assert hasattr(service, 'custom_tokenizers')
assert hasattr(service, 'tokenizer_info_cache')
assert isinstance(service.tokenizers, dict)
assert isinstance(service.custom_tokenizers, dict)
assert isinstance(service.tokenizer_info_cache, dict)
def test_special_tokens_filtering(self, mock_tokenizer):
"""Test that only valid special tokens are included."""
# Add some None and empty special tokens
mock_tokenizer.pad_token = '<pad>'
mock_tokenizer.eos_token = '</s>'
mock_tokenizer.bos_token = None
mock_tokenizer.sep_token = ''
mock_tokenizer.cls_token = ' ' # Whitespace only
mock_tokenizer.unk_token = '<unk>'
mock_tokenizer.mask_token = '<mask>'
info = self.service.get_tokenizer_info(mock_tokenizer)
special_tokens = info['special_tokens']
# Should only include non-None, non-empty tokens
assert 'pad_token' in special_tokens
assert 'eos_token' in special_tokens
assert 'unk_token' in special_tokens
assert 'mask_token' in special_tokens
# Should not include None or empty tokens
assert 'bos_token' not in special_tokens
assert 'sep_token' not in special_tokens
assert 'cls_token' not in special_tokens