bug-explainer-ml / model.py
Sushwetabm
updated model.py
aff0b1f
# model.py - Fixed for CodeT5+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from functools import lru_cache
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
logger = logging.getLogger(__name__)
_tokenizer = None
_model = None
_model_loading = False
_model_loaded = False
@lru_cache(maxsize=1)
def get_model_config():
return {
"model_id": "Salesforce/codet5p-220m",
"trust_remote_code": True
}
def load_model_sync():
global _tokenizer, _model, _model_loaded
if _model_loaded:
return _tokenizer, _model
config = get_model_config()
model_id = config["model_id"]
logger.info(f"πŸ”§ Loading model {model_id}...")
try:
cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
os.makedirs(cache_dir, exist_ok=True)
logger.info("πŸ“ Loading tokenizer...")
_tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=config["trust_remote_code"],
cache_dir=cache_dir,
use_fast=True,
)
logger.info("🧠 Loading model...")
_model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
trust_remote_code=config["trust_remote_code"],
cache_dir=cache_dir
)
_model.eval()
_model_loaded = True
logger.info("βœ… Model loaded successfully!")
return _tokenizer, _model
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
raise
async def load_model_async():
global _model_loading
if _model_loaded:
return _tokenizer, _model
if _model_loading:
while _model_loading and not _model_loaded:
await asyncio.sleep(0.1)
return _tokenizer, _model
_model_loading = True
try:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(max_workers=1) as executor:
tokenizer, model = await loop.run_in_executor(executor, load_model_sync)
return tokenizer, model
finally:
_model_loading = False
def get_model():
if not _model_loaded:
return load_model_sync()
return _tokenizer, _model
def is_model_loaded():
return _model_loaded
def get_model_info():
config = get_model_config()
return {
"model_id": config["model_id"],
"loaded": _model_loaded,
"loading": _model_loading,
}