File size: 7,468 Bytes
d66ab65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""

Tokenizer Service - Handles tokenizer loading, caching, and management

"""
import time
from typing import Dict, Tuple, Optional, Any
from transformers import AutoTokenizer
from flask import current_app


class TokenizerService:
    """Service for managing tokenizer loading and caching."""
    
    # Predefined tokenizer models with aliases
    TOKENIZER_MODELS = {
        'qwen3': {
            'name': 'Qwen/Qwen3-0.6B',
            'alias': 'Qwen 3'
        },
        'gemma3-27b': {
            'name': 'google/gemma-3-27b-it',
            'alias': 'Gemma 3 27B'
        },
        'glm4': {
            'name': 'THUDM/GLM-4-32B-0414',
            'alias': 'GLM 4'
        },
        'mistral-small': {
            'name': 'mistralai/Mistral-Small-3.1-24B-Instruct-2503',
            'alias': 'Mistral Small 3.1'
        },
        'llama4': {
            'name': 'meta-llama/Llama-4-Scout-17B-16E-Instruct',
            'alias': 'Llama 4'
        },
        'deepseek-r1': {
            'name': 'deepseek-ai/DeepSeek-R1',
            'alias': 'Deepseek R1'
        },
        'qwen_25_72b': {
            'name': 'Qwen/Qwen2.5-72B-Instruct',
            'alias': 'QWQ 32B'
        },
        'llama_33': {
            'name': 'unsloth/Llama-3.3-70B-Instruct-bnb-4bit',
            'alias': 'Llama 3.3 70B'
        },
        'gemma2_2b': {
            'name': 'google/gemma-2-2b-it',
            'alias': 'Gemma 2 2B'
        },
        'bert-large-uncased': {
            'name': 'google-bert/bert-large-uncased',
            'alias': 'Bert Large Uncased'
        },
        'gpt2': {
            'name': 'openai-community/gpt2',
            'alias': 'GPT-2'
        }
    }
    
    def __init__(self):
        """Initialize the tokenizer service with empty caches."""
        self.tokenizers: Dict[str, Any] = {}
        self.custom_tokenizers: Dict[str, Tuple[Any, float]] = {}
        self.tokenizer_info_cache: Dict[str, Dict] = {}
        self.custom_model_errors: Dict[str, str] = {}
    
    def get_tokenizer_info(self, tokenizer) -> Dict:
        """Extract useful information from a tokenizer."""
        info = {}
        try:
            # Get vocabulary size (dictionary size)
            if hasattr(tokenizer, 'vocab_size'):
                info['vocab_size'] = tokenizer.vocab_size
            elif hasattr(tokenizer, 'get_vocab'):
                info['vocab_size'] = len(tokenizer.get_vocab())
            
            # Get model max length if available
            if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length < 1000000:
                info['model_max_length'] = tokenizer.model_max_length
            
            # Check tokenizer type
            info['tokenizer_type'] = tokenizer.__class__.__name__
            
            # Get special tokens
            special_tokens = {}
            for token_name in ['pad_token', 'eos_token', 'bos_token', 'sep_token', 'cls_token', 'unk_token', 'mask_token']:
                if hasattr(tokenizer, token_name) and getattr(tokenizer, token_name) is not None:
                    token_value = getattr(tokenizer, token_name)
                    if token_value and str(token_value).strip():
                        special_tokens[token_name] = str(token_value)
            
            info['special_tokens'] = special_tokens
                
        except Exception as e:
            info['error'] = f"Error extracting tokenizer info: {str(e)}"
        
        return info
    
    def load_tokenizer(self, model_id_or_name: str) -> Tuple[Optional[Any], Dict, Optional[str]]:
        """

        Load tokenizer if not already loaded.

        

        Returns:

            Tuple of (tokenizer, tokenizer_info, error_message)

        """
        error_message = None
        tokenizer_info = {}
        
        # Check if we have cached tokenizer info
        if model_id_or_name in self.tokenizer_info_cache:
            tokenizer_info = self.tokenizer_info_cache[model_id_or_name]
        
        try:
            # Check if it's a predefined model ID
            if model_id_or_name in self.TOKENIZER_MODELS:
                model_name = self.TOKENIZER_MODELS[model_id_or_name]['name']
                if model_id_or_name not in self.tokenizers:
                    self.tokenizers[model_id_or_name] = AutoTokenizer.from_pretrained(model_name)
                tokenizer = self.tokenizers[model_id_or_name]
                
                # Get tokenizer info if not already cached
                if model_id_or_name not in self.tokenizer_info_cache:
                    tokenizer_info = self.get_tokenizer_info(tokenizer)
                    self.tokenizer_info_cache[model_id_or_name] = tokenizer_info
                    
                return tokenizer, tokenizer_info, None
            
            # It's a custom model path
            # Check if we have it in the custom cache and it's not expired
            current_time = time.time()
            cache_expiration = current_app.config.get('CACHE_EXPIRATION', 3600)
            
            if model_id_or_name in self.custom_tokenizers:
                cached_tokenizer, timestamp = self.custom_tokenizers[model_id_or_name]
                if current_time - timestamp < cache_expiration:
                    # Get tokenizer info if not already cached
                    if model_id_or_name not in self.tokenizer_info_cache:
                        tokenizer_info = self.get_tokenizer_info(cached_tokenizer)
                        self.tokenizer_info_cache[model_id_or_name] = tokenizer_info
                    return cached_tokenizer, tokenizer_info, None
            
            # Not in cache or expired, load it
            tokenizer = AutoTokenizer.from_pretrained(model_id_or_name)
            # Store in cache with timestamp
            self.custom_tokenizers[model_id_or_name] = (tokenizer, current_time)
            # Clear any previous errors for this model
            if model_id_or_name in self.custom_model_errors:
                del self.custom_model_errors[model_id_or_name]
                
            # Get tokenizer info
            tokenizer_info = self.get_tokenizer_info(tokenizer)
            self.tokenizer_info_cache[model_id_or_name] = tokenizer_info
            
            return tokenizer, tokenizer_info, None
            
        except Exception as e:
            error_message = f"Failed to load tokenizer: {str(e)}"
            # Store error for future reference
            self.custom_model_errors[model_id_or_name] = error_message
            return None, tokenizer_info, error_message
    
    def get_model_alias(self, model_id: str) -> str:
        """Get the display alias for a model ID."""
        if model_id in self.TOKENIZER_MODELS:
            return self.TOKENIZER_MODELS[model_id]['alias']
        return model_id
    
    def is_predefined_model(self, model_id: str) -> bool:
        """Check if a model ID is a predefined model."""
        return model_id in self.TOKENIZER_MODELS
    
    def clear_cache(self):
        """Clear all caches."""
        self.tokenizers.clear()
        self.custom_tokenizers.clear()
        self.tokenizer_info_cache.clear()
        self.custom_model_errors.clear()


# Global instance
tokenizer_service = TokenizerService()