""" Stats Service - Handles token statistics and color generation """ import hashlib import math from typing import List, Dict, Any class StatsService: """Service for calculating token statistics and generating colors.""" @staticmethod def get_varied_color(token: str) -> Dict[str, str]: """Generate vibrant colors with HSL for better visual distinction.""" token_hash = hashlib.md5(token.encode()).hexdigest() hue = int(token_hash[:3], 16) % 360 saturation = 70 + (int(token_hash[3:5], 16) % 20) lightness = 80 + (int(token_hash[5:7], 16) % 10) text_lightness = 20 if lightness > 50 else 90 return { 'background': f'hsl({hue}, {saturation}%, {lightness}%)', 'text': f'hsl({hue}, {saturation}%, {text_lightness}%)' } @staticmethod def fix_token(token: str) -> str: """Fix token for display with improved space visualization.""" if token.startswith('Ġ'): space_count = token.count('Ġ') return '·' * space_count + token[space_count:] return token @staticmethod def get_token_stats(tokens: List[str], original_text: str) -> Dict[str, Any]: """Calculate enhanced statistics about the tokens.""" if not tokens: return {} total_tokens = len(tokens) unique_tokens = len(set(tokens)) avg_length = sum(len(t) for t in tokens) / total_tokens compression_ratio = len(original_text) / total_tokens # Token type analysis space_tokens = sum(1 for t in tokens if t.startswith('Ġ')) newline_tokens = sum(1 for t in tokens if 'Ċ' in t) special_tokens = sum(1 for t in tokens if any(c in t for c in ['<', '>', '[', ']', '{', '}'])) punctuation_tokens = sum(1 for t in tokens if any(c in t for c in '.,!?;:()')) # Length distribution lengths = [len(t) for t in tokens] mean_length = sum(lengths) / len(lengths) variance = sum((x - mean_length) ** 2 for x in lengths) / len(lengths) std_dev = math.sqrt(variance) return { 'basic_stats': { 'total_tokens': total_tokens, 'unique_tokens': unique_tokens, 'compression_ratio': round(compression_ratio, 2), 'space_tokens': space_tokens, 'newline_tokens': newline_tokens, 'special_tokens': special_tokens, 'punctuation_tokens': punctuation_tokens, 'unique_percentage': round(unique_tokens/total_tokens * 100, 1) }, 'length_stats': { 'avg_length': round(avg_length, 2), 'std_dev': round(std_dev, 2), 'min_length': min(lengths), 'max_length': max(lengths), 'median_length': sorted(lengths)[len(lengths)//2] } } @staticmethod def format_tokens_for_display(tokens: List[str], tokenizer) -> List[Dict[str, Any]]: """Format tokens for frontend display with colors and metadata.""" token_data = [] for idx, token in enumerate(tokens): colors = StatsService.get_varied_color(token) fixed_token = StatsService.fix_token(token) # Compute the numerical token ID from the tokenizer token_id = tokenizer.convert_tokens_to_ids(token) token_data.append({ 'original': token, 'display': fixed_token[:-1] if fixed_token.endswith('Ċ') else fixed_token, 'colors': colors, 'newline': fixed_token.endswith('Ċ'), 'token_id': token_id, 'token_index': idx }) return token_data # Global instance stats_service = StatsService()