Spaces:
Sleeping
Sleeping
File size: 4,682 Bytes
6d5a8ce f59cf24 6d5a8ce f59cf24 6d5a8ce f59cf24 6d5a8ce f59cf24 6d5a8ce f59cf24 6d5a8ce f59cf24 6d5a8ce f59cf24 6d5a8ce f59cf24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# # model.py - Optimized version
# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch
# from functools import lru_cache
# import os
# import asyncio
# from concurrent.futures import ThreadPoolExecutor
# import logging
# logger = logging.getLogger(__name__)
# # Global variables to store loaded model
# _tokenizer = None
# _model = None
# _model_loading = False
# _model_loaded = False
# @lru_cache(maxsize=1)
# def get_model_config():
# """Cache model configuration"""
# return {
# "model_id": "deepseek-ai/deepseek-coder-1.3b-instruct",
# "torch_dtype": torch.bfloat16,
# "device_map": "auto",
# "trust_remote_code": True,
# # Add these optimizations
# "low_cpu_mem_usage": True,
# "use_cache": True,
# }
# def load_model_sync():
# """Synchronous model loading with optimizations"""
# global _tokenizer, _model, _model_loaded
# if _model_loaded:
# return _tokenizer, _model
# config = get_model_config()
# model_id = config["model_id"]
# logger.info(f"π§ Loading model {model_id}...")
# try:
# # Set cache directory to avoid re-downloading
# cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
# os.makedirs(cache_dir, exist_ok=True)
# # Load tokenizer first (faster)
# logger.info("π Loading tokenizer...")
# _tokenizer = AutoTokenizer.from_pretrained(
# model_id,
# trust_remote_code=config["trust_remote_code"],
# cache_dir=cache_dir,
# use_fast=True, # Use fast tokenizer if available
# )
# # Load model with optimizations
# logger.info("π§ Loading model...")
# _model = AutoModelForCausalLM.from_pretrained(
# model_id,
# trust_remote_code=config["trust_remote_code"],
# torch_dtype=config["torch_dtype"],
# device_map=config["device_map"],
# low_cpu_mem_usage=config["low_cpu_mem_usage"],
# cache_dir=cache_dir,
# offload_folder="offload",
# offload_state_dict=True
# )
# # Set to evaluation mode
# _model.eval()
# _model_loaded = True
# logger.info("β
Model loaded successfully!")
# return _tokenizer, _model
# except Exception as e:
# logger.error(f"β Failed to load model: {e}")
# raise
# async def load_model_async():
# """Asynchronous model loading"""
# global _model_loading
# if _model_loaded:
# return _tokenizer, _model
# if _model_loading:
# # Wait for ongoing loading to complete
# while _model_loading and not _model_loaded:
# await asyncio.sleep(0.1)
# return _tokenizer, _model
# _model_loading = True
# try:
# # Run model loading in thread pool to avoid blocking
# loop = asyncio.get_event_loop()
# with ThreadPoolExecutor(max_workers=1) as executor:
# tokenizer, model = await loop.run_in_executor(
# executor, load_model_sync
# )
# return tokenizer, model
# finally:
# _model_loading = False
# def get_model():
# """Get the loaded model (for synchronous access)"""
# if not _model_loaded:
# return load_model_sync()
# return _tokenizer, _model
# def is_model_loaded():
# """Check if model is loaded"""
# return _model_loaded
# def get_model_info():
# """Get model information without loading"""
# config = get_model_config()
# return {
# "model_id": config["model_id"],
# "loaded": _model_loaded,
# "loading": _model_loading,
# }
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from functools import lru_cache
import logging
logger = logging.getLogger(__name__)
_model_loaded = False
_tokenizer = None
_model = None
@lru_cache(maxsize=1)
def get_model_config():
return {
"model_id": "Salesforce/codet5p-220m",
"trust_remote_code": True
}
def load_model_sync():
global _tokenizer, _model, _model_loaded
if _model_loaded:
return _tokenizer, _model
config = get_model_config()
model_id = config["model_id"]
try:
_tokenizer = AutoTokenizer.from_pretrained(model_id)
_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
_model.eval()
_model_loaded = True
return _tokenizer, _model
except Exception as e:
logger.error(f"β Failed to load model: {e}")
raise
|