File size: 4,187 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# starfish/adapters/litellm_adapter_ext.py

import os
from typing import Any, Dict, List

import litellm

from starfish.common.logger import get_logger

logger = get_logger(__name__)

# --- Configuration ---
# Convention:
# Keys = LiteLLM parameter names.
# Values:
#  - "$ENV_VAR_NAME" -> Fetches value DIRECTLY from os.getenv("ENV_VAR_NAME").
#                      The env var MUST contain the final desired value (e.g., "Bearer sk-...").
#  - Other           -> Literal static value (int, bool, string not starting with $).

OPENAI_COMPATIBLE_PROVIDERS_CONFIG: Dict[str, Dict[str, Any]] = {
    "hyperbolic": {
        "api_base": "$HYPERBOLIC_API_BASE",  # Env var HYPERBOLIC_API_BASE holds the URL
        "headers": {
            # Env var HYPERBOLIC_API_KEY MUST be set to "Bearer sk-..."
            "Authorization": "$HYPERBOLIC_API_KEY"
        },
    },
    # Add more providers following this ultra-simple convention
}


def _resolve_config_value(value: Any, description: str) -> Any:
    """Resolves a configuration value based on the '$' convention.
    '$VAR_NAME' -> os.getenv('VAR_NAME')
    Other -> literal value.
    """
    if isinstance(value, str) and value.startswith("$"):
        # Environment Variable lookup: $VAR_NAME
        env_var_name = value[1:]
        if not env_var_name:  # Handle edge case of just "$"
            raise ValueError(f"Invalid environment variable specification '$' ({description}).")

        env_var_value = os.getenv(env_var_name)
        if env_var_value is None:
            # Keep the description for helpful error messages
            error_msg = f"Required environment variable '{env_var_name}' not set ({description})."
            logger.error(error_msg)
            raise RuntimeError(error_msg)
        return env_var_value  # Return the exact value from env var
    else:
        # It's a literal static value (int, bool, string not starting with $)
        return value


async def route_openai_compatible_request(
    provider_prefix: str,
    provider_config: Dict[str, Any],  # Assumes config dict is valid
    model_name: str,
    messages: List[Dict[str, str]],
    model_kwargs: Dict[str, Any],
) -> Any:
    """Handles requests for configured OpenAI-compatible providers using the
    simple '$' convention for environment variable lookup.
    """
    litellm_call_kwargs = {}
    resolved_headers = {}

    # Iterate directly through configured LiteLLM parameters
    for param_name, config_value in provider_config.items():
        param_desc = f"provider '{provider_prefix}', param '{param_name}'"
        try:
            if param_name == "headers":
                if not isinstance(config_value, dict):
                    logger.warning(f"Headers config is not a dictionary {param_desc}. Skipping.")
                    continue
                # Resolve each header value using the simple convention
                for header_key, header_config_value in config_value.items():
                    header_desc = f"provider '{provider_prefix}', header '{header_key}'"
                    resolved_headers[header_key] = _resolve_config_value(header_config_value, header_desc)
            else:
                # Resolve other parameter values using the simple convention
                litellm_call_kwargs[param_name] = _resolve_config_value(config_value, param_desc)
        except (RuntimeError, ValueError):  # Catch env var missing or invalid '$'
            raise  # Propagate critical errors immediately
        except Exception as e:
            # Catch any unexpected errors during resolution
            logger.error(f"Unexpected error resolving value {param_desc}: {e}", exc_info=True)
            raise RuntimeError(f"Error resolving value {param_desc}") from e

    if resolved_headers:
        litellm_call_kwargs["headers"] = resolved_headers  # Add resolved headers if any

    # --- Parameter construction is complete ---

    model_id = model_name.split("/", 1)[1]

    # Merge function kwargs, overriding config/resolved values
    final_kwargs = {**litellm_call_kwargs, **model_kwargs}

    return await litellm.acompletion(model="openai/" + model_id, messages=messages, **final_kwargs)