flare / llm_factory.py
ciyidogan's picture
Create llm_factory.py
2ebef02 verified
raw
history blame
3.8 kB
"""
LLM Provider Factory for Flare
"""
import os
from typing import Optional
from dotenv import load_dotenv
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
from config_provider import ConfigProvider
from utils import log
class LLMFactory:
"""Factory class to create appropriate LLM provider based on llm_provider config"""
@staticmethod
def create_provider() -> LLMInterface:
"""Create and return appropriate LLM provider based on config"""
cfg = ConfigProvider.get()
llm_provider = cfg.global_config.llm_provider
log(f"🏭 Creating LLM provider: {llm_provider}")
# Get provider config
provider_config = cfg.global_config.get_llm_provider_config()
if not provider_config:
raise ValueError(f"Unknown LLM provider: {llm_provider}")
# Get API key
api_key = LLMFactory._get_api_key()
if not api_key and provider_config.requires_api_key:
raise ValueError(f"API key required for {llm_provider} but not configured")
# Create appropriate provider
if llm_provider == "spark":
return LLMFactory._create_spark_provider(api_key)
elif llm_provider in ("gpt4o", "gpt4o-mini"):
return LLMFactory._create_gpt_provider(llm_provider, api_key)
else:
raise ValueError(f"Unsupported LLM provider: {llm_provider}")
@staticmethod
def _create_spark_provider(api_key: str) -> SparkLLM:
"""Create Spark LLM provider"""
cfg = ConfigProvider.get()
endpoint = cfg.global_config.llm_provider_endpoint
if not endpoint:
raise ValueError("Spark requires llm_provider_endpoint to be configured")
log(f"πŸš€ Creating SparkLLM provider")
log(f"πŸ“ Endpoint: {endpoint}")
# Determine work mode for Spark (backward compatibility)
work_mode = "cloud" # Default
if not cfg.global_config.is_cloud_mode():
work_mode = "on-premise"
return SparkLLM(
spark_endpoint=str(endpoint),
spark_token=api_key,
work_mode=work_mode
)
@staticmethod
def _create_gpt_provider(model_type: str, api_key: str) -> GPT4oLLM:
"""Create GPT-4o LLM provider"""
# Determine model
model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
log(f"πŸ€– Creating GPT4oLLM provider with model: {model}")
return GPT4oLLM(
api_key=api_key,
model=model
)
@staticmethod
def _get_api_key() -> Optional[str]:
"""Get API key from config or environment"""
cfg = ConfigProvider.get()
# First check encrypted config
api_key = cfg.global_config.get_plain_api_key()
if api_key:
log("πŸ”‘ Using decrypted API key from config")
return api_key
# Then check environment based on provider
llm_provider = cfg.global_config.llm_provider
env_var_map = {
"spark": "SPARK_TOKEN",
"gpt4o": "OPENAI_API_KEY",
"gpt4o-mini": "OPENAI_API_KEY",
# Add more mappings as needed
}
env_var = env_var_map.get(llm_provider)
if env_var:
if cfg.global_config.is_cloud_mode():
api_key = os.environ.get(env_var)
if api_key:
log(f"πŸ”‘ Using {env_var} from environment")
else:
load_dotenv()
api_key = os.getenv(env_var)
if api_key:
log(f"πŸ”‘ Using {env_var} from .env file")
return api_key