|
from transformers import pipeline |
|
import logging |
|
from functools import lru_cache |
|
import os |
|
|
|
|
|
_PIPELINE_CACHE = {} |
|
|
|
MODELS = { |
|
"Qwen3-0.6B": {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."}, |
|
"Qwen3-8B": {"repo_id":"Qwen/Qwen3-8B","description":"Dense causal language model with 8.2 B total parameters (6.95 B non-embedding), 36 layers, 32 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), excels at multilingual instruction following & zero-shot tasks."}, |
|
"Qwen3-14B": {"repo_id":"Qwen/Qwen3-14B","description":"Dense causal language model with 14.8 B total parameters (13.2 B non-embedding), 40 layers, 40 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), enhanced human preference alignment & advanced agent integration."}, |
|
|
|
|
|
} |
|
|
|
@lru_cache(maxsize=5) |
|
def get_pipeline(model_id, task="text-generation", device="auto"): |
|
""" |
|
Get or create a model pipeline with caching. |
|
|
|
This function is cached using lru_cache to ensure efficient reuse. |
|
|
|
Args: |
|
model_id (str): The Hugging Face model ID |
|
task (str): The pipeline task (default: "text-generation") |
|
device (str): The device to use for execution (default: "auto") |
|
|
|
Returns: |
|
The pipeline object |
|
""" |
|
cache_key = f"{model_id}_{task}_{device}" |
|
|
|
if cache_key in _PIPELINE_CACHE: |
|
return _PIPELINE_CACHE[cache_key] |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Loading model: {model_id} for task: {task} on device: {device}") |
|
|
|
pipe = pipeline( |
|
task, |
|
model=model_id, |
|
device_map=device |
|
) |
|
|
|
_PIPELINE_CACHE[cache_key] = pipe |
|
return pipe |
|
|
|
|
|
class ModelManager: |
|
"""Manages loading and caching of Qwen models""" |
|
def __init__(self): |
|
self.models = {k: v["repo_id"] for k, v in MODELS.items()} |
|
self.logger = logging.getLogger(__name__) |
|
|
|
def get_pipeline(self, model_name, device="auto"): |
|
"""Get or create a model pipeline""" |
|
try: |
|
model_id = self.models[model_name] |
|
return get_pipeline(model_id, device=device) |
|
except KeyError: |
|
raise ValueError(f"Model {model_name} not found in available models") |
|
|
|
def get_model_id(self, model_name): |
|
"""Get the model ID for a given model name""" |
|
try: |
|
return self.models[model_name] |
|
except KeyError: |
|
raise ValueError(f"Model {model_name} not found in available models") |
|
|
|
|
|
DEVICE = os.getenv("QWEN_DEVICE", "auto") |
|
|
|
|
|
|