File size: 3,762 Bytes
ebe67a1
 
 
 
 
 
 
 
6d5a8ce
ebe67a1
6d5a8ce
ebe67a1
 
 
 
 
6d5a8ce
ebe67a1
6d5a8ce
 
 
 
 
 
 
 
 
 
 
f59cf24
 
6d5a8ce
 
f59cf24
 
ebe67a1
f59cf24
ebe67a1
f59cf24
 
ebe67a1
f59cf24
 
ebe67a1
 
 
f59cf24
ebe67a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59cf24
ebe67a1
f59cf24
ebe67a1
f59cf24
ebe67a1
f59cf24
 
 
c16d4e7
 
ebe67a1
 
 
c16d4e7
ebe67a1
 
 
 
 
 
 
 
 
 
c16d4e7
ebe67a1
 
 
 
 
 
 
 
 
c16d4e7
 
ebe67a1
c16d4e7
ebe67a1
c16d4e7
 
 
ebe67a1
c16d4e7
 
 
ebe67a1
 
c16d4e7
ebe67a1
c16d4e7
ebe67a1
c16d4e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# model.py - Optimized version
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from functools import lru_cache
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging

logger = logging.getLogger(__name__)

# Global variables to store loaded model
_tokenizer = None
_model = None
_model_loading = False
_model_loaded = False

@lru_cache(maxsize=1)
# def get_model_config():
#     """Cache model configuration"""
#     return {
#         "model_id": "deepseek-ai/deepseek-coder-1.3b-instruct",
#         "torch_dtype": torch.bfloat16,
#         "device_map": "auto",
#         "trust_remote_code": True,
#         # Add these optimizations
#         "low_cpu_mem_usage": True,
#         "use_cache": True,
#     }
def get_model_config():
    return {
        "model_id": "Salesforce/codet5p-220m",
        "trust_remote_code": True
    }
def load_model_sync():
    """Synchronous model loading with optimizations"""
    global _tokenizer, _model, _model_loaded
    
    if _model_loaded:
        return _tokenizer, _model
    
    config = get_model_config()
    model_id = config["model_id"]
    
    logger.info(f"πŸ”§ Loading model {model_id}...")
    
    try:
        # Set cache directory to avoid re-downloading
        cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
        os.makedirs(cache_dir, exist_ok=True)
        
        # Load tokenizer first (faster)
        logger.info("πŸ“ Loading tokenizer...")
        _tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=config["trust_remote_code"],
            cache_dir=cache_dir,
            use_fast=True,  # Use fast tokenizer if available
        )
        
        # Load model with optimizations
        logger.info("🧠 Loading model...")
        _model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=config["trust_remote_code"],
            torch_dtype=config["torch_dtype"],
            device_map=config["device_map"],
            low_cpu_mem_usage=config["low_cpu_mem_usage"],
            cache_dir=cache_dir,
            offload_folder="offload",  
             offload_state_dict=True      
        )
        
        # Set to evaluation mode
        _model.eval()
        
        _model_loaded = True
        logger.info("βœ… Model loaded successfully!")
        return _tokenizer, _model
        
    except Exception as e:
        logger.error(f"❌ Failed to load model: {e}")
        raise

async def load_model_async():
    """Asynchronous model loading"""
    global _model_loading
    
    if _model_loaded:
        return _tokenizer, _model
    
    if _model_loading:
        # Wait for ongoing loading to complete
        while _model_loading and not _model_loaded:
            await asyncio.sleep(0.1)
        return _tokenizer, _model
    
    _model_loading = True
    
    try:
        # Run model loading in thread pool to avoid blocking
        loop = asyncio.get_event_loop()
        with ThreadPoolExecutor(max_workers=1) as executor:
            tokenizer, model = await loop.run_in_executor(
                executor, load_model_sync
            )
        return tokenizer, model
    finally:
        _model_loading = False

def get_model():
    """Get the loaded model (for synchronous access)"""
    if not _model_loaded:
        return load_model_sync()
    return _tokenizer, _model

def is_model_loaded():
    """Check if model is loaded"""
    return _model_loaded

def get_model_info():
    """Get model information without loading"""
    config = get_model_config()
    return {
        "model_id": config["model_id"],
        "loaded": _model_loaded,
        "loading": _model_loading,
    }