File size: 8,534 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from typing import Any, Dict, List, Optional

import litellm
from litellm.utils import supports_response_schema

from starfish.common.logger import get_logger
from starfish.llm.proxy.litellm_adapter_ext import (
    OPENAI_COMPATIBLE_PROVIDERS_CONFIG,
    route_openai_compatible_request,
)

logger = get_logger(__name__)


def get_available_models() -> List[str]:
    """Returns a list of all available models from litellm.

    Returns:
        List[str]: A list of valid model names.
    """
    available_models = litellm.utils.get_valid_models()
    return available_models


def build_chat_messages(user_instruction: str, system_prompt: Optional[str] = None) -> List[Dict[str, str]]:
    """Constructs a list of chat messages for the LLM.

    Args:
        user_input (str): The input message from the user.
        system_prompt (str, optional): An optional system prompt to guide the conversation.

    Returns:
        List[Dict[str, str]]: A list of message dictionaries formatted for the chat API.
    """
    messages: List[Dict[str, str]] = []
    if system_prompt:
        # Add the system prompt to the messages if provided
        messages.append({"role": "system", "content": system_prompt})
    # Add the user's message to the messages
    messages.append({"role": "user", "content": user_instruction})
    return messages


"""
Model router for directing requests to appropriate model backends
"""
import shutil
from typing import Dict, List, Optional

from starfish.common.logger import get_logger

logger = get_logger(__name__)

# Installation guides by platform
OLLAMA_INSTALLATION_GUIDE = """
Ollama is not installed. Please install it:

Mac:
curl -fsSL https://ollama.com/install.sh | sh

Linux:
curl -fsSL https://ollama.com/install.sh | sh

Windows:
Download the installer from https://ollama.com/download

After installation, restart your application.
"""


async def route_ollama_request(model_name: str, messages: List[Dict[str, str]], model_kwargs: Dict[str, Any]) -> Any:
    """Handle Ollama-specific model requests.

    Args:
        model_name: The full model name (e.g., "ollama/llama3")
        messages: The messages to send to the model
        model_kwargs: Additional keyword arguments for the model

    Returns:
        The response from the Ollama model
    """
    from starfish.llm.backend.ollama_adapter import OllamaError, ensure_model_ready

    # Extract the actual model name
    ollama_model_name = model_name.split("/", 1)[1]

    try:
        # Check if Ollama is installed
        ollama_bin = shutil.which("ollama")
        if not ollama_bin:
            logger.error("Ollama is not installed")
            raise OllamaError(f"Ollama is not installed.\n{OLLAMA_INSTALLATION_GUIDE}")

        # Ensure the Ollama model is ready before using
        logger.info(f"Ensuring Ollama model {ollama_model_name} is ready...")
        model_ready = await ensure_model_ready(ollama_model_name)

        if not model_ready:
            error_msg = f"Failed to provision Ollama model: {ollama_model_name}"
            logger.error(error_msg)
            raise RuntimeError(error_msg)

        # The model is ready, make the API call
        logger.info(f"Model {ollama_model_name} is ready, making API call...")
        return await litellm.acompletion(model=model_name, messages=messages, **model_kwargs)

    except OllamaError as e:
        error_msg = f"Ollama error: {str(e)}"
        logger.error(error_msg)
        raise RuntimeError(error_msg)
    except Exception as e:
        error_msg = f"Unexpected error with Ollama model: {str(e)}"
        logger.error(error_msg)
        raise RuntimeError(error_msg)


async def route_huggingface_request(model_name: str, messages: List[Dict[str, str]], model_kwargs: Dict[str, Any]) -> Any:
    """Handle HuggingFace model requests by importing into Ollama.

    Args:
        model_name: The full model name (e.g., "hf/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
        messages: The messages to send to the model
        model_kwargs: Additional keyword arguments for the model
    """
    from starfish.llm.model_hub.huggingface_adapter import ensure_hf_model_ready

    # Extract the HuggingFace model ID (everything after "hf/")
    hf_model_id = model_name.split("/", 1)[1]

    try:
        # Ensure the HuggingFace model is ready in Ollama
        logger.info(f"Ensuring HuggingFace model {hf_model_id} is ready...")
        success, ollama_model_name = await ensure_hf_model_ready(hf_model_id)

        if not success:
            error_msg = f"Failed to provision HuggingFace model: {hf_model_id}. {ollama_model_name}"
            logger.error(error_msg)
            raise RuntimeError(error_msg)

        # The model is ready in Ollama, make the API call using the Ollama endpoint
        logger.info(f"Model {hf_model_id} is ready as Ollama model {ollama_model_name}, making API call...")
        return await litellm.acompletion(model=f"ollama/{ollama_model_name}", messages=messages, **model_kwargs)

    except Exception as e:
        error_msg = f"HuggingFace error: {str(e)}"
        logger.error(error_msg)
        raise RuntimeError(error_msg)


async def call_chat_model(model_name: str, messages: List[Dict[str, str]], model_kwargs: Optional[Dict[str, Any]] = None) -> Any:
    """Routes the model request:
    1. Checks OpenAI compatible providers defined in litellm_adapter_ext.py.
    2. Checks specific handlers (Ollama, HuggingFace).
    3. Defaults to standard LiteLLM call.
    """
    model_kwargs = model_kwargs or {}
    model_prefix = model_name.split("/", 1)[0] if "/" in model_name else None

    try:
        if model_prefix and model_prefix in OPENAI_COMPATIBLE_PROVIDERS_CONFIG:
            config = OPENAI_COMPATIBLE_PROVIDERS_CONFIG[model_prefix]
            return await route_openai_compatible_request(model_prefix, config, model_name, messages, model_kwargs)

        # Route based on model prefix
        elif model_name.startswith("ollama/"):
            # Direct Ollama model
            return await route_ollama_request(model_name, messages, model_kwargs)

        elif model_name.startswith("hf/"):
            # HuggingFace models are served through Ollama, but with a different prefix
            # This allows using HF models directly without downloading them first
            # These are not actually directly using HF API - they're served through Ollama
            return await route_huggingface_request(model_name, messages, model_kwargs)

        else:
            # Default case - use litellm directly
            try:
                return await litellm.acompletion(model=model_name, messages=messages, **model_kwargs)
            except Exception as e:
                logger.error(f"LiteLLM error: {str(e)}")
                raise RuntimeError(f"Error executing model {model_name}: {str(e)}")
    except Exception as e:
        logger.exception(f"Error in execute_chat_completion for model {model_name}")
        raise RuntimeError(f"Error executing model {model_name}: {str(e)}")


async def build_and_call_chat_model(
    model_name: str, user_instruction: str, system_prompt: Optional[str] = None, model_kwargs: Optional[Dict[str, Any]] = None
) -> Any:
    """A convenience function that combines constructing chat messages and executing chat completion.

    Args:
        model_name (str): The name of the model to use for chat completion.
        user_instruction (str): The input message from the user.
        system_prompt (Optional[str], optional): An optional system prompt to guide the conversation. Defaults to None.
        model_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the model. Defaults to None.

    Returns:
        Any: The response from the chat completion API.
    """
    # Construct the messages
    messages = build_chat_messages(user_instruction, system_prompt)

    # Execute the chat completion
    return await call_chat_model(model_name, messages, model_kwargs)


def is_response_schema_supported(model_name: str) -> bool:
    """Check if a model supports the response_schema parameter.

    Args:
        model_name: Name of the model to check

    Returns:
        bool: True if the model supports response_schema, False otherwise
    """
    try:
        # Use litellm's native function to check if model supports json_schema
        return supports_response_schema(model=model_name)
    except Exception as e:
        logger.warning(f"Error checking response schema support: {e}")
        return False