File size: 3,617 Bytes
f59cf24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# model.py - Optimized version
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from functools import lru_cache
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging

logger = logging.getLogger(__name__)

# Global variables to store loaded model
_tokenizer = None
_model = None
_model_loading = False
_model_loaded = False

@lru_cache(maxsize=1)
def get_model_config():
    """Cache model configuration"""
    return {
        "model_id": "deepseek-ai/deepseek-coder-1.3b-instruct",
        "torch_dtype": torch.bfloat16,
        "device_map": "auto",
        "trust_remote_code": True,
        # Add these optimizations
        "low_cpu_mem_usage": True,
        "use_cache": True,
    }

def load_model_sync():
    """Synchronous model loading with optimizations"""
    global _tokenizer, _model, _model_loaded
    
    if _model_loaded:
        return _tokenizer, _model
    
    config = get_model_config()
    model_id = config["model_id"]
    
    logger.info(f"πŸ”§ Loading model {model_id}...")
    
    try:
        # Set cache directory to avoid re-downloading
        cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
        os.makedirs(cache_dir, exist_ok=True)
        
        # Load tokenizer first (faster)
        logger.info("πŸ“ Loading tokenizer...")
        _tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=config["trust_remote_code"],
            cache_dir=cache_dir,
            use_fast=True,  # Use fast tokenizer if available
        )
        
        # Load model with optimizations
        logger.info("🧠 Loading model...")
        _model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=config["trust_remote_code"],
            torch_dtype=config["torch_dtype"],
            device_map=config["device_map"],
            low_cpu_mem_usage=config["low_cpu_mem_usage"],
            cache_dir=cache_dir,
            offload_folder="offload",  
             offload_state_dict=True      
        )
        
        # Set to evaluation mode
        _model.eval()
        
        _model_loaded = True
        logger.info("βœ… Model loaded successfully!")
        return _tokenizer, _model
        
    except Exception as e:
        logger.error(f"❌ Failed to load model: {e}")
        raise

async def load_model_async():
    """Asynchronous model loading"""
    global _model_loading
    
    if _model_loaded:
        return _tokenizer, _model
    
    if _model_loading:
        # Wait for ongoing loading to complete
        while _model_loading and not _model_loaded:
            await asyncio.sleep(0.1)
        return _tokenizer, _model
    
    _model_loading = True
    
    try:
        # Run model loading in thread pool to avoid blocking
        loop = asyncio.get_event_loop()
        with ThreadPoolExecutor(max_workers=1) as executor:
            tokenizer, model = await loop.run_in_executor(
                executor, load_model_sync
            )
        return tokenizer, model
    finally:
        _model_loading = False

def get_model():
    """Get the loaded model (for synchronous access)"""
    if not _model_loaded:
        return load_model_sync()
    return _tokenizer, _model

def is_model_loaded():
    """Check if model is loaded"""
    return _model_loaded

def get_model_info():
    """Get model information without loading"""
    config = get_model_config()
    return {
        "model_id": config["model_id"],
        "loaded": _model_loaded,
        "loading": _model_loading,
    }