Spaces:
Running
Running
""" | |
Stats Service - Handles token statistics and color generation | |
""" | |
import hashlib | |
import math | |
from typing import List, Dict, Any | |
class StatsService: | |
"""Service for calculating token statistics and generating colors.""" | |
def get_varied_color(token: str) -> Dict[str, str]: | |
"""Generate vibrant colors with HSL for better visual distinction.""" | |
token_hash = hashlib.md5(token.encode()).hexdigest() | |
hue = int(token_hash[:3], 16) % 360 | |
saturation = 70 + (int(token_hash[3:5], 16) % 20) | |
lightness = 80 + (int(token_hash[5:7], 16) % 10) | |
text_lightness = 20 if lightness > 50 else 90 | |
return { | |
'background': f'hsl({hue}, {saturation}%, {lightness}%)', | |
'text': f'hsl({hue}, {saturation}%, {text_lightness}%)' | |
} | |
def fix_token(token: str) -> str: | |
"""Fix token for display with improved space visualization.""" | |
if token.startswith('Ġ'): | |
space_count = token.count('Ġ') | |
return '·' * space_count + token[space_count:] | |
return token | |
def get_token_stats(tokens: List[str], original_text: str) -> Dict[str, Any]: | |
"""Calculate enhanced statistics about the tokens.""" | |
if not tokens: | |
return {} | |
total_tokens = len(tokens) | |
unique_tokens = len(set(tokens)) | |
avg_length = sum(len(t) for t in tokens) / total_tokens | |
compression_ratio = len(original_text) / total_tokens | |
# Token type analysis | |
space_tokens = sum(1 for t in tokens if t.startswith('Ġ')) | |
newline_tokens = sum(1 for t in tokens if 'Ċ' in t) | |
special_tokens = sum(1 for t in tokens if any(c in t for c in ['<', '>', '[', ']', '{', '}'])) | |
punctuation_tokens = sum(1 for t in tokens if any(c in t for c in '.,!?;:()')) | |
# Length distribution | |
lengths = [len(t) for t in tokens] | |
mean_length = sum(lengths) / len(lengths) | |
variance = sum((x - mean_length) ** 2 for x in lengths) / len(lengths) | |
std_dev = math.sqrt(variance) | |
return { | |
'basic_stats': { | |
'total_tokens': total_tokens, | |
'unique_tokens': unique_tokens, | |
'compression_ratio': round(compression_ratio, 2), | |
'space_tokens': space_tokens, | |
'newline_tokens': newline_tokens, | |
'special_tokens': special_tokens, | |
'punctuation_tokens': punctuation_tokens, | |
'unique_percentage': round(unique_tokens/total_tokens * 100, 1) | |
}, | |
'length_stats': { | |
'avg_length': round(avg_length, 2), | |
'std_dev': round(std_dev, 2), | |
'min_length': min(lengths), | |
'max_length': max(lengths), | |
'median_length': sorted(lengths)[len(lengths)//2] | |
} | |
} | |
def format_tokens_for_display(tokens: List[str], tokenizer) -> List[Dict[str, Any]]: | |
"""Format tokens for frontend display with colors and metadata.""" | |
token_data = [] | |
for idx, token in enumerate(tokens): | |
colors = StatsService.get_varied_color(token) | |
fixed_token = StatsService.fix_token(token) | |
# Compute the numerical token ID from the tokenizer | |
token_id = tokenizer.convert_tokens_to_ids(token) | |
token_data.append({ | |
'original': token, | |
'display': fixed_token[:-1] if fixed_token.endswith('Ċ') else fixed_token, | |
'colors': colors, | |
'newline': fixed_token.endswith('Ċ'), | |
'token_id': token_id, | |
'token_index': idx | |
}) | |
return token_data | |
# Global instance | |
stats_service = StatsService() |