Spaces:
Paused
Paused
Update llm_factory.py
Browse files- llm_factory.py +30 -57
llm_factory.py
CHANGED
|
@@ -5,7 +5,9 @@ import os
|
|
| 5 |
from typing import Optional
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
| 8 |
-
from llm_interface import LLMInterface
|
|
|
|
|
|
|
| 9 |
from config_provider import ConfigProvider
|
| 10 |
from utils import log
|
| 11 |
|
|
@@ -39,63 +41,44 @@ class LLMFactory:
|
|
| 39 |
raise ValueError(f"Unsupported LLM provider: {provider_name}")
|
| 40 |
|
| 41 |
@staticmethod
|
| 42 |
-
def _create_spark_provider(llm_config, api_key
|
| 43 |
"""Create Spark LLM provider"""
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# Extract work mode variant (for backward compatibility)
|
| 51 |
-
provider_variant = "cloud" # Default
|
| 52 |
-
if os.getenv("SPACE_ID"): # HuggingFace Space
|
| 53 |
-
provider_variant = "hfcloud"
|
| 54 |
-
|
| 55 |
-
log(f"π Initializing SparkLLM: {llm_config.endpoint}")
|
| 56 |
-
log(f"π§ Provider variant: {provider_variant}")
|
| 57 |
|
| 58 |
return SparkLLM(
|
| 59 |
-
spark_endpoint=
|
| 60 |
spark_token=api_key,
|
| 61 |
-
provider_variant=
|
| 62 |
settings=llm_config.settings
|
| 63 |
)
|
| 64 |
|
| 65 |
@staticmethod
|
| 66 |
-
def _create_gpt_provider(llm_config, api_key
|
| 67 |
-
"""Create GPT
|
| 68 |
-
|
| 69 |
-
raise ValueError("OpenAI API key is required")
|
| 70 |
-
|
| 71 |
-
# Get model-specific settings
|
| 72 |
-
settings = llm_config.settings or {}
|
| 73 |
-
model = provider_def.name # gpt4o or gpt4o-mini
|
| 74 |
-
|
| 75 |
-
log(f"π€ Initializing GPT4oLLM with model: {model}")
|
| 76 |
-
|
| 77 |
-
return GPT4oLLM(
|
| 78 |
api_key=api_key,
|
| 79 |
-
model=
|
| 80 |
-
settings=settings
|
| 81 |
)
|
| 82 |
|
| 83 |
@staticmethod
|
| 84 |
-
def _get_api_key(provider_name: str,
|
| 85 |
"""Get API key from config or environment"""
|
| 86 |
-
# First
|
| 87 |
-
if
|
| 88 |
-
|
| 89 |
-
|
| 90 |
from encryption_utils import decrypt
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
return decrypted
|
| 94 |
-
else:
|
| 95 |
-
log(f"π Using plain API key from config: ***{config_key[-4:]}")
|
| 96 |
-
return config_key
|
| 97 |
|
| 98 |
-
# Then
|
| 99 |
env_mappings = {
|
| 100 |
"spark": "SPARK_TOKEN",
|
| 101 |
"gpt4o": "OPENAI_API_KEY",
|
|
@@ -104,19 +87,9 @@ class LLMFactory:
|
|
| 104 |
|
| 105 |
env_var = env_mappings.get(provider_name)
|
| 106 |
if env_var:
|
| 107 |
-
|
| 108 |
-
if
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
if api_key:
|
| 112 |
-
log(f"π Using API key from HuggingFace secrets: {env_var}")
|
| 113 |
-
return api_key
|
| 114 |
-
else:
|
| 115 |
-
# Local mode - use dotenv
|
| 116 |
-
load_dotenv()
|
| 117 |
-
api_key = os.getenv(env_var)
|
| 118 |
-
if api_key:
|
| 119 |
-
log(f"π Using API key from .env: {env_var}")
|
| 120 |
-
return api_key
|
| 121 |
|
| 122 |
-
|
|
|
|
| 5 |
from typing import Optional
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
| 8 |
+
from llm_interface import LLMInterface
|
| 9 |
+
from llm_spark import SparkLLM
|
| 10 |
+
from llm_openai import OpenAILLM
|
| 11 |
from config_provider import ConfigProvider
|
| 12 |
from utils import log
|
| 13 |
|
|
|
|
| 41 |
raise ValueError(f"Unsupported LLM provider: {provider_name}")
|
| 42 |
|
| 43 |
@staticmethod
|
| 44 |
+
def _create_spark_provider(llm_config, api_key, provider_def):
|
| 45 |
"""Create Spark LLM provider"""
|
| 46 |
+
endpoint = llm_config.endpoint
|
| 47 |
+
if not endpoint:
|
| 48 |
+
raise ValueError("Spark endpoint not configured")
|
| 49 |
|
| 50 |
+
# Determine variant based on environment
|
| 51 |
+
is_cloud = bool(os.environ.get("SPACE_ID"))
|
| 52 |
+
variant = "hfcloud" if is_cloud else "on-premise"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
return SparkLLM(
|
| 55 |
+
spark_endpoint=endpoint,
|
| 56 |
spark_token=api_key,
|
| 57 |
+
provider_variant=variant,
|
| 58 |
settings=llm_config.settings
|
| 59 |
)
|
| 60 |
|
| 61 |
@staticmethod
|
| 62 |
+
def _create_gpt_provider(llm_config, api_key, provider_def):
|
| 63 |
+
"""Create OpenAI GPT provider"""
|
| 64 |
+
return OpenAILLM(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
api_key=api_key,
|
| 66 |
+
model=llm_config.name,
|
| 67 |
+
settings=llm_config.settings
|
| 68 |
)
|
| 69 |
|
| 70 |
@staticmethod
|
| 71 |
+
def _get_api_key(provider_name: str, configured_key: Optional[str]) -> str:
|
| 72 |
"""Get API key from config or environment"""
|
| 73 |
+
# First try configured key
|
| 74 |
+
if configured_key:
|
| 75 |
+
# Handle encrypted keys
|
| 76 |
+
if configured_key.startswith("enc:"):
|
| 77 |
from encryption_utils import decrypt
|
| 78 |
+
return decrypt(configured_key)
|
| 79 |
+
return configured_key
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
# Then try environment variables
|
| 82 |
env_mappings = {
|
| 83 |
"spark": "SPARK_TOKEN",
|
| 84 |
"gpt4o": "OPENAI_API_KEY",
|
|
|
|
| 87 |
|
| 88 |
env_var = env_mappings.get(provider_name)
|
| 89 |
if env_var:
|
| 90 |
+
key = os.environ.get(env_var)
|
| 91 |
+
if key:
|
| 92 |
+
log(f"π Using API key from environment: {env_var}")
|
| 93 |
+
return key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
raise ValueError(f"No API key found for provider: {provider_name}")
|