Spaces:
Sleeping
Sleeping
# model.py - Fixed for CodeT5+ | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from functools import lru_cache | |
import os | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
import logging | |
logger = logging.getLogger(__name__) | |
_tokenizer = None | |
_model = None | |
_model_loading = False | |
_model_loaded = False | |
def get_model_config(): | |
return { | |
"model_id": "Salesforce/codet5p-220m", | |
"trust_remote_code": True | |
} | |
def load_model_sync(): | |
global _tokenizer, _model, _model_loaded | |
if _model_loaded: | |
return _tokenizer, _model | |
config = get_model_config() | |
model_id = config["model_id"] | |
logger.info(f"π§ Loading model {model_id}...") | |
try: | |
cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
logger.info("π Loading tokenizer...") | |
_tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
trust_remote_code=config["trust_remote_code"], | |
cache_dir=cache_dir, | |
use_fast=True, | |
) | |
logger.info("π§ Loading model...") | |
_model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_id, | |
trust_remote_code=config["trust_remote_code"], | |
cache_dir=cache_dir | |
) | |
_model.eval() | |
_model_loaded = True | |
logger.info("β Model loaded successfully!") | |
return _tokenizer, _model | |
except Exception as e: | |
logger.error(f"β Failed to load model: {e}") | |
raise | |
async def load_model_async(): | |
global _model_loading | |
if _model_loaded: | |
return _tokenizer, _model | |
if _model_loading: | |
while _model_loading and not _model_loaded: | |
await asyncio.sleep(0.1) | |
return _tokenizer, _model | |
_model_loading = True | |
try: | |
loop = asyncio.get_event_loop() | |
with ThreadPoolExecutor(max_workers=1) as executor: | |
tokenizer, model = await loop.run_in_executor(executor, load_model_sync) | |
return tokenizer, model | |
finally: | |
_model_loading = False | |
def get_model(): | |
if not _model_loaded: | |
return load_model_sync() | |
return _tokenizer, _model | |
def is_model_loaded(): | |
return _model_loaded | |
def get_model_info(): | |
config = get_model_config() | |
return { | |
"model_id": config["model_id"], | |
"loaded": _model_loaded, | |
"loading": _model_loading, | |
} | |