Spaces:
Running
Running
""" | |
LLM Provider Factory for Flare | |
""" | |
import os | |
from typing import Optional, Dict, Any | |
from dotenv import load_dotenv | |
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM | |
from config_provider import ConfigProvider | |
from utils import log | |
class LLMFactory: | |
"""Factory class to create appropriate LLM provider based on llm_provider config""" | |
def create_provider() -> LLMInterface: | |
"""Create and return appropriate LLM provider based on config""" | |
cfg = ConfigProvider.get() | |
llm_provider = cfg.global_config.llm_provider | |
if not llm_provider or not llm_provider.name: | |
raise ValueError("No LLM provider configured") | |
provider_name = llm_provider.name | |
log(f"π Creating LLM provider: {provider_name}") | |
# Get provider config | |
provider_config = cfg.global_config.get_provider_config("llm", provider_name) | |
if not provider_config: | |
raise ValueError(f"Unknown LLM provider: {provider_name}") | |
# Get API key | |
api_key = LLMFactory._get_api_key(provider_name) | |
if not api_key and provider_config.requires_api_key: | |
raise ValueError(f"API key required for {provider_name} but not configured") | |
# Get settings | |
settings = llm_provider.settings or {} | |
# Create appropriate provider | |
if provider_name == "spark": | |
return LLMFactory._create_spark_provider(api_key, llm_provider.endpoint, settings) | |
elif provider_name in ("gpt4o", "gpt4o-mini"): | |
return LLMFactory._create_gpt_provider(provider_name, api_key, settings) | |
else: | |
raise ValueError(f"Unsupported LLM provider: {provider_name}") | |
def _create_spark_provider(api_key: str, endpoint: Optional[str], settings: Dict[str, Any]) -> SparkLLM: | |
"""Create Spark LLM provider""" | |
if not endpoint: | |
raise ValueError("Spark requires endpoint to be configured") | |
log(f"π Creating SparkLLM provider") | |
log(f"π Endpoint: {endpoint}") | |
# Determine provider variant for backward compatibility | |
provider_variant = "spark-cloud" | |
if not ConfigProvider.get().global_config.is_cloud_mode(): | |
provider_variant = "spark-onpremise" | |
return SparkLLM( | |
spark_endpoint=str(endpoint), | |
spark_token=api_key, | |
provider_variant=provider_variant, | |
settings=settings | |
) | |
def _create_gpt_provider(model_type: str, api_key: str, settings: Dict[str, Any]) -> GPT4oLLM: | |
"""Create GPT-4o LLM provider""" | |
# Determine model | |
model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o" | |
log(f"π€ Creating GPT4oLLM provider with model: {model}") | |
return GPT4oLLM( | |
api_key=api_key, | |
model=model, | |
settings=settings | |
) | |
def _get_api_key(provider_name: str) -> Optional[str]: | |
"""Get API key from config or environment""" | |
cfg = ConfigProvider.get() | |
# First check encrypted config | |
api_key = cfg.global_config.get_plain_api_key("llm") | |
if api_key: | |
log("π Using decrypted API key from config") | |
return api_key | |
# Then check environment based on provider | |
env_var_map = { | |
"spark": "SPARK_TOKEN", | |
"gpt4o": "OPENAI_API_KEY", | |
"gpt4o-mini": "OPENAI_API_KEY", | |
} | |
env_var = env_var_map.get(provider_name) | |
if env_var: | |
# Check if running in HuggingFace Space | |
if os.environ.get("SPACE_ID"): | |
api_key = os.environ.get(env_var) | |
if api_key: | |
log(f"π Using {env_var} from HuggingFace secrets") | |
return api_key | |
else: | |
# Local development | |
load_dotenv() | |
api_key = os.getenv(env_var) | |
if api_key: | |
log(f"π Using {env_var} from .env file") | |
return api_key | |
return None |