flare / llm_interface.py
ciyidogan's picture
Update llm_interface.py
1c091c0 verified
raw
history blame
6.19 kB
"""
LLM Provider Interface for Flare
"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
import httpx
from openai import AsyncOpenAI
from utils import log
class LLMInterface(ABC):
"""Abstract base class for LLM providers"""
def __init__(self, settings: Dict[str, Any] = None):
"""Initialize with provider settings"""
self.settings = settings or {}
self.internal_prompt = self.settings.get("internal_prompt", "")
self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
@abstractmethod
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from LLM"""
pass
@abstractmethod
async def startup(self, project_config: Dict) -> bool:
"""Initialize LLM with project config"""
pass
class SparkLLM(LLMInterface):
"""Spark LLM integration"""
def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
super().__init__(settings)
self.spark_endpoint = spark_endpoint.rstrip("/")
self.spark_token = spark_token
self.provider_variant = provider_variant
log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from Spark LLM"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Build payload
payload = {
"system_prompt": system_prompt,
"user_input": user_input,
"context": context
}
try:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{self.spark_endpoint}/generate",
json=payload,
headers=headers
)
response.raise_for_status()
data = response.json()
# Try different response fields
raw = data.get("model_answer", "").strip()
if not raw:
raw = (data.get("assistant") or data.get("text", "")).strip()
return raw
except Exception as e:
log(f"❌ Spark error: {e}")
raise
async def startup(self, project_config: Dict) -> bool:
"""Send startup request to Spark"""
headers = {
"Authorization": f"Bearer {self.spark_token}",
"Content-Type": "application/json"
}
# Extract required fields from project config
body = {
"work_mode": self.provider_variant,
"cloud_token": self.spark_token,
"project_name": project_config.get("name"),
"project_version": project_config.get("version_id"),
"repo_id": project_config.get("repo_id"),
"generation_config": project_config.get("generation_config", {}),
"use_fine_tune": project_config.get("use_fine_tune", False),
"fine_tune_zip": project_config.get("fine_tune_zip", "")
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(
f"{self.spark_endpoint}/startup",
json=body,
headers=headers
)
if response.status_code >= 400:
log(f"❌ Spark startup failed: {response.status_code} - {response.text}")
return False
log(f"βœ… Spark acknowledged startup ({response.status_code})")
return True
except Exception as e:
log(f"⚠️ Spark startup error: {e}")
return False
class GPT4oLLM(LLMInterface):
"""OpenAI GPT integration"""
def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
super().__init__(settings)
self.api_key = api_key
self.model = self._map_model_name(model)
self.client = AsyncOpenAI(api_key=api_key)
# Extract model-specific settings
self.temperature = settings.get("temperature", 0.7) if settings else 0.7
self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
log(f"βœ… Initialized GPT LLM with model: {self.model}")
def _map_model_name(self, model: str) -> str:
"""Map provider name to actual model name"""
mappings = {
"gpt4o": "gpt-4",
"gpt4o-mini": "gpt-4o-mini"
}
return mappings.get(model, model)
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
"""Generate response from OpenAI"""
try:
# Build messages
messages = [{"role": "system", "content": system_prompt}]
# Add context
for msg in context:
messages.append({
"role": msg.get("role", "user"),
"content": msg.get("content", "")
})
# Add current user input
messages.append({"role": "user", "content": user_input})
# Call OpenAI
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=self.max_tokens
)
return response.choices[0].message.content.strip()
except Exception as e:
log(f"❌ OpenAI error: {e}")
raise
async def startup(self, project_config: Dict) -> bool:
"""GPT doesn't need startup, always return True"""
log("βœ… GPT provider ready (no startup needed)")
return True