Spaces:
Sleeping
Sleeping
File size: 2,480 Bytes
aff0b1f ce3ac0e ebe67a1 6d5a8ce ebe67a1 6d5a8ce ebe67a1 6d5a8ce ebe67a1 f59cf24 6d5a8ce f59cf24 aff0b1f f59cf24 aff0b1f f59cf24 aff0b1f f59cf24 aff0b1f ebe67a1 aff0b1f f59cf24 ebe67a1 aff0b1f ebe67a1 aff0b1f ebe67a1 aff0b1f ebe67a1 ce3ac0e ebe67a1 aff0b1f ebe67a1 aff0b1f f59cf24 ebe67a1 f59cf24 aff0b1f f59cf24 c16d4e7 ebe67a1 aff0b1f c16d4e7 ebe67a1 aff0b1f ebe67a1 aff0b1f ebe67a1 aff0b1f c16d4e7 ebe67a1 aff0b1f ebe67a1 c16d4e7 ebe67a1 c16d4e7 ebe67a1 c16d4e7 ebe67a1 c16d4e7 ebe67a1 c16d4e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# model.py - Fixed for CodeT5+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from functools import lru_cache
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
logger = logging.getLogger(__name__)
_tokenizer = None
_model = None
_model_loading = False
_model_loaded = False
@lru_cache(maxsize=1)
def get_model_config():
return {
"model_id": "Salesforce/codet5p-220m",
"trust_remote_code": True
}
def load_model_sync():
global _tokenizer, _model, _model_loaded
if _model_loaded:
return _tokenizer, _model
config = get_model_config()
model_id = config["model_id"]
logger.info(f"π§ Loading model {model_id}...")
try:
cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
os.makedirs(cache_dir, exist_ok=True)
logger.info("π Loading tokenizer...")
_tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=config["trust_remote_code"],
cache_dir=cache_dir,
use_fast=True,
)
logger.info("π§ Loading model...")
_model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
trust_remote_code=config["trust_remote_code"],
cache_dir=cache_dir
)
_model.eval()
_model_loaded = True
logger.info("β
Model loaded successfully!")
return _tokenizer, _model
except Exception as e:
logger.error(f"β Failed to load model: {e}")
raise
async def load_model_async():
global _model_loading
if _model_loaded:
return _tokenizer, _model
if _model_loading:
while _model_loading and not _model_loaded:
await asyncio.sleep(0.1)
return _tokenizer, _model
_model_loading = True
try:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(max_workers=1) as executor:
tokenizer, model = await loop.run_in_executor(executor, load_model_sync)
return tokenizer, model
finally:
_model_loading = False
def get_model():
if not _model_loaded:
return load_model_sync()
return _tokenizer, _model
def is_model_loaded():
return _model_loaded
def get_model_info():
config = get_model_config()
return {
"model_id": config["model_id"],
"loaded": _model_loaded,
"loading": _model_loading,
}
|