File size: 1,781 Bytes
8483978 0b7ba67 8483978 7778aa1 8483978 0b7ba67 8483978 0b7ba67 8483978 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from transformers import pipeline
import logging
from functools import lru_cache
# Global cache for pipelines to ensure they're initialized only once
_PIPELINE_CACHE = {}
@lru_cache(maxsize=5)
def get_pipeline(model_id, task="text-generation"):
"""
Get or create a model pipeline with caching.
This function is cached using lru_cache to ensure efficient reuse.
Args:
model_id (str): The Hugging Face model ID
task (str): The pipeline task (default: "text-generation")
Returns:
The pipeline object
"""
cache_key = f"{model_id}_{task}"
if cache_key in _PIPELINE_CACHE:
return _PIPELINE_CACHE[cache_key]
logger = logging.getLogger(__name__)
logger.info(f"Loading model: {model_id} for task: {task}")
pipe = pipeline(
task,
model=model_id,
device_map="auto"
)
_PIPELINE_CACHE[cache_key] = pipe
return pipe
class ModelManager:
"""Manages loading and caching of Qwen models"""
def __init__(self):
self.models = {
"Qwen3-14B": "Qwen/Qwen3-14B",
"Qwen3-8B": "Qwen/Qwen3-8B"
}
self.logger = logging.getLogger(__name__)
def get_pipeline(self, model_name):
"""Get or create a model pipeline"""
try:
model_id = self.models[model_name]
return get_pipeline(model_id)
except KeyError:
raise ValueError(f"Model {model_name} not found in available models")
def get_model_id(self, model_name):
"""Get the model ID for a given model name"""
try:
return self.models[model_name]
except KeyError:
raise ValueError(f"Model {model_name} not found in available models") |