Olubakka / src /models.py
Sachi Wagaarachchi
debug: updated the pipeline
0b7ba67
raw
history blame
1.78 kB
from transformers import pipeline
import logging
from functools import lru_cache
# Global cache for pipelines to ensure they're initialized only once
_PIPELINE_CACHE = {}
@lru_cache(maxsize=5)
def get_pipeline(model_id, task="text-generation"):
"""
Get or create a model pipeline with caching.
This function is cached using lru_cache to ensure efficient reuse.
Args:
model_id (str): The Hugging Face model ID
task (str): The pipeline task (default: "text-generation")
Returns:
The pipeline object
"""
cache_key = f"{model_id}_{task}"
if cache_key in _PIPELINE_CACHE:
return _PIPELINE_CACHE[cache_key]
logger = logging.getLogger(__name__)
logger.info(f"Loading model: {model_id} for task: {task}")
pipe = pipeline(
task,
model=model_id,
device_map="auto"
)
_PIPELINE_CACHE[cache_key] = pipe
return pipe
class ModelManager:
"""Manages loading and caching of Qwen models"""
def __init__(self):
self.models = {
"Qwen3-14B": "Qwen/Qwen3-14B",
"Qwen3-8B": "Qwen/Qwen3-8B"
}
self.logger = logging.getLogger(__name__)
def get_pipeline(self, model_name):
"""Get or create a model pipeline"""
try:
model_id = self.models[model_name]
return get_pipeline(model_id)
except KeyError:
raise ValueError(f"Model {model_name} not found in available models")
def get_model_id(self, model_name):
"""Get the model ID for a given model name"""
try:
return self.models[model_name]
except KeyError:
raise ValueError(f"Model {model_name} not found in available models")