Spaces:
Running
Running
""" | |
LLM Provider Factory for Flare | |
""" | |
import os | |
from typing import Optional | |
from dotenv import load_dotenv | |
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM | |
from config_provider import ConfigProvider | |
from utils import log | |
class LLMFactory: | |
"""Factory class to create appropriate LLM provider based on llm_provider config""" | |
def create_provider() -> LLMInterface: | |
"""Create and return appropriate LLM provider based on config""" | |
cfg = ConfigProvider.get() | |
llm_provider = cfg.global_config.llm_provider | |
log(f"π Creating LLM provider: {llm_provider}") | |
# Get provider config | |
provider_config = cfg.global_config.get_llm_provider_config() | |
if not provider_config: | |
raise ValueError(f"Unknown LLM provider: {llm_provider}") | |
# Get API key | |
api_key = LLMFactory._get_api_key() | |
if not api_key and provider_config.requires_api_key: | |
raise ValueError(f"API key required for {llm_provider} but not configured") | |
# Create appropriate provider | |
if llm_provider == "spark": | |
return LLMFactory._create_spark_provider(api_key) | |
elif llm_provider in ("gpt4o", "gpt4o-mini"): | |
return LLMFactory._create_gpt_provider(llm_provider, api_key) | |
else: | |
raise ValueError(f"Unsupported LLM provider: {llm_provider}") | |
def _create_spark_provider(api_key: str) -> SparkLLM: | |
"""Create Spark LLM provider""" | |
cfg = ConfigProvider.get() | |
endpoint = cfg.global_config.llm_provider_endpoint | |
if not endpoint: | |
raise ValueError("Spark requires llm_provider_endpoint to be configured") | |
log(f"π Creating SparkLLM provider") | |
log(f"π Endpoint: {endpoint}") | |
# Determine work mode for Spark (backward compatibility) | |
work_mode = "cloud" # Default | |
if not cfg.global_config.is_cloud_mode(): | |
work_mode = "on-premise" | |
return SparkLLM( | |
spark_endpoint=str(endpoint), | |
spark_token=api_key, | |
work_mode=work_mode | |
) | |
def _create_gpt_provider(model_type: str, api_key: str) -> GPT4oLLM: | |
"""Create GPT-4o LLM provider""" | |
# Determine model | |
model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o" | |
log(f"π€ Creating GPT4oLLM provider with model: {model}") | |
return GPT4oLLM( | |
api_key=api_key, | |
model=model | |
) | |
def _get_api_key() -> Optional[str]: | |
"""Get API key from config or environment""" | |
cfg = ConfigProvider.get() | |
# First check encrypted config | |
api_key = cfg.global_config.get_plain_api_key() | |
if api_key: | |
log("π Using decrypted API key from config") | |
return api_key | |
# Then check environment based on provider | |
llm_provider = cfg.global_config.llm_provider | |
env_var_map = { | |
"spark": "SPARK_TOKEN", | |
"gpt4o": "OPENAI_API_KEY", | |
"gpt4o-mini": "OPENAI_API_KEY", | |
# Add more mappings as needed | |
} | |
env_var = env_var_map.get(llm_provider) | |
if env_var: | |
if cfg.global_config.is_cloud_mode(): | |
api_key = os.environ.get(env_var) | |
if api_key: | |
log(f"π Using {env_var} from environment") | |
else: | |
load_dotenv() | |
api_key = os.getenv(env_var) | |
if api_key: | |
log(f"π Using {env_var} from .env file") | |
return api_key |