Spaces:
Running
Running
File size: 3,984 Bytes
d66ab65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
"""
Stats Service - Handles token statistics and color generation
"""
import hashlib
import math
from typing import List, Dict, Any
class StatsService:
"""Service for calculating token statistics and generating colors."""
@staticmethod
def get_varied_color(token: str) -> Dict[str, str]:
"""Generate vibrant colors with HSL for better visual distinction."""
token_hash = hashlib.md5(token.encode()).hexdigest()
hue = int(token_hash[:3], 16) % 360
saturation = 70 + (int(token_hash[3:5], 16) % 20)
lightness = 80 + (int(token_hash[5:7], 16) % 10)
text_lightness = 20 if lightness > 50 else 90
return {
'background': f'hsl({hue}, {saturation}%, {lightness}%)',
'text': f'hsl({hue}, {saturation}%, {text_lightness}%)'
}
@staticmethod
def fix_token(token: str) -> str:
"""Fix token for display with improved space visualization."""
if token.startswith('Ġ'):
space_count = token.count('Ġ')
return '·' * space_count + token[space_count:]
return token
@staticmethod
def get_token_stats(tokens: List[str], original_text: str) -> Dict[str, Any]:
"""Calculate enhanced statistics about the tokens."""
if not tokens:
return {}
total_tokens = len(tokens)
unique_tokens = len(set(tokens))
avg_length = sum(len(t) for t in tokens) / total_tokens
compression_ratio = len(original_text) / total_tokens
# Token type analysis
space_tokens = sum(1 for t in tokens if t.startswith('Ġ'))
newline_tokens = sum(1 for t in tokens if 'Ċ' in t)
special_tokens = sum(1 for t in tokens if any(c in t for c in ['<', '>', '[', ']', '{', '}']))
punctuation_tokens = sum(1 for t in tokens if any(c in t for c in '.,!?;:()'))
# Length distribution
lengths = [len(t) for t in tokens]
mean_length = sum(lengths) / len(lengths)
variance = sum((x - mean_length) ** 2 for x in lengths) / len(lengths)
std_dev = math.sqrt(variance)
return {
'basic_stats': {
'total_tokens': total_tokens,
'unique_tokens': unique_tokens,
'compression_ratio': round(compression_ratio, 2),
'space_tokens': space_tokens,
'newline_tokens': newline_tokens,
'special_tokens': special_tokens,
'punctuation_tokens': punctuation_tokens,
'unique_percentage': round(unique_tokens/total_tokens * 100, 1)
},
'length_stats': {
'avg_length': round(avg_length, 2),
'std_dev': round(std_dev, 2),
'min_length': min(lengths),
'max_length': max(lengths),
'median_length': sorted(lengths)[len(lengths)//2]
}
}
@staticmethod
def format_tokens_for_display(tokens: List[str], tokenizer) -> List[Dict[str, Any]]:
"""Format tokens for frontend display with colors and metadata."""
token_data = []
for idx, token in enumerate(tokens):
colors = StatsService.get_varied_color(token)
fixed_token = StatsService.fix_token(token)
# Compute the numerical token ID from the tokenizer
token_id = tokenizer.convert_tokens_to_ids(token)
token_data.append({
'original': token,
'display': fixed_token[:-1] if fixed_token.endswith('Ċ') else fixed_token,
'colors': colors,
'newline': fixed_token.endswith('Ċ'),
'token_id': token_id,
'token_index': idx
})
return token_data
# Global instance
stats_service = StatsService() |