starfish_data_ai / src /starfish /llm /proxy /litellm_adapter_ext.py
John-Jiang's picture
init commit
5301c48
# starfish/adapters/litellm_adapter_ext.py
import os
from typing import Any, Dict, List
import litellm
from starfish.common.logger import get_logger
logger = get_logger(__name__)
# --- Configuration ---
# Convention:
# Keys = LiteLLM parameter names.
# Values:
# - "$ENV_VAR_NAME" -> Fetches value DIRECTLY from os.getenv("ENV_VAR_NAME").
# The env var MUST contain the final desired value (e.g., "Bearer sk-...").
# - Other -> Literal static value (int, bool, string not starting with $).
OPENAI_COMPATIBLE_PROVIDERS_CONFIG: Dict[str, Dict[str, Any]] = {
"hyperbolic": {
"api_base": "$HYPERBOLIC_API_BASE", # Env var HYPERBOLIC_API_BASE holds the URL
"headers": {
# Env var HYPERBOLIC_API_KEY MUST be set to "Bearer sk-..."
"Authorization": "$HYPERBOLIC_API_KEY"
},
},
# Add more providers following this ultra-simple convention
}
def _resolve_config_value(value: Any, description: str) -> Any:
"""Resolves a configuration value based on the '$' convention.
'$VAR_NAME' -> os.getenv('VAR_NAME')
Other -> literal value.
"""
if isinstance(value, str) and value.startswith("$"):
# Environment Variable lookup: $VAR_NAME
env_var_name = value[1:]
if not env_var_name: # Handle edge case of just "$"
raise ValueError(f"Invalid environment variable specification '$' ({description}).")
env_var_value = os.getenv(env_var_name)
if env_var_value is None:
# Keep the description for helpful error messages
error_msg = f"Required environment variable '{env_var_name}' not set ({description})."
logger.error(error_msg)
raise RuntimeError(error_msg)
return env_var_value # Return the exact value from env var
else:
# It's a literal static value (int, bool, string not starting with $)
return value
async def route_openai_compatible_request(
provider_prefix: str,
provider_config: Dict[str, Any], # Assumes config dict is valid
model_name: str,
messages: List[Dict[str, str]],
model_kwargs: Dict[str, Any],
) -> Any:
"""Handles requests for configured OpenAI-compatible providers using the
simple '$' convention for environment variable lookup.
"""
litellm_call_kwargs = {}
resolved_headers = {}
# Iterate directly through configured LiteLLM parameters
for param_name, config_value in provider_config.items():
param_desc = f"provider '{provider_prefix}', param '{param_name}'"
try:
if param_name == "headers":
if not isinstance(config_value, dict):
logger.warning(f"Headers config is not a dictionary {param_desc}. Skipping.")
continue
# Resolve each header value using the simple convention
for header_key, header_config_value in config_value.items():
header_desc = f"provider '{provider_prefix}', header '{header_key}'"
resolved_headers[header_key] = _resolve_config_value(header_config_value, header_desc)
else:
# Resolve other parameter values using the simple convention
litellm_call_kwargs[param_name] = _resolve_config_value(config_value, param_desc)
except (RuntimeError, ValueError): # Catch env var missing or invalid '$'
raise # Propagate critical errors immediately
except Exception as e:
# Catch any unexpected errors during resolution
logger.error(f"Unexpected error resolving value {param_desc}: {e}", exc_info=True)
raise RuntimeError(f"Error resolving value {param_desc}") from e
if resolved_headers:
litellm_call_kwargs["headers"] = resolved_headers # Add resolved headers if any
# --- Parameter construction is complete ---
model_id = model_name.split("/", 1)[1]
# Merge function kwargs, overriding config/resolved values
final_kwargs = {**litellm_call_kwargs, **model_kwargs}
return await litellm.acompletion(model="openai/" + model_id, messages=messages, **final_kwargs)