File size: 2,956 Bytes
8483978
 
0b7ba67
60f0153
0b7ba67
 
 
 
60f0153
 
 
 
 
 
 
 
0b7ba67
60f0153
0b7ba67
 
 
 
 
 
 
 
60f0153
0b7ba67
 
 
 
60f0153
0b7ba67
 
 
 
 
60f0153
0b7ba67
 
 
 
60f0153
0b7ba67
 
 
 
 
8483978
 
 
 
60f0153
8483978
0b7ba67
60f0153
8483978
 
 
60f0153
0b7ba67
 
 
 
 
 
 
8483978
60f0153
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from transformers import pipeline
import logging
from functools import lru_cache
import os

# Global cache for pipelines to ensure they're initialized only once
_PIPELINE_CACHE = {}

MODELS = {
    "Qwen3-0.6B":    {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."},
    "Qwen3-8B":      {"repo_id":"Qwen/Qwen3-8B","description":"Dense causal language model with 8.2 B total parameters (6.95 B non-embedding), 36 layers, 32 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), excels at multilingual instruction following & zero-shot tasks."},
    "Qwen3-14B":     {"repo_id":"Qwen/Qwen3-14B","description":"Dense causal language model with 14.8 B total parameters (13.2 B non-embedding), 40 layers, 40 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), enhanced human preference alignment & advanced agent integration."},


}

@lru_cache(maxsize=5)
def get_pipeline(model_id, task="text-generation", device="auto"):
    """
    Get or create a model pipeline with caching.
    
    This function is cached using lru_cache to ensure efficient reuse.
    
    Args:
        model_id (str): The Hugging Face model ID
        task (str): The pipeline task (default: "text-generation")
        device (str): The device to use for execution (default: "auto")
        
    Returns:
        The pipeline object
    """
    cache_key = f"{model_id}_{task}_{device}"
    
    if cache_key in _PIPELINE_CACHE:
        return _PIPELINE_CACHE[cache_key]
    
    logger = logging.getLogger(__name__)
    logger.info(f"Loading model: {model_id} for task: {task} on device: {device}")
    
    pipe = pipeline(
        task,
        model=model_id,
        device_map=device
    )
    
    _PIPELINE_CACHE[cache_key] = pipe
    return pipe


class ModelManager:
    """Manages loading and caching of Qwen models"""
    def __init__(self):
        self.models = {k: v["repo_id"] for k, v in MODELS.items()}
        self.logger = logging.getLogger(__name__)
    
    def get_pipeline(self, model_name, device="auto"):
        """Get or create a model pipeline"""
        try:
            model_id = self.models[model_name]
            return get_pipeline(model_id, device=device)
        except KeyError:
            raise ValueError(f"Model {model_name} not found in available models")
    
    def get_model_id(self, model_name):
        """Get the model ID for a given model name"""
        try:
            return self.models[model_name]
        except KeyError:
            raise ValueError(f"Model {model_name} not found in available models")

# Determine device based on environment variable
DEVICE = os.getenv("QWEN_DEVICE", "auto")

# Example usage
# model_manager.get_pipeline("Qwen3-0.6B", device=DEVICE)