File size: 3,429 Bytes
b352179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from typing import Optional, Any, Literal
from smolagents.models import OpenAIServerModel, InferenceClientModel

class SmoLModelManager:
    """A class to create and manage SmoLAgents model instances with different client backends."""
    
    def __init__(self, 
                model_id: str,
                model_client: Literal["openai", "inference"] = "openai",
                api_key: Optional[str] = None, 
                api_base: Optional[str] = None):
        """
        Initialize with model configuration parameters.
        
        Args:
            model_id: The model identifier to use
            model_client: The client backend to use ("openai" or "inference")
            api_key: The API key (will check environment variable if None)
            api_base: The API base URL (will check environment variable if None)
        """
        if not model_id:
            raise ValueError("model_id cannot be empty")
            
        self.model_id = model_id
        
        if model_client not in ["openai", "inference"]:
            raise ValueError("model_client must be either 'openai' or 'inference'")
        
        self.model_client = model_client
        
        # Set client-specific environment variable names
        if model_client == "openai":
            self.api_key_env = "OPENROUTER_API_KEY"
            self.api_base_env = "OPENROUTER_BASE_URL"
            self.default_api_base = "https://openrouter.ai/api/v1"
        else:  # inference
            self.api_key_env = "INFERENCE_API_KEY"
        
        # Store API credentials
        self.api_key = api_key or os.getenv(self.api_key_env)
        self.api_base = api_base or os.getenv(self.api_base_env, self.default_api_base)
    
    def create_model(self) -> Any:
        """
        Create and return the appropriate model instance based on model_client.
        
        Returns:
            The configured model instance or None if creation fails
            
        Note:
            This method catches exceptions to prevent app crashes
        """
        # Validate API key is available
        if not self.api_key:
            print(f"Warning: No API key provided and {self.api_key_env} environment variable not set")
            return None
            
        try:
            if self.model_client == "openai":
                return self._create_openai_model()
            else:  # inference
                return self._create_inference_model()
        except Exception as e:
            print(f"Error creating model: {str(e)}")
            return None
    
    def _create_openai_model(self) -> Any:
        """Create an OpenAIServerModel instance."""
        try:
           
            return OpenAIServerModel(
                model_id=self.model_id,
                api_base=self.api_base,
                api_key=self.api_key
            )
        except ImportError:
            print("Failed to import OpenAIServerModel. Please ensure smolagents is installed.")
            return None
    
    def _create_inference_model(self) -> Any:
        """Create an InferenceClientModel instance."""
        try:            
            return InferenceClientModel(
                model=self.model_id,
                api_key=self.api_key
            )
        except ImportError:
            print("Failed to import InferenceClientModel. Please ensure smolagents is installed.")
            return None