File size: 3,796 Bytes
2ebef02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
LLM Provider Factory for Flare
"""
import os
from typing import Optional
from dotenv import load_dotenv

from llm_interface import LLMInterface, SparkLLM, GPT4oLLM
from config_provider import ConfigProvider
from utils import log

class LLMFactory:
    """Factory class to create appropriate LLM provider based on llm_provider config"""
    
    @staticmethod
    def create_provider() -> LLMInterface:
        """Create and return appropriate LLM provider based on config"""
        cfg = ConfigProvider.get()
        llm_provider = cfg.global_config.llm_provider
        
        log(f"🏭 Creating LLM provider: {llm_provider}")
        
        # Get provider config
        provider_config = cfg.global_config.get_llm_provider_config()
        if not provider_config:
            raise ValueError(f"Unknown LLM provider: {llm_provider}")
        
        # Get API key
        api_key = LLMFactory._get_api_key()
        if not api_key and provider_config.requires_api_key:
            raise ValueError(f"API key required for {llm_provider} but not configured")
        
        # Create appropriate provider
        if llm_provider == "spark":
            return LLMFactory._create_spark_provider(api_key)
        elif llm_provider in ("gpt4o", "gpt4o-mini"):
            return LLMFactory._create_gpt_provider(llm_provider, api_key)
        else:
            raise ValueError(f"Unsupported LLM provider: {llm_provider}")
    
    @staticmethod
    def _create_spark_provider(api_key: str) -> SparkLLM:
        """Create Spark LLM provider"""
        cfg = ConfigProvider.get()
        
        endpoint = cfg.global_config.llm_provider_endpoint
        if not endpoint:
            raise ValueError("Spark requires llm_provider_endpoint to be configured")
        
        log(f"πŸš€ Creating SparkLLM provider")
        log(f"πŸ“ Endpoint: {endpoint}")
        
        # Determine work mode for Spark (backward compatibility)
        work_mode = "cloud"  # Default
        if not cfg.global_config.is_cloud_mode():
            work_mode = "on-premise"
        
        return SparkLLM(
            spark_endpoint=str(endpoint),
            spark_token=api_key,
            work_mode=work_mode
        )
    
    @staticmethod
    def _create_gpt_provider(model_type: str, api_key: str) -> GPT4oLLM:
        """Create GPT-4o LLM provider"""
        # Determine model
        model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o"
        
        log(f"πŸ€– Creating GPT4oLLM provider with model: {model}")
        
        return GPT4oLLM(
            api_key=api_key,
            model=model
        )
    
    @staticmethod
    def _get_api_key() -> Optional[str]:
        """Get API key from config or environment"""
        cfg = ConfigProvider.get()
        
        # First check encrypted config
        api_key = cfg.global_config.get_plain_api_key()
        if api_key:
            log("πŸ”‘ Using decrypted API key from config")
            return api_key
        
        # Then check environment based on provider
        llm_provider = cfg.global_config.llm_provider
        
        env_var_map = {
            "spark": "SPARK_TOKEN",
            "gpt4o": "OPENAI_API_KEY",
            "gpt4o-mini": "OPENAI_API_KEY",
            # Add more mappings as needed
        }
        
        env_var = env_var_map.get(llm_provider)
        if env_var:
            if cfg.global_config.is_cloud_mode():
                api_key = os.environ.get(env_var)
                if api_key:
                    log(f"πŸ”‘ Using {env_var} from environment")
            else:
                load_dotenv()
                api_key = os.getenv(env_var)
                if api_key:
                    log(f"πŸ”‘ Using {env_var} from .env file")
        
        return api_key