Spaces:
Running
Running
""" | |
Unit tests for StatsService | |
""" | |
import pytest | |
from unittest.mock import Mock | |
from app.services.stats_service import StatsService | |
class TestStatsService: | |
"""Test cases for StatsService.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
self.service = StatsService() | |
def test_get_varied_color_basic(self): | |
"""Test basic color generation.""" | |
color = self.service.get_varied_color(0, 10) | |
assert isinstance(color, dict) | |
assert 'background' in color | |
assert 'text' in color | |
assert color['background'].startswith('#') | |
assert color['text'].startswith('#') | |
assert len(color['background']) == 7 # #RRGGBB format | |
assert len(color['text']) == 7 | |
def test_get_varied_color_different_indices(self): | |
"""Test that different indices produce different colors.""" | |
color1 = self.service.get_varied_color(0, 10) | |
color2 = self.service.get_varied_color(1, 10) | |
color3 = self.service.get_varied_color(5, 10) | |
# Colors should be different | |
assert color1['background'] != color2['background'] | |
assert color2['background'] != color3['background'] | |
assert color1['background'] != color3['background'] | |
def test_get_varied_color_edge_cases(self): | |
"""Test color generation with edge cases.""" | |
# Single token | |
color = self.service.get_varied_color(0, 1) | |
assert isinstance(color, dict) | |
# Large number of tokens | |
color = self.service.get_varied_color(999, 1000) | |
assert isinstance(color, dict) | |
# Zero index | |
color = self.service.get_varied_color(0, 5) | |
assert isinstance(color, dict) | |
def test_fix_token_basic(self): | |
"""Test basic token fixing.""" | |
assert self.service.fix_token("hello") == "hello" | |
assert self.service.fix_token("world") == "world" | |
def test_fix_token_special_characters(self): | |
"""Test token fixing with special characters.""" | |
# Test space replacement | |
assert self.service.fix_token(" ") == "␣" | |
assert self.service.fix_token("\t") == "→" | |
assert self.service.fix_token("\n") == "↵" | |
# Test Ġ prefix (common in tokenizers) | |
assert self.service.fix_token("Ġhello") == " hello" | |
assert self.service.fix_token("Ġworld") == " world" | |
# Test combination | |
assert self.service.fix_token("Ġ") == " " | |
def test_fix_token_edge_cases(self): | |
"""Test token fixing edge cases.""" | |
# Empty string | |
assert self.service.fix_token("") == "" | |
# None (shouldn't happen but test defensive programming) | |
result = self.service.fix_token(None) | |
assert result is None or result == "" | |
# Multiple special characters | |
assert self.service.fix_token("\n\t ") == "↵→␣" | |
# Multiple Ġ prefixes (edge case) | |
assert self.service.fix_token("ĠĠhello") == " hello" | |
def test_get_token_stats_basic(self, sample_tokens, sample_text): | |
"""Test basic token statistics calculation.""" | |
stats = self.service.get_token_stats(sample_tokens, sample_text) | |
assert isinstance(stats, dict) | |
assert 'basic_stats' in stats | |
assert 'length_stats' in stats | |
basic = stats['basic_stats'] | |
length = stats['length_stats'] | |
# Check basic stats structure | |
assert 'total_tokens' in basic | |
assert 'unique_tokens' in basic | |
assert 'unique_percentage' in basic | |
assert 'special_tokens' in basic | |
assert 'space_tokens' in basic | |
assert 'newline_tokens' in basic | |
assert 'compression_ratio' in basic | |
# Check length stats structure | |
assert 'avg_length' in length | |
assert 'median_length' in length | |
assert 'std_dev' in length | |
def test_get_token_stats_calculations(self): | |
"""Test specific statistics calculations.""" | |
tokens = ['Hello', ' world', '!', ' test'] | |
text = "Hello world! test" | |
stats = self.service.get_token_stats(tokens, text) | |
basic = stats['basic_stats'] | |
# Test total tokens | |
assert basic['total_tokens'] == 4 | |
# Test unique tokens (all are unique in this case) | |
assert basic['unique_tokens'] == 4 | |
assert basic['unique_percentage'] == "100.0" | |
# Test compression ratio | |
expected_ratio = len(text) / len(tokens) | |
assert float(basic['compression_ratio']) == pytest.approx(expected_ratio, rel=1e-2) | |
def test_get_token_stats_special_tokens(self): | |
"""Test special token counting.""" | |
tokens = ['<s>', 'Hello', ' world', '</s>', '<pad>'] | |
text = "Hello world" | |
stats = self.service.get_token_stats(tokens, text) | |
basic = stats['basic_stats'] | |
# Should detect special tokens (those with < >) | |
assert basic['special_tokens'] >= 2 # <s>, </s>, <pad> | |
def test_get_token_stats_whitespace_tokens(self): | |
"""Test whitespace token counting.""" | |
tokens = ['Hello', ' ', 'world', '\n', 'test', '\t'] | |
text = "Hello world\ntest\t" | |
stats = self.service.get_token_stats(tokens, text) | |
basic = stats['basic_stats'] | |
# Should count space and tab tokens | |
assert basic['space_tokens'] >= 1 | |
assert basic['newline_tokens'] >= 1 | |
def test_get_token_stats_length_calculations(self): | |
"""Test token length statistics.""" | |
tokens = ['a', 'bb', 'ccc', 'dddd'] # Lengths: 1, 2, 3, 4 | |
text = "a bb ccc dddd" | |
stats = self.service.get_token_stats(tokens, text) | |
length = stats['length_stats'] | |
# Average length should be 2.5 | |
assert float(length['avg_length']) == pytest.approx(2.5, rel=1e-2) | |
# Median should be 2.5 (between 2 and 3) | |
assert float(length['median_length']) == pytest.approx(2.5, rel=1e-2) | |
def test_get_token_stats_empty_input(self): | |
"""Test statistics with empty input.""" | |
stats = self.service.get_token_stats([], "") | |
basic = stats['basic_stats'] | |
length = stats['length_stats'] | |
assert basic['total_tokens'] == 0 | |
assert basic['unique_tokens'] == 0 | |
assert basic['unique_percentage'] == "0.0" | |
assert basic['compression_ratio'] == "0.0" | |
# Length stats should handle empty case gracefully | |
assert length['avg_length'] == "0.0" | |
assert length['median_length'] == "0.0" | |
assert length['std_dev'] == "0.0" | |
def test_format_tokens_for_display_basic(self, mock_tokenizer): | |
"""Test basic token formatting for display.""" | |
tokens = ['Hello', ' world', '!'] | |
# Mock the tokenizer.convert_ids_to_tokens method | |
mock_tokenizer.convert_ids_to_tokens.return_value = tokens | |
formatted = self.service.format_tokens_for_display(tokens, mock_tokenizer) | |
assert isinstance(formatted, list) | |
assert len(formatted) == len(tokens) | |
for i, token_data in enumerate(formatted): | |
assert isinstance(token_data, dict) | |
assert 'display' in token_data | |
assert 'original' in token_data | |
assert 'token_id' in token_data | |
assert 'colors' in token_data | |
assert 'newline' in token_data | |
assert token_data['original'] == tokens[i] | |
assert isinstance(token_data['colors'], dict) | |
assert 'background' in token_data['colors'] | |
assert 'text' in token_data['colors'] | |
def test_format_tokens_newline_detection(self, mock_tokenizer): | |
"""Test newline detection in token formatting.""" | |
tokens = ['Hello', '\n', 'world'] | |
mock_tokenizer.convert_ids_to_tokens.return_value = tokens | |
formatted = self.service.format_tokens_for_display(tokens, mock_tokenizer) | |
# Second token should be marked as newline | |
assert formatted[1]['newline'] is True | |
assert formatted[0]['newline'] is False | |
assert formatted[2]['newline'] is False | |
def test_format_tokens_color_consistency(self, mock_tokenizer): | |
"""Test that same tokens get same colors.""" | |
tokens = ['hello', 'world', 'hello'] # 'hello' appears twice | |
mock_tokenizer.convert_ids_to_tokens.return_value = tokens | |
formatted = self.service.format_tokens_for_display(tokens, mock_tokenizer) | |
# Same tokens should have same colors | |
assert formatted[0]['colors']['background'] == formatted[2]['colors']['background'] | |
assert formatted[0]['colors']['text'] == formatted[2]['colors']['text'] | |
# Different tokens should have different colors | |
assert formatted[0]['colors']['background'] != formatted[1]['colors']['background'] | |
def test_format_tokens_special_character_handling(self, mock_tokenizer): | |
"""Test special character handling in token formatting.""" | |
tokens = [' ', '\t', '\n', 'Ġhello'] | |
mock_tokenizer.convert_ids_to_tokens.return_value = tokens | |
formatted = self.service.format_tokens_for_display(tokens, mock_tokenizer) | |
# Check that special characters are properly converted | |
assert formatted[0]['display'] == '␣' # Space | |
assert formatted[1]['display'] == '→' # Tab | |
assert formatted[2]['display'] == '↵' # Newline | |
assert formatted[3]['display'] == ' hello' # Ġ prefix |