File size: 4,905 Bytes
ab7e98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c31df77
 
ab7e98d
 
c31df77
 
ab7e98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c31df77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab7e98d
c31df77
 
 
 
 
 
 
 
 
 
 
 
 
ab7e98d
c31df77
ab7e98d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
LLM Provider Interface for Flare
"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import httpx
from openai import AsyncOpenAI
from utils import log

class LLMInterface(ABC):
    """Abstract base class for LLM providers"""
    
    @abstractmethod
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response from LLM"""
        pass
    
    @abstractmethod
    async def startup(self, project_config: Dict) -> bool:
        """Initialize LLM with project config"""
        pass

class SparkLLM(LLMInterface):
    """Existing Spark integration"""
    

    def __init__(self, spark_endpoint: str, spark_token: str, work_mode: str = "cloud"):
        self.spark_endpoint = spark_endpoint.rstrip("/")
        self.spark_token = spark_token
        self.work_mode = work_mode
        log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
    
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        headers = {
            "Authorization": f"Bearer {self.spark_token}",
            "Content-Type": "application/json"
        }
        
        payload = {
            "system_prompt": system_prompt,
            "user_input": user_input,
            "context": context
        }
        
        try:
            async with httpx.AsyncClient(timeout=60) as client:
                response = await client.post(
                    f"{self.spark_endpoint}/generate",
                    json=payload,
                    headers=headers
                )
                response.raise_for_status()
                data = response.json()
                
                # Try different response fields
                raw = data.get("model_answer", "").strip()
                if not raw:
                    raw = (data.get("assistant") or data.get("text", "")).strip()
                
                return raw
        except Exception as e:
            log(f"❌ Spark error: {e}")
            raise
    
    async def startup(self, project_config: Dict) -> bool:
        """Send startup request to Spark"""
        # Existing Spark startup logic
        return True

class GPT4oLLM(LLMInterface):
    """OpenAI GPT integration"""
    
    def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
        self.api_key = api_key
        self.model = model
        self.client = AsyncOpenAI(api_key=api_key)
        log(f"βœ… Initialized GPT LLM with model: {model}")
    
    async def generate(self, project_name: str, user_input: str, system_prompt: str, context: List[Dict], version_config: Dict = None) -> str:
        """Generate response from LLM with project context"""
        headers = {
            "Authorization": f"Bearer {self.spark_token}",
            "Content-Type": "application/json"
        }
        
        # Build payload with all required fields for Spark
        payload = {
            "work_mode": self.work_mode,
            "cloud_token": self.spark_token,
            "project_name": project_name,
            "system_prompt": system_prompt,
            "user_input": user_input,
            "context": context
        }
        
        # Add version-specific config if available
        if version_config:
            llm_config = version_config.get("llm", {})
            payload.update({
                "project_version": version_config.get("version_id"),
                "repo_id": llm_config.get("repo_id"),
                "generation_config": llm_config.get("generation_config"),
                "use_fine_tune": llm_config.get("use_fine_tune"),
                "fine_tune_zip": llm_config.get("fine_tune_zip")
            })
        
        try:
            log(f"πŸ“€ Spark request payload keys: {list(payload.keys())}")
            async with httpx.AsyncClient(timeout=60) as client:
                response = await client.post(
                    f"{self.spark_endpoint}/generate",
                    json=payload,
                    headers=headers
                )
                response.raise_for_status()
                data = response.json()
                return data.get("model_answer", data.get("assistant", data.get("text", "")))
        except httpx.TimeoutException:
            log("⏱️ Spark timeout")
            raise
        except Exception as e:
            log(f"❌ Spark error: {str(e)}")
            raise
    
    async def startup(self, project_config: Dict) -> bool:
        """Validate API key"""
        try:
            # Test API key with a simple request
            test_response = await self.client.models.list()
            log(f"βœ… OpenAI API key validated, available models: {len(test_response.data)}")
            return True
        except Exception as e:
            log(f"❌ Invalid OpenAI API key: {e}")
            return False