""" LLM Provider Interface for Flare """ import os from abc import ABC, abstractmethod from typing import Dict, List, Optional import httpx from openai import AsyncOpenAI from utils import log class LLMInterface(ABC): """Abstract base class for LLM providers""" @abstractmethod async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: """Generate response from LLM""" pass @abstractmethod async def startup(self, project_config: Dict) -> bool: """Initialize LLM with project config""" pass class SparkLLM(LLMInterface): """Existing Spark integration""" def __init__(self, spark_endpoint: str, spark_token: str, work_mode: str = "cloud"): self.spark_endpoint = spark_endpoint.rstrip("/") self.spark_token = spark_token self.work_mode = work_mode log(f"🔌 SparkLLM initialized with endpoint: {self.spark_endpoint}") async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: headers = { "Authorization": f"Bearer {self.spark_token}", "Content-Type": "application/json" } payload = { "system_prompt": system_prompt, "user_input": user_input, "context": context } try: async with httpx.AsyncClient(timeout=60) as client: response = await client.post( f"{self.spark_endpoint}/generate", json=payload, headers=headers ) response.raise_for_status() data = response.json() # Try different response fields raw = data.get("model_answer", "").strip() if not raw: raw = (data.get("assistant") or data.get("text", "")).strip() return raw except Exception as e: log(f"❌ Spark error: {e}") raise async def startup(self, project_config: Dict) -> bool: """Send startup request to Spark""" # Existing Spark startup logic return True class GPT4oLLM(LLMInterface): """OpenAI GPT integration""" def __init__(self, api_key: str, model: str = "gpt-4o-mini"): self.api_key = api_key self.model = model self.client = AsyncOpenAI(api_key=api_key) log(f"✅ Initialized GPT LLM with model: {model}") async def generate(self, project_name: str, user_input: str, system_prompt: str, context: List[Dict], version_config: Dict = None) -> str: """Generate response from LLM with project context""" headers = { "Authorization": f"Bearer {self.spark_token}", "Content-Type": "application/json" } # Build payload with all required fields for Spark payload = { "work_mode": self.work_mode, "cloud_token": self.spark_token, "project_name": project_name, "system_prompt": system_prompt, "user_input": user_input, "context": context } # Add version-specific config if available if version_config: llm_config = version_config.get("llm", {}) payload.update({ "project_version": version_config.get("version_id"), "repo_id": llm_config.get("repo_id"), "generation_config": llm_config.get("generation_config"), "use_fine_tune": llm_config.get("use_fine_tune"), "fine_tune_zip": llm_config.get("fine_tune_zip") }) try: log(f"📤 Spark request payload keys: {list(payload.keys())}") async with httpx.AsyncClient(timeout=60) as client: response = await client.post( f"{self.spark_endpoint}/generate", json=payload, headers=headers ) response.raise_for_status() data = response.json() return data.get("model_answer", data.get("assistant", data.get("text", ""))) except httpx.TimeoutException: log("⏱️ Spark timeout") raise except Exception as e: log(f"❌ Spark error: {str(e)}") raise async def startup(self, project_config: Dict) -> bool: """Validate API key""" try: # Test API key with a simple request test_response = await self.client.models.list() log(f"✅ OpenAI API key validated, available models: {len(test_response.data)}") return True except Exception as e: log(f"❌ Invalid OpenAI API key: {e}") return False