File size: 3,984 Bytes
d66ab65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""

Stats Service - Handles token statistics and color generation

"""
import hashlib
import math
from typing import List, Dict, Any


class StatsService:
    """Service for calculating token statistics and generating colors."""
    
    @staticmethod
    def get_varied_color(token: str) -> Dict[str, str]:
        """Generate vibrant colors with HSL for better visual distinction."""
        token_hash = hashlib.md5(token.encode()).hexdigest()
        hue = int(token_hash[:3], 16) % 360
        saturation = 70 + (int(token_hash[3:5], 16) % 20)
        lightness = 80 + (int(token_hash[5:7], 16) % 10)
        text_lightness = 20 if lightness > 50 else 90
        
        return {
            'background': f'hsl({hue}, {saturation}%, {lightness}%)',
            'text': f'hsl({hue}, {saturation}%, {text_lightness}%)'
        }
    
    @staticmethod
    def fix_token(token: str) -> str:
        """Fix token for display with improved space visualization."""
        if token.startswith('Ġ'):
            space_count = token.count('Ġ')
            return '·' * space_count + token[space_count:]
        return token
    
    @staticmethod
    def get_token_stats(tokens: List[str], original_text: str) -> Dict[str, Any]:
        """Calculate enhanced statistics about the tokens."""
        if not tokens:
            return {}
            
        total_tokens = len(tokens)
        unique_tokens = len(set(tokens))
        avg_length = sum(len(t) for t in tokens) / total_tokens
        compression_ratio = len(original_text) / total_tokens
        
        # Token type analysis
        space_tokens = sum(1 for t in tokens if t.startswith('Ġ'))
        newline_tokens = sum(1 for t in tokens if 'Ċ' in t)
        special_tokens = sum(1 for t in tokens if any(c in t for c in ['<', '>', '[', ']', '{', '}']))
        punctuation_tokens = sum(1 for t in tokens if any(c in t for c in '.,!?;:()'))
        
        # Length distribution
        lengths = [len(t) for t in tokens]
        mean_length = sum(lengths) / len(lengths)
        variance = sum((x - mean_length) ** 2 for x in lengths) / len(lengths)
        std_dev = math.sqrt(variance)
        
        return {
            'basic_stats': {
                'total_tokens': total_tokens,
                'unique_tokens': unique_tokens,
                'compression_ratio': round(compression_ratio, 2),
                'space_tokens': space_tokens,
                'newline_tokens': newline_tokens,
                'special_tokens': special_tokens,
                'punctuation_tokens': punctuation_tokens,
                'unique_percentage': round(unique_tokens/total_tokens * 100, 1)
            },
            'length_stats': {
                'avg_length': round(avg_length, 2),
                'std_dev': round(std_dev, 2),
                'min_length': min(lengths),
                'max_length': max(lengths),
                'median_length': sorted(lengths)[len(lengths)//2]
            }
        }
    
    @staticmethod
    def format_tokens_for_display(tokens: List[str], tokenizer) -> List[Dict[str, Any]]:
        """Format tokens for frontend display with colors and metadata."""
        token_data = []
        for idx, token in enumerate(tokens):
            colors = StatsService.get_varied_color(token)
            fixed_token = StatsService.fix_token(token)
            # Compute the numerical token ID from the tokenizer
            token_id = tokenizer.convert_tokens_to_ids(token)
            token_data.append({
                'original': token,
                'display': fixed_token[:-1] if fixed_token.endswith('Ċ') else fixed_token,
                'colors': colors,
                'newline': fixed_token.endswith('Ċ'),
                'token_id': token_id,
                'token_index': idx
            })
        return token_data


# Global instance
stats_service = StatsService()