Spaces:
Running
Running
File size: 4,271 Bytes
2ebef02 394611c 2ebef02 394611c 2ebef02 394611c 2ebef02 394611c b9b2b1e 2ebef02 394611c b9b2b1e 394611c b9b2b1e 2ebef02 b9b2b1e 394611c b9b2b1e 394611c 2ebef02 394611c b9b2b1e 394611c 2ebef02 b9b2b1e 2ebef02 394611c 2ebef02 394611c 2ebef02 394611c 2ebef02 394611c 2ebef02 394611c b9b2b1e 2ebef02 394611c 2ebef02 b9b2b1e 2ebef02 b9b2b1e 2ebef02 b9b2b1e 2ebef02 b9b2b1e 2ebef02 b9b2b1e 2ebef02 b9b2b1e 394611c 2ebef02 394611c 2ebef02 394611c 2ebef02 394611c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""
LLM Provider Factory for Flare
"""
import os
from typing import Optional, Dict, Any
from dotenv import load_dotenv
from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
from config_provider import ConfigProvider
from utils import log
class LLMFactory:
"""Factory class to create appropriate LLM provider based on llm_provider config"""
@staticmethod
def create_provider() -> LLMInterface:
"""Create and return appropriate LLM provider based on config"""
cfg = ConfigProvider.get()
llm_provider = cfg.global_config.llm_provider
if not llm_provider or not llm_provider.name:
raise ValueError("No LLM provider configured")
provider_name = llm_provider.name
log(f"π Creating LLM provider: {provider_name}")
# Get provider config
provider_config = cfg.global_config.get_provider_config("llm", provider_name)
if not provider_config:
raise ValueError(f"Unknown LLM provider: {provider_name}")
# Get API key
api_key = LLMFactory._get_api_key(provider_name)
if not api_key and provider_config.requires_api_key:
raise ValueError(f"API key required for {provider_name} but not configured")
# Get settings
settings = llm_provider.settings or {}
# Create appropriate provider
if provider_name == "spark":
return LLMFactory._create_spark_provider(api_key, llm_provider.endpoint, settings)
elif provider_name in ("gpt4o", "gpt4o-mini"):
return LLMFactory._create_gpt_provider(provider_name, api_key, settings)
else:
raise ValueError(f"Unsupported LLM provider: {provider_name}")
@staticmethod
def _create_spark_provider(api_key: str, endpoint: Optional[str], settings: Dict[str, Any]) -> SparkLLM:
"""Create Spark LLM provider"""
if not endpoint:
raise ValueError("Spark requires endpoint to be configured")
log(f"π Creating SparkLLM provider")
log(f"π Endpoint: {endpoint}")
# Determine provider variant for backward compatibility
provider_variant = "spark-cloud"
if not ConfigProvider.get().global_config.is_cloud_mode():
provider_variant = "spark-onpremise"
return SparkLLM(
spark_endpoint=str(endpoint),
spark_token=api_key,
provider_variant=provider_variant,
settings=settings
)
@staticmethod
def _create_gpt_provider(model_type: str, api_key: str, settings: Dict[str, Any]) -> GPT4oLLM:
"""Create GPT-4o LLM provider"""
# Determine model
model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
log(f"π€ Creating GPT4oLLM provider with model: {model}")
return GPT4oLLM(
api_key=api_key,
model=model,
settings=settings
)
@staticmethod
def _get_api_key(provider_name: str) -> Optional[str]:
"""Get API key from config or environment"""
cfg = ConfigProvider.get()
# First check encrypted config
api_key = cfg.global_config.get_plain_api_key("llm")
if api_key:
log("π Using decrypted API key from config")
return api_key
# Then check environment based on provider
env_var_map = {
"spark": "SPARK_TOKEN",
"gpt4o": "OPENAI_API_KEY",
"gpt4o-mini": "OPENAI_API_KEY",
}
env_var = env_var_map.get(provider_name)
if env_var:
# Check if running in HuggingFace Space
if os.environ.get("SPACE_ID"):
api_key = os.environ.get(env_var)
if api_key:
log(f"π Using {env_var} from HuggingFace secrets")
return api_key
else:
# Local development
load_dotenv()
api_key = os.getenv(env_var)
if api_key:
log(f"π Using {env_var} from .env file")
return api_key
return None |