|
from transformers import pipeline |
|
import logging |
|
from functools import lru_cache |
|
|
|
|
|
_PIPELINE_CACHE = {} |
|
|
|
@lru_cache(maxsize=5) |
|
def get_pipeline(model_id, task="text-generation"): |
|
""" |
|
Get or create a model pipeline with caching. |
|
|
|
This function is cached using lru_cache to ensure efficient reuse. |
|
|
|
Args: |
|
model_id (str): The Hugging Face model ID |
|
task (str): The pipeline task (default: "text-generation") |
|
|
|
Returns: |
|
The pipeline object |
|
""" |
|
cache_key = f"{model_id}_{task}" |
|
|
|
if cache_key in _PIPELINE_CACHE: |
|
return _PIPELINE_CACHE[cache_key] |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Loading model: {model_id} for task: {task}") |
|
|
|
pipe = pipeline( |
|
task, |
|
model=model_id, |
|
device_map="auto" |
|
) |
|
|
|
_PIPELINE_CACHE[cache_key] = pipe |
|
return pipe |
|
|
|
|
|
class ModelManager: |
|
"""Manages loading and caching of Qwen models""" |
|
def __init__(self): |
|
self.models = { |
|
"Qwen3-14B": "Qwen/Qwen3-14B", |
|
"Qwen3-8B": "Qwen/Qwen3-8B" |
|
} |
|
self.logger = logging.getLogger(__name__) |
|
|
|
def get_pipeline(self, model_name): |
|
"""Get or create a model pipeline""" |
|
try: |
|
model_id = self.models[model_name] |
|
return get_pipeline(model_id) |
|
except KeyError: |
|
raise ValueError(f"Model {model_name} not found in available models") |
|
|
|
def get_model_id(self, model_name): |
|
"""Get the model ID for a given model name""" |
|
try: |
|
return self.models[model_name] |
|
except KeyError: |
|
raise ValueError(f"Model {model_name} not found in available models") |