flare / llm_interface.py
ciyidogan's picture
Update llm_interface.py
c31df77 verified
raw
history blame
4.91 kB
"""
LLM Provider Interface for Flare
"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import httpx
from openai import AsyncOpenAI
from utils import log
class LLMInterface(ABC):
"""Abstract base class for LLM providers"""
@abstractmethod
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from LLM"""
pass
@abstractmethod
async def startup(self, project_config: Dict) -> bool:
"""Initialize LLM with project config"""
pass
class SparkLLM(LLMInterface):
"""Existing Spark integration"""
def __init__(self, spark_endpoint: str, spark_token: str, work_mode: str = "cloud"):
self.spark_endpoint = spark_endpoint.rstrip("/")
self.spark_token = spark_token
self.work_mode = work_mode
log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
payload = {
"system_prompt": system_prompt,
"user_input": user_input,
"context": context
}
try:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{self.spark_endpoint}/generate",
json=payload,
headers=headers
)
response.raise_for_status()
data = response.json()
# Try different response fields
raw = data.get("model_answer", "").strip()
if not raw:
raw = (data.get("assistant") or data.get("text", "")).strip()
return raw
except Exception as e:
log(f"❌ Spark error: {e}")
raise
async def startup(self, project_config: Dict) -> bool:
"""Send startup request to Spark"""
# Existing Spark startup logic
return True
class GPT4oLLM(LLMInterface):
"""OpenAI GPT integration"""
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
self.api_key = api_key
self.model = model
self.client = AsyncOpenAI(api_key=api_key)
log(f"βœ… Initialized GPT LLM with model: {model}")
async def generate(self, project_name: str, user_input: str, system_prompt: str, context: List[Dict], version_config: Dict = None) -> str:
"""Generate response from LLM with project context"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Build payload with all required fields for Spark
payload = {
"work_mode": self.work_mode,
"cloud_token": self.spark_token,
"project_name": project_name,
"system_prompt": system_prompt,
"user_input": user_input,
"context": context
}
# Add version-specific config if available
if version_config:
llm_config = version_config.get("llm", {})
payload.update({
"project_version": version_config.get("version_id"),
"repo_id": llm_config.get("repo_id"),
"generation_config": llm_config.get("generation_config"),
"use_fine_tune": llm_config.get("use_fine_tune"),
"fine_tune_zip": llm_config.get("fine_tune_zip")
})
try:
log(f"πŸ“€ Spark request payload keys: {list(payload.keys())}")
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{self.spark_endpoint}/generate",
json=payload,
headers=headers
)
response.raise_for_status()
data = response.json()
return data.get("model_answer", data.get("assistant", data.get("text", "")))
except httpx.TimeoutException:
log("⏱️ Spark timeout")
raise
except Exception as e:
log(f"❌ Spark error: {str(e)}")
raise
async def startup(self, project_config: Dict) -> bool:
"""Validate API key"""
try:
# Test API key with a simple request
test_response = await self.client.models.list()
log(f"βœ… OpenAI API key validated, available models: {len(test_response.data)}")
return True
except Exception as e:
log(f"❌ Invalid OpenAI API key: {e}")
return False