Olubakka / src /models.py
Sachi Wagaarachchi
bugfixes, CPU local works with 0.6B
60f0153
from transformers import pipeline
import logging
from functools import lru_cache
import os
# Global cache for pipelines to ensure they're initialized only once
_PIPELINE_CACHE = {}
MODELS = {
"Qwen3-0.6B": {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."},
"Qwen3-8B": {"repo_id":"Qwen/Qwen3-8B","description":"Dense causal language model with 8.2 B total parameters (6.95 B non-embedding), 36 layers, 32 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), excels at multilingual instruction following & zero-shot tasks."},
"Qwen3-14B": {"repo_id":"Qwen/Qwen3-14B","description":"Dense causal language model with 14.8 B total parameters (13.2 B non-embedding), 40 layers, 40 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), enhanced human preference alignment & advanced agent integration."},
}
@lru_cache(maxsize=5)
def get_pipeline(model_id, task="text-generation", device="auto"):
"""
Get or create a model pipeline with caching.
This function is cached using lru_cache to ensure efficient reuse.
Args:
model_id (str): The Hugging Face model ID
task (str): The pipeline task (default: "text-generation")
device (str): The device to use for execution (default: "auto")
Returns:
The pipeline object
"""
cache_key = f"{model_id}_{task}_{device}"
if cache_key in _PIPELINE_CACHE:
return _PIPELINE_CACHE[cache_key]
logger = logging.getLogger(__name__)
logger.info(f"Loading model: {model_id} for task: {task} on device: {device}")
pipe = pipeline(
task,
model=model_id,
device_map=device
)
_PIPELINE_CACHE[cache_key] = pipe
return pipe
class ModelManager:
"""Manages loading and caching of Qwen models"""
def __init__(self):
self.models = {k: v["repo_id"] for k, v in MODELS.items()}
self.logger = logging.getLogger(__name__)
def get_pipeline(self, model_name, device="auto"):
"""Get or create a model pipeline"""
try:
model_id = self.models[model_name]
return get_pipeline(model_id, device=device)
except KeyError:
raise ValueError(f"Model {model_name} not found in available models")
def get_model_id(self, model_name):
"""Get the model ID for a given model name"""
try:
return self.models[model_name]
except KeyError:
raise ValueError(f"Model {model_name} not found in available models")
# Determine device based on environment variable
DEVICE = os.getenv("QWEN_DEVICE", "auto")
# Example usage
# model_manager.get_pipeline("Qwen3-0.6B", device=DEVICE)