# model.py - Fixed for CodeT5+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch from functools import lru_cache import os import asyncio from concurrent.futures import ThreadPoolExecutor import logging logger = logging.getLogger(__name__) _tokenizer = None _model = None _model_loading = False _model_loaded = False @lru_cache(maxsize=1) def get_model_config(): return { "model_id": "Salesforce/codet5p-220m", "trust_remote_code": True } def load_model_sync(): global _tokenizer, _model, _model_loaded if _model_loaded: return _tokenizer, _model config = get_model_config() model_id = config["model_id"] logger.info(f"🔧 Loading model {model_id}...") try: cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache") os.makedirs(cache_dir, exist_ok=True) logger.info("📝 Loading tokenizer...") _tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=config["trust_remote_code"], cache_dir=cache_dir, use_fast=True, ) logger.info("🧠 Loading model...") _model = AutoModelForSeq2SeqLM.from_pretrained( model_id, trust_remote_code=config["trust_remote_code"], cache_dir=cache_dir ) _model.eval() _model_loaded = True logger.info("✅ Model loaded successfully!") return _tokenizer, _model except Exception as e: logger.error(f"❌ Failed to load model: {e}") raise async def load_model_async(): global _model_loading if _model_loaded: return _tokenizer, _model if _model_loading: while _model_loading and not _model_loaded: await asyncio.sleep(0.1) return _tokenizer, _model _model_loading = True try: loop = asyncio.get_event_loop() with ThreadPoolExecutor(max_workers=1) as executor: tokenizer, model = await loop.run_in_executor(executor, load_model_sync) return tokenizer, model finally: _model_loading = False def get_model(): if not _model_loaded: return load_model_sync() return _tokenizer, _model def is_model_loaded(): return _model_loaded def get_model_info(): config = get_model_config() return { "model_id": config["model_id"], "loaded": _model_loaded, "loading": _model_loading, }