Spaces:
Running
Running
File size: 8,534 Bytes
5301c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
from typing import Any, Dict, List, Optional
import litellm
from litellm.utils import supports_response_schema
from starfish.common.logger import get_logger
from starfish.llm.proxy.litellm_adapter_ext import (
OPENAI_COMPATIBLE_PROVIDERS_CONFIG,
route_openai_compatible_request,
)
logger = get_logger(__name__)
def get_available_models() -> List[str]:
"""Returns a list of all available models from litellm.
Returns:
List[str]: A list of valid model names.
"""
available_models = litellm.utils.get_valid_models()
return available_models
def build_chat_messages(user_instruction: str, system_prompt: Optional[str] = None) -> List[Dict[str, str]]:
"""Constructs a list of chat messages for the LLM.
Args:
user_input (str): The input message from the user.
system_prompt (str, optional): An optional system prompt to guide the conversation.
Returns:
List[Dict[str, str]]: A list of message dictionaries formatted for the chat API.
"""
messages: List[Dict[str, str]] = []
if system_prompt:
# Add the system prompt to the messages if provided
messages.append({"role": "system", "content": system_prompt})
# Add the user's message to the messages
messages.append({"role": "user", "content": user_instruction})
return messages
"""
Model router for directing requests to appropriate model backends
"""
import shutil
from typing import Dict, List, Optional
from starfish.common.logger import get_logger
logger = get_logger(__name__)
# Installation guides by platform
OLLAMA_INSTALLATION_GUIDE = """
Ollama is not installed. Please install it:
Mac:
curl -fsSL https://ollama.com/install.sh | sh
Linux:
curl -fsSL https://ollama.com/install.sh | sh
Windows:
Download the installer from https://ollama.com/download
After installation, restart your application.
"""
async def route_ollama_request(model_name: str, messages: List[Dict[str, str]], model_kwargs: Dict[str, Any]) -> Any:
"""Handle Ollama-specific model requests.
Args:
model_name: The full model name (e.g., "ollama/llama3")
messages: The messages to send to the model
model_kwargs: Additional keyword arguments for the model
Returns:
The response from the Ollama model
"""
from starfish.llm.backend.ollama_adapter import OllamaError, ensure_model_ready
# Extract the actual model name
ollama_model_name = model_name.split("/", 1)[1]
try:
# Check if Ollama is installed
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
raise OllamaError(f"Ollama is not installed.\n{OLLAMA_INSTALLATION_GUIDE}")
# Ensure the Ollama model is ready before using
logger.info(f"Ensuring Ollama model {ollama_model_name} is ready...")
model_ready = await ensure_model_ready(ollama_model_name)
if not model_ready:
error_msg = f"Failed to provision Ollama model: {ollama_model_name}"
logger.error(error_msg)
raise RuntimeError(error_msg)
# The model is ready, make the API call
logger.info(f"Model {ollama_model_name} is ready, making API call...")
return await litellm.acompletion(model=model_name, messages=messages, **model_kwargs)
except OllamaError as e:
error_msg = f"Ollama error: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
except Exception as e:
error_msg = f"Unexpected error with Ollama model: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def route_huggingface_request(model_name: str, messages: List[Dict[str, str]], model_kwargs: Dict[str, Any]) -> Any:
"""Handle HuggingFace model requests by importing into Ollama.
Args:
model_name: The full model name (e.g., "hf/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
messages: The messages to send to the model
model_kwargs: Additional keyword arguments for the model
"""
from starfish.llm.model_hub.huggingface_adapter import ensure_hf_model_ready
# Extract the HuggingFace model ID (everything after "hf/")
hf_model_id = model_name.split("/", 1)[1]
try:
# Ensure the HuggingFace model is ready in Ollama
logger.info(f"Ensuring HuggingFace model {hf_model_id} is ready...")
success, ollama_model_name = await ensure_hf_model_ready(hf_model_id)
if not success:
error_msg = f"Failed to provision HuggingFace model: {hf_model_id}. {ollama_model_name}"
logger.error(error_msg)
raise RuntimeError(error_msg)
# The model is ready in Ollama, make the API call using the Ollama endpoint
logger.info(f"Model {hf_model_id} is ready as Ollama model {ollama_model_name}, making API call...")
return await litellm.acompletion(model=f"ollama/{ollama_model_name}", messages=messages, **model_kwargs)
except Exception as e:
error_msg = f"HuggingFace error: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def call_chat_model(model_name: str, messages: List[Dict[str, str]], model_kwargs: Optional[Dict[str, Any]] = None) -> Any:
"""Routes the model request:
1. Checks OpenAI compatible providers defined in litellm_adapter_ext.py.
2. Checks specific handlers (Ollama, HuggingFace).
3. Defaults to standard LiteLLM call.
"""
model_kwargs = model_kwargs or {}
model_prefix = model_name.split("/", 1)[0] if "/" in model_name else None
try:
if model_prefix and model_prefix in OPENAI_COMPATIBLE_PROVIDERS_CONFIG:
config = OPENAI_COMPATIBLE_PROVIDERS_CONFIG[model_prefix]
return await route_openai_compatible_request(model_prefix, config, model_name, messages, model_kwargs)
# Route based on model prefix
elif model_name.startswith("ollama/"):
# Direct Ollama model
return await route_ollama_request(model_name, messages, model_kwargs)
elif model_name.startswith("hf/"):
# HuggingFace models are served through Ollama, but with a different prefix
# This allows using HF models directly without downloading them first
# These are not actually directly using HF API - they're served through Ollama
return await route_huggingface_request(model_name, messages, model_kwargs)
else:
# Default case - use litellm directly
try:
return await litellm.acompletion(model=model_name, messages=messages, **model_kwargs)
except Exception as e:
logger.error(f"LiteLLM error: {str(e)}")
raise RuntimeError(f"Error executing model {model_name}: {str(e)}")
except Exception as e:
logger.exception(f"Error in execute_chat_completion for model {model_name}")
raise RuntimeError(f"Error executing model {model_name}: {str(e)}")
async def build_and_call_chat_model(
model_name: str, user_instruction: str, system_prompt: Optional[str] = None, model_kwargs: Optional[Dict[str, Any]] = None
) -> Any:
"""A convenience function that combines constructing chat messages and executing chat completion.
Args:
model_name (str): The name of the model to use for chat completion.
user_instruction (str): The input message from the user.
system_prompt (Optional[str], optional): An optional system prompt to guide the conversation. Defaults to None.
model_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the model. Defaults to None.
Returns:
Any: The response from the chat completion API.
"""
# Construct the messages
messages = build_chat_messages(user_instruction, system_prompt)
# Execute the chat completion
return await call_chat_model(model_name, messages, model_kwargs)
def is_response_schema_supported(model_name: str) -> bool:
"""Check if a model supports the response_schema parameter.
Args:
model_name: Name of the model to check
Returns:
bool: True if the model supports response_schema, False otherwise
"""
try:
# Use litellm's native function to check if model supports json_schema
return supports_response_schema(model=model_name)
except Exception as e:
logger.warning(f"Error checking response schema support: {e}")
return False
|