flare / llm_factory.py
ciyidogan's picture
Update llm_factory.py
394611c verified
raw
history blame
4.27 kB
"""
LLM Provider Factory for Flare
"""
import os
from typing import Optional, Dict, Any
from dotenv import load_dotenv
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
from config_provider import ConfigProvider
from utils import log
class LLMFactory:
"""Factory class to create appropriate LLM provider based on llm_provider config"""
@staticmethod
def create_provider() -> LLMInterface:
"""Create and return appropriate LLM provider based on config"""
cfg = ConfigProvider.get()
llm_provider = cfg.global_config.llm_provider
if not llm_provider or not llm_provider.name:
raise ValueError("No LLM provider configured")
provider_name = llm_provider.name
log(f"🏭 Creating LLM provider: {provider_name}")
# Get provider config
provider_config = cfg.global_config.get_provider_config("llm", provider_name)
if not provider_config:
raise ValueError(f"Unknown LLM provider: {provider_name}")
# Get API key
api_key = LLMFactory._get_api_key(provider_name)
if not api_key and provider_config.requires_api_key:
raise ValueError(f"API key required for {provider_name} but not configured")
# Get settings
settings = llm_provider.settings or {}
# Create appropriate provider
if provider_name == "spark":
return LLMFactory._create_spark_provider(api_key, llm_provider.endpoint, settings)
elif provider_name in ("gpt4o", "gpt4o-mini"):
return LLMFactory._create_gpt_provider(provider_name, api_key, settings)
else:
raise ValueError(f"Unsupported LLM provider: {provider_name}")
@staticmethod
def _create_spark_provider(api_key: str, endpoint: Optional[str], settings: Dict[str, Any]) -> SparkLLM:
"""Create Spark LLM provider"""
if not endpoint:
raise ValueError("Spark requires endpoint to be configured")
log(f"πŸš€ Creating SparkLLM provider")
log(f"πŸ“ Endpoint: {endpoint}")
# Determine provider variant for backward compatibility
provider_variant = "spark-cloud"
if not ConfigProvider.get().global_config.is_cloud_mode():
provider_variant = "spark-onpremise"
return SparkLLM(
spark_endpoint=str(endpoint),
spark_token=api_key,
provider_variant=provider_variant,
settings=settings
)
@staticmethod
def _create_gpt_provider(model_type: str, api_key: str, settings: Dict[str, Any]) -> GPT4oLLM:
"""Create GPT-4o LLM provider"""
# Determine model
model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
log(f"πŸ€– Creating GPT4oLLM provider with model: {model}")
return GPT4oLLM(
api_key=api_key,
model=model,
settings=settings
)
@staticmethod
def _get_api_key(provider_name: str) -> Optional[str]:
"""Get API key from config or environment"""
cfg = ConfigProvider.get()
# First check encrypted config
api_key = cfg.global_config.get_plain_api_key("llm")
if api_key:
log("πŸ”‘ Using decrypted API key from config")
return api_key
# Then check environment based on provider
env_var_map = {
"spark": "SPARK_TOKEN",
"gpt4o": "OPENAI_API_KEY",
"gpt4o-mini": "OPENAI_API_KEY",
}
env_var = env_var_map.get(provider_name)
if env_var:
# Check if running in HuggingFace Space
if os.environ.get("SPACE_ID"):
api_key = os.environ.get(env_var)
if api_key:
log(f"πŸ”‘ Using {env_var} from HuggingFace secrets")
return api_key
else:
# Local development
load_dotenv()
api_key = os.getenv(env_var)
if api_key:
log(f"πŸ”‘ Using {env_var} from .env file")
return api_key
return None