starfish_data_ai / src /starfish /llm /proxy /litellm_adapter.py
John-Jiang's picture
init commit
5301c48
from typing import Any, Dict, List, Optional
import litellm
from litellm.utils import supports_response_schema
from starfish.common.logger import get_logger
from starfish.llm.proxy.litellm_adapter_ext import (
OPENAI_COMPATIBLE_PROVIDERS_CONFIG,
route_openai_compatible_request,
)
logger = get_logger(__name__)
def get_available_models() -> List[str]:
"""Returns a list of all available models from litellm.
Returns:
List[str]: A list of valid model names.
"""
available_models = litellm.utils.get_valid_models()
return available_models
def build_chat_messages(user_instruction: str, system_prompt: Optional[str] = None) -> List[Dict[str, str]]:
"""Constructs a list of chat messages for the LLM.
Args:
user_input (str): The input message from the user.
system_prompt (str, optional): An optional system prompt to guide the conversation.
Returns:
List[Dict[str, str]]: A list of message dictionaries formatted for the chat API.
"""
messages: List[Dict[str, str]] = []
if system_prompt:
# Add the system prompt to the messages if provided
messages.append({"role": "system", "content": system_prompt})
# Add the user's message to the messages
messages.append({"role": "user", "content": user_instruction})
return messages
"""
Model router for directing requests to appropriate model backends
"""
import shutil
from typing import Dict, List, Optional
from starfish.common.logger import get_logger
logger = get_logger(__name__)
# Installation guides by platform
OLLAMA_INSTALLATION_GUIDE = """
Ollama is not installed. Please install it:
Mac:
curl -fsSL https://ollama.com/install.sh | sh
Linux:
curl -fsSL https://ollama.com/install.sh | sh
Windows:
Download the installer from https://ollama.com/download
After installation, restart your application.
"""
async def route_ollama_request(model_name: str, messages: List[Dict[str, str]], model_kwargs: Dict[str, Any]) -> Any:
"""Handle Ollama-specific model requests.
Args:
model_name: The full model name (e.g., "ollama/llama3")
messages: The messages to send to the model
model_kwargs: Additional keyword arguments for the model
Returns:
The response from the Ollama model
"""
from starfish.llm.backend.ollama_adapter import OllamaError, ensure_model_ready
# Extract the actual model name
ollama_model_name = model_name.split("/", 1)[1]
try:
# Check if Ollama is installed
ollama_bin = shutil.which("ollama")
if not ollama_bin:
logger.error("Ollama is not installed")
raise OllamaError(f"Ollama is not installed.\n{OLLAMA_INSTALLATION_GUIDE}")
# Ensure the Ollama model is ready before using
logger.info(f"Ensuring Ollama model {ollama_model_name} is ready...")
model_ready = await ensure_model_ready(ollama_model_name)
if not model_ready:
error_msg = f"Failed to provision Ollama model: {ollama_model_name}"
logger.error(error_msg)
raise RuntimeError(error_msg)
# The model is ready, make the API call
logger.info(f"Model {ollama_model_name} is ready, making API call...")
return await litellm.acompletion(model=model_name, messages=messages, **model_kwargs)
except OllamaError as e:
error_msg = f"Ollama error: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
except Exception as e:
error_msg = f"Unexpected error with Ollama model: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def route_huggingface_request(model_name: str, messages: List[Dict[str, str]], model_kwargs: Dict[str, Any]) -> Any:
"""Handle HuggingFace model requests by importing into Ollama.
Args:
model_name: The full model name (e.g., "hf/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
messages: The messages to send to the model
model_kwargs: Additional keyword arguments for the model
"""
from starfish.llm.model_hub.huggingface_adapter import ensure_hf_model_ready
# Extract the HuggingFace model ID (everything after "hf/")
hf_model_id = model_name.split("/", 1)[1]
try:
# Ensure the HuggingFace model is ready in Ollama
logger.info(f"Ensuring HuggingFace model {hf_model_id} is ready...")
success, ollama_model_name = await ensure_hf_model_ready(hf_model_id)
if not success:
error_msg = f"Failed to provision HuggingFace model: {hf_model_id}. {ollama_model_name}"
logger.error(error_msg)
raise RuntimeError(error_msg)
# The model is ready in Ollama, make the API call using the Ollama endpoint
logger.info(f"Model {hf_model_id} is ready as Ollama model {ollama_model_name}, making API call...")
return await litellm.acompletion(model=f"ollama/{ollama_model_name}", messages=messages, **model_kwargs)
except Exception as e:
error_msg = f"HuggingFace error: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def call_chat_model(model_name: str, messages: List[Dict[str, str]], model_kwargs: Optional[Dict[str, Any]] = None) -> Any:
"""Routes the model request:
1. Checks OpenAI compatible providers defined in litellm_adapter_ext.py.
2. Checks specific handlers (Ollama, HuggingFace).
3. Defaults to standard LiteLLM call.
"""
model_kwargs = model_kwargs or {}
model_prefix = model_name.split("/", 1)[0] if "/" in model_name else None
try:
if model_prefix and model_prefix in OPENAI_COMPATIBLE_PROVIDERS_CONFIG:
config = OPENAI_COMPATIBLE_PROVIDERS_CONFIG[model_prefix]
return await route_openai_compatible_request(model_prefix, config, model_name, messages, model_kwargs)
# Route based on model prefix
elif model_name.startswith("ollama/"):
# Direct Ollama model
return await route_ollama_request(model_name, messages, model_kwargs)
elif model_name.startswith("hf/"):
# HuggingFace models are served through Ollama, but with a different prefix
# This allows using HF models directly without downloading them first
# These are not actually directly using HF API - they're served through Ollama
return await route_huggingface_request(model_name, messages, model_kwargs)
else:
# Default case - use litellm directly
try:
return await litellm.acompletion(model=model_name, messages=messages, **model_kwargs)
except Exception as e:
logger.error(f"LiteLLM error: {str(e)}")
raise RuntimeError(f"Error executing model {model_name}: {str(e)}")
except Exception as e:
logger.exception(f"Error in execute_chat_completion for model {model_name}")
raise RuntimeError(f"Error executing model {model_name}: {str(e)}")
async def build_and_call_chat_model(
model_name: str, user_instruction: str, system_prompt: Optional[str] = None, model_kwargs: Optional[Dict[str, Any]] = None
) -> Any:
"""A convenience function that combines constructing chat messages and executing chat completion.
Args:
model_name (str): The name of the model to use for chat completion.
user_instruction (str): The input message from the user.
system_prompt (Optional[str], optional): An optional system prompt to guide the conversation. Defaults to None.
model_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the model. Defaults to None.
Returns:
Any: The response from the chat completion API.
"""
# Construct the messages
messages = build_chat_messages(user_instruction, system_prompt)
# Execute the chat completion
return await call_chat_model(model_name, messages, model_kwargs)
def is_response_schema_supported(model_name: str) -> bool:
"""Check if a model supports the response_schema parameter.
Args:
model_name: Name of the model to check
Returns:
bool: True if the model supports response_schema, False otherwise
"""
try:
# Use litellm's native function to check if model supports json_schema
return supports_response_schema(model=model_name)
except Exception as e:
logger.warning(f"Error checking response schema support: {e}")
return False