from transformers import pipeline import logging from functools import lru_cache import os # Global cache for pipelines to ensure they're initialized only once _PIPELINE_CACHE = {} MODELS = { "Qwen3-0.6B": {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."}, "Qwen3-8B": {"repo_id":"Qwen/Qwen3-8B","description":"Dense causal language model with 8.2 B total parameters (6.95 B non-embedding), 36 layers, 32 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), excels at multilingual instruction following & zero-shot tasks."}, "Qwen3-14B": {"repo_id":"Qwen/Qwen3-14B","description":"Dense causal language model with 14.8 B total parameters (13.2 B non-embedding), 40 layers, 40 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), enhanced human preference alignment & advanced agent integration."}, } @lru_cache(maxsize=5) def get_pipeline(model_id, task="text-generation", device="auto"): """ Get or create a model pipeline with caching. This function is cached using lru_cache to ensure efficient reuse. Args: model_id (str): The Hugging Face model ID task (str): The pipeline task (default: "text-generation") device (str): The device to use for execution (default: "auto") Returns: The pipeline object """ cache_key = f"{model_id}_{task}_{device}" if cache_key in _PIPELINE_CACHE: return _PIPELINE_CACHE[cache_key] logger = logging.getLogger(__name__) logger.info(f"Loading model: {model_id} for task: {task} on device: {device}") pipe = pipeline( task, model=model_id, device_map=device ) _PIPELINE_CACHE[cache_key] = pipe return pipe class ModelManager: """Manages loading and caching of Qwen models""" def __init__(self): self.models = {k: v["repo_id"] for k, v in MODELS.items()} self.logger = logging.getLogger(__name__) def get_pipeline(self, model_name, device="auto"): """Get or create a model pipeline""" try: model_id = self.models[model_name] return get_pipeline(model_id, device=device) except KeyError: raise ValueError(f"Model {model_name} not found in available models") def get_model_id(self, model_name): """Get the model ID for a given model name""" try: return self.models[model_name] except KeyError: raise ValueError(f"Model {model_name} not found in available models") # Determine device based on environment variable DEVICE = os.getenv("QWEN_DEVICE", "auto") # Example usage # model_manager.get_pipeline("Qwen3-0.6B", device=DEVICE)