flare / llm_factory.py
ciyidogan's picture
Update llm_factory.py
b9b2b1e verified
raw
history blame
4.11 kB
"""
LLM Provider Factory for Flare
"""
import os
from typing import Optional
from dotenv import load_dotenv
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
from config_provider import ConfigProvider
from utils import log
class LLMFactory:
"""Factory class to create appropriate LLM provider based on configuration"""
@staticmethod
def create_provider() -> LLMInterface:
"""Create and return appropriate LLM provider based on config"""
cfg = ConfigProvider.get()
llm_config = cfg.global_config.llm_provider
if not llm_config:
raise ValueError("No LLM provider configured")
provider_name = llm_config.name
log(f"🏭 Creating LLM provider: {provider_name}")
# Get provider definition
provider_def = cfg.global_config.get_provider_config("llm", provider_name)
if not provider_def:
raise ValueError(f"Unknown LLM provider: {provider_name}")
# Get API key
api_key = LLMFactory._get_api_key(provider_name)
if not api_key and provider_def.requires_api_key:
raise ValueError(f"API key required for {provider_name} but not configured")
# Get endpoint
endpoint = llm_config.endpoint
if not endpoint and provider_def.requires_endpoint:
raise ValueError(f"Endpoint required for {provider_name} but not configured")
# Create appropriate provider
if provider_name in ("spark", "spark_cloud", "spark_onpremise"):
return LLMFactory._create_spark_provider(provider_name, api_key, endpoint, llm_config.settings)
elif provider_name in ("gpt4o", "gpt4o-mini"):
return LLMFactory._create_gpt_provider(provider_name, api_key, llm_config.settings)
else:
raise ValueError(f"Unsupported LLM provider: {provider_name}")
@staticmethod
def _create_spark_provider(provider_name: str, api_key: str, endpoint: str, settings: dict) -> SparkLLM:
"""Create Spark LLM provider"""
log(f"πŸš€ Creating SparkLLM provider: {provider_name}")
log(f"πŸ“ Endpoint: {endpoint}")
return SparkLLM(
spark_endpoint=endpoint,
spark_token=api_key,
provider_variant=provider_name,
settings=settings
)
@staticmethod
def _create_gpt_provider(model_type: str, api_key: str, settings: dict) -> GPT4oLLM:
"""Create GPT-4o LLM provider"""
# Determine model
model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
log(f"πŸ€– Creating GPT4oLLM provider with model: {model}")
return GPT4oLLM(
api_key=api_key,
model=model,
settings=settings
)
@staticmethod
def _get_api_key(provider_name: str) -> Optional[str]:
"""Get API key from config or environment"""
cfg = ConfigProvider.get()
# First check encrypted config
api_key = cfg.global_config.get_plain_api_key("llm")
if api_key:
log("πŸ”‘ Using decrypted API key from config")
return api_key
# Then check environment based on provider
env_var_map = {
"spark": "SPARK_TOKEN",
"spark_cloud": "SPARK_TOKEN",
"spark_onpremise": "SPARK_TOKEN",
"gpt4o": "OPENAI_API_KEY",
"gpt4o-mini": "OPENAI_API_KEY",
}
env_var = env_var_map.get(provider_name)
if env_var:
# Check if running in HuggingFace Space
if os.environ.get("SPACE_ID"):
api_key = os.environ.get(env_var)
if api_key:
log(f"πŸ”‘ Using {env_var} from HuggingFace secrets")
else:
# Local/on-premise deployment
load_dotenv()
api_key = os.getenv(env_var)
if api_key:
log(f"πŸ”‘ Using {env_var} from .env file")
return api_key