File size: 2,480 Bytes
aff0b1f
ce3ac0e
ebe67a1
 
 
 
 
 
6d5a8ce
ebe67a1
6d5a8ce
ebe67a1
 
 
 
6d5a8ce
ebe67a1
f59cf24
 
6d5a8ce
 
f59cf24
aff0b1f
f59cf24
 
aff0b1f
f59cf24
 
aff0b1f
f59cf24
 
aff0b1f
ebe67a1
aff0b1f
f59cf24
ebe67a1
 
aff0b1f
ebe67a1
 
 
 
 
aff0b1f
ebe67a1
aff0b1f
ebe67a1
ce3ac0e
ebe67a1
 
aff0b1f
ebe67a1
aff0b1f
f59cf24
 
ebe67a1
f59cf24
aff0b1f
f59cf24
 
 
c16d4e7
 
ebe67a1
aff0b1f
c16d4e7
ebe67a1
aff0b1f
ebe67a1
 
 
 
aff0b1f
ebe67a1
aff0b1f
c16d4e7
ebe67a1
 
aff0b1f
ebe67a1
 
 
c16d4e7
 
 
ebe67a1
c16d4e7
 
 
 
 
 
ebe67a1
c16d4e7
ebe67a1
c16d4e7
ebe67a1
c16d4e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# model.py - Fixed for CodeT5+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from functools import lru_cache
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging

logger = logging.getLogger(__name__)

_tokenizer = None
_model = None
_model_loading = False
_model_loaded = False

@lru_cache(maxsize=1)
def get_model_config():
    return {
        "model_id": "Salesforce/codet5p-220m",
        "trust_remote_code": True
    }

def load_model_sync():
    global _tokenizer, _model, _model_loaded

    if _model_loaded:
        return _tokenizer, _model

    config = get_model_config()
    model_id = config["model_id"]

    logger.info(f"πŸ”§ Loading model {model_id}...")

    try:
        cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
        os.makedirs(cache_dir, exist_ok=True)

        logger.info("πŸ“ Loading tokenizer...")
        _tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=config["trust_remote_code"],
            cache_dir=cache_dir,
            use_fast=True,
        )

        logger.info("🧠 Loading model...")
        _model = AutoModelForSeq2SeqLM.from_pretrained(
            model_id,
            trust_remote_code=config["trust_remote_code"],
            cache_dir=cache_dir
        )

        _model.eval()
        _model_loaded = True
        logger.info("βœ… Model loaded successfully!")
        return _tokenizer, _model

    except Exception as e:
        logger.error(f"❌ Failed to load model: {e}")
        raise

async def load_model_async():
    global _model_loading

    if _model_loaded:
        return _tokenizer, _model

    if _model_loading:
        while _model_loading and not _model_loaded:
            await asyncio.sleep(0.1)
        return _tokenizer, _model

    _model_loading = True

    try:
        loop = asyncio.get_event_loop()
        with ThreadPoolExecutor(max_workers=1) as executor:
            tokenizer, model = await loop.run_in_executor(executor, load_model_sync)
        return tokenizer, model
    finally:
        _model_loading = False

def get_model():
    if not _model_loaded:
        return load_model_sync()
    return _tokenizer, _model

def is_model_loaded():
    return _model_loaded

def get_model_info():
    config = get_model_config()
    return {
        "model_id": config["model_id"],
        "loaded": _model_loaded,
        "loading": _model_loading,
    }