Spaces:
Running
Running
File size: 3,796 Bytes
2ebef02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
"""
LLM Provider Factory for Flare
"""
import os
from typing import Optional
from dotenv import load_dotenv
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
from config_provider import ConfigProvider
from utils import log
class LLMFactory:
"""Factory class to create appropriate LLM provider based on llm_provider config"""
@staticmethod
def create_provider() -> LLMInterface:
"""Create and return appropriate LLM provider based on config"""
cfg = ConfigProvider.get()
llm_provider = cfg.global_config.llm_provider
log(f"π Creating LLM provider: {llm_provider}")
# Get provider config
provider_config = cfg.global_config.get_llm_provider_config()
if not provider_config:
raise ValueError(f"Unknown LLM provider: {llm_provider}")
# Get API key
api_key = LLMFactory._get_api_key()
if not api_key and provider_config.requires_api_key:
raise ValueError(f"API key required for {llm_provider} but not configured")
# Create appropriate provider
if llm_provider == "spark":
return LLMFactory._create_spark_provider(api_key)
elif llm_provider in ("gpt4o", "gpt4o-mini"):
return LLMFactory._create_gpt_provider(llm_provider, api_key)
else:
raise ValueError(f"Unsupported LLM provider: {llm_provider}")
@staticmethod
def _create_spark_provider(api_key: str) -> SparkLLM:
"""Create Spark LLM provider"""
cfg = ConfigProvider.get()
endpoint = cfg.global_config.llm_provider_endpoint
if not endpoint:
raise ValueError("Spark requires llm_provider_endpoint to be configured")
log(f"π Creating SparkLLM provider")
log(f"π Endpoint: {endpoint}")
# Determine work mode for Spark (backward compatibility)
work_mode = "cloud" # Default
if not cfg.global_config.is_cloud_mode():
work_mode = "on-premise"
return SparkLLM(
spark_endpoint=str(endpoint),
spark_token=api_key,
work_mode=work_mode
)
@staticmethod
def _create_gpt_provider(model_type: str, api_key: str) -> GPT4oLLM:
"""Create GPT-4o LLM provider"""
# Determine model
model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
log(f"π€ Creating GPT4oLLM provider with model: {model}")
return GPT4oLLM(
api_key=api_key,
model=model
)
@staticmethod
def _get_api_key() -> Optional[str]:
"""Get API key from config or environment"""
cfg = ConfigProvider.get()
# First check encrypted config
api_key = cfg.global_config.get_plain_api_key()
if api_key:
log("π Using decrypted API key from config")
return api_key
# Then check environment based on provider
llm_provider = cfg.global_config.llm_provider
env_var_map = {
"spark": "SPARK_TOKEN",
"gpt4o": "OPENAI_API_KEY",
"gpt4o-mini": "OPENAI_API_KEY",
# Add more mappings as needed
}
env_var = env_var_map.get(llm_provider)
if env_var:
if cfg.global_config.is_cloud_mode():
api_key = os.environ.get(env_var)
if api_key:
log(f"π Using {env_var} from environment")
else:
load_dotenv()
api_key = os.getenv(env_var)
if api_key:
log(f"π Using {env_var} from .env file")
return api_key |