Spaces:
Running
Running
""" | |
LLM Provider Interface for Flare | |
""" | |
import os | |
from abc import ABC, abstractmethod | |
from typing import Dict, List, Optional | |
import httpx | |
from openai import AsyncOpenAI | |
from utils import log | |
class LLMInterface(ABC): | |
"""Abstract base class for LLM providers""" | |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
"""Generate response from LLM""" | |
pass | |
async def startup(self, project_config: Dict) -> bool: | |
"""Initialize LLM with project config""" | |
pass | |
class SparkLLM(LLMInterface): | |
"""Existing Spark integration""" | |
def __init__(self, spark_endpoint: str, spark_token: str, work_mode: str = "cloud"): | |
self.spark_endpoint = spark_endpoint.rstrip("/") | |
self.spark_token = spark_token | |
self.work_mode = work_mode | |
log(f"π SparkLLM initialized with endpoint: {self.spark_endpoint}") | |
async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
headers = { | |
"Authorization": f"Bearer {self.spark_token}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"system_prompt": system_prompt, | |
"user_input": user_input, | |
"context": context | |
} | |
try: | |
async with httpx.AsyncClient(timeout=60) as client: | |
response = await client.post( | |
f"{self.spark_endpoint}/generate", | |
json=payload, | |
headers=headers | |
) | |
response.raise_for_status() | |
data = response.json() | |
# Try different response fields | |
raw = data.get("model_answer", "").strip() | |
if not raw: | |
raw = (data.get("assistant") or data.get("text", "")).strip() | |
return raw | |
except Exception as e: | |
log(f"β Spark error: {e}") | |
raise | |
async def startup(self, project_config: Dict) -> bool: | |
"""Send startup request to Spark""" | |
# Existing Spark startup logic | |
return True | |
class GPT4oLLM(LLMInterface): | |
"""OpenAI GPT integration""" | |
def __init__(self, api_key: str, model: str = "gpt-4o-mini"): | |
self.api_key = api_key | |
self.model = model | |
self.client = AsyncOpenAI(api_key=api_key) | |
log(f"β Initialized GPT LLM with model: {model}") | |
async def generate(self, project_name: str, user_input: str, system_prompt: str, context: List[Dict], version_config: Dict = None) -> str: | |
"""Generate response from LLM with project context""" | |
headers = { | |
"Authorization": f"Bearer {self.spark_token}", | |
"Content-Type": "application/json" | |
} | |
# Build payload with all required fields for Spark | |
payload = { | |
"work_mode": self.work_mode, | |
"cloud_token": self.spark_token, | |
"project_name": project_name, | |
"system_prompt": system_prompt, | |
"user_input": user_input, | |
"context": context | |
} | |
# Add version-specific config if available | |
if version_config: | |
llm_config = version_config.get("llm", {}) | |
payload.update({ | |
"project_version": version_config.get("version_id"), | |
"repo_id": llm_config.get("repo_id"), | |
"generation_config": llm_config.get("generation_config"), | |
"use_fine_tune": llm_config.get("use_fine_tune"), | |
"fine_tune_zip": llm_config.get("fine_tune_zip") | |
}) | |
try: | |
log(f"π€ Spark request payload keys: {list(payload.keys())}") | |
async with httpx.AsyncClient(timeout=60) as client: | |
response = await client.post( | |
f"{self.spark_endpoint}/generate", | |
json=payload, | |
headers=headers | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data.get("model_answer", data.get("assistant", data.get("text", ""))) | |
except httpx.TimeoutException: | |
log("β±οΈ Spark timeout") | |
raise | |
except Exception as e: | |
log(f"β Spark error: {str(e)}") | |
raise | |
async def startup(self, project_config: Dict) -> bool: | |
"""Validate API key""" | |
try: | |
# Test API key with a simple request | |
test_response = await self.client.models.list() | |
log(f"β OpenAI API key validated, available models: {len(test_response.data)}") | |
return True | |
except Exception as e: | |
log(f"β Invalid OpenAI API key: {e}") | |
return False |