File size: 4,271 Bytes
2ebef02
 
 
 
394611c
2ebef02
 
 
 
 
 
 
394611c
2ebef02
 
 
 
 
394611c
2ebef02
394611c
b9b2b1e
2ebef02
394611c
b9b2b1e
 
394611c
 
 
b9b2b1e
2ebef02
 
b9b2b1e
394611c
b9b2b1e
 
394611c
 
2ebef02
 
394611c
 
b9b2b1e
394611c
2ebef02
b9b2b1e
2ebef02
 
394611c
2ebef02
394611c
 
 
 
2ebef02
 
394611c
 
 
 
 
2ebef02
394611c
2ebef02
394611c
b9b2b1e
2ebef02
 
 
394611c
2ebef02
 
 
 
 
 
 
 
b9b2b1e
 
2ebef02
 
 
b9b2b1e
2ebef02
 
 
 
b9b2b1e
2ebef02
 
 
 
 
 
 
 
 
 
 
b9b2b1e
2ebef02
b9b2b1e
 
2ebef02
 
b9b2b1e
394611c
2ebef02
394611c
2ebef02
 
 
 
394611c
2ebef02
394611c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
LLM Provider Factory for Flare
"""
import os
from typing import Optional, Dict, Any
from dotenv import load_dotenv

from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
from config_provider import ConfigProvider
from utils import log

class LLMFactory:
    """Factory class to create appropriate LLM provider based on llm_provider config"""
    
    @staticmethod
    def create_provider() -> LLMInterface:
        """Create and return appropriate LLM provider based on config"""
        cfg = ConfigProvider.get()
        llm_provider = cfg.global_config.llm_provider
        
        if not llm_provider or not llm_provider.name:
            raise ValueError("No LLM provider configured")
        
        provider_name = llm_provider.name
        log(f"🏭 Creating LLM provider: {provider_name}")
        
        # Get provider config
        provider_config = cfg.global_config.get_provider_config("llm", provider_name)
        if not provider_config:
            raise ValueError(f"Unknown LLM provider: {provider_name}")
        
        # Get API key
        api_key = LLMFactory._get_api_key(provider_name)
        if not api_key and provider_config.requires_api_key:
            raise ValueError(f"API key required for {provider_name} but not configured")
        
        # Get settings
        settings = llm_provider.settings or {}
        
        # Create appropriate provider
        if provider_name == "spark":
            return LLMFactory._create_spark_provider(api_key, llm_provider.endpoint, settings)
        elif provider_name in ("gpt4o", "gpt4o-mini"):
            return LLMFactory._create_gpt_provider(provider_name, api_key, settings)
        else:
            raise ValueError(f"Unsupported LLM provider: {provider_name}")
    
    @staticmethod
    def _create_spark_provider(api_key: str, endpoint: Optional[str], settings: Dict[str, Any]) -> SparkLLM:
        """Create Spark LLM provider"""
        if not endpoint:
            raise ValueError("Spark requires endpoint to be configured")
        
        log(f"πŸš€ Creating SparkLLM provider")
        log(f"πŸ“ Endpoint: {endpoint}")
        
        # Determine provider variant for backward compatibility
        provider_variant = "spark-cloud"
        if not ConfigProvider.get().global_config.is_cloud_mode():
            provider_variant = "spark-onpremise"
        
        return SparkLLM(
            spark_endpoint=str(endpoint),
            spark_token=api_key,
            provider_variant=provider_variant,
            settings=settings
        )
    
    @staticmethod
    def _create_gpt_provider(model_type: str, api_key: str, settings: Dict[str, Any]) -> GPT4oLLM:
        """Create GPT-4o LLM provider"""
        # Determine model
        model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
        
        log(f"πŸ€– Creating GPT4oLLM provider with model: {model}")
        
        return GPT4oLLM(
            api_key=api_key,
            model=model,
            settings=settings
        )
    
    @staticmethod
    def _get_api_key(provider_name: str) -> Optional[str]:
        """Get API key from config or environment"""
        cfg = ConfigProvider.get()
        
        # First check encrypted config
        api_key = cfg.global_config.get_plain_api_key("llm")
        if api_key:
            log("πŸ”‘ Using decrypted API key from config")
            return api_key
        
        # Then check environment based on provider
        env_var_map = {
            "spark": "SPARK_TOKEN",
            "gpt4o": "OPENAI_API_KEY",
            "gpt4o-mini": "OPENAI_API_KEY",
        }
        
        env_var = env_var_map.get(provider_name)
        if env_var:
            # Check if running in HuggingFace Space
            if os.environ.get("SPACE_ID"):
                api_key = os.environ.get(env_var)
                if api_key:
                    log(f"πŸ”‘ Using {env_var} from HuggingFace secrets")
                    return api_key
            else:
                # Local development
                load_dotenv()
                api_key = os.getenv(env_var)
                if api_key:
                    log(f"πŸ”‘ Using {env_var} from .env file")
                    return api_key
        
        return None