Spaces:
Paused
Paused
| """ | |
| Spark LLM Implementation | |
| """ | |
| import os | |
| import httpx | |
| import json | |
| from typing import Dict, List, Any, AsyncIterator | |
| from llm_interface import LLMInterface | |
| from logger import log_info, log_error, log_warning, log_debug | |
| # Get timeout from environment | |
| DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60")) | |
| MAX_RESPONSE_LENGTH = int(os.getenv("LLM_MAX_RESPONSE_LENGTH", "4096")) | |
| class SparkLLM(LLMInterface): | |
| """Spark LLM integration with improved error handling""" | |
| def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None): | |
| super().__init__(settings) | |
| self.spark_endpoint = spark_endpoint.rstrip("/") | |
| self.spark_token = spark_token | |
| self.provider_variant = provider_variant | |
| self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT) | |
| log_info(f"π SparkLLM initialized", endpoint=self.spark_endpoint, timeout=self.timeout) | |
| async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
| """Generate response with improved error handling and streaming support""" | |
| headers = { | |
| "Authorization": f"Bearer {self.spark_token}", | |
| "Content-Type": "application/json" | |
| } | |
| # Build context messages | |
| messages = [] | |
| if system_prompt: | |
| messages.append({ | |
| "role": "system", | |
| "content": system_prompt | |
| }) | |
| for msg in context[-10:]: # Last 10 messages for context | |
| messages.append({ | |
| "role": msg.get("role", "user"), | |
| "content": msg.get("content", "") | |
| }) | |
| messages.append({ | |
| "role": "user", | |
| "content": user_input | |
| }) | |
| payload = { | |
| "messages": messages, | |
| "mode": self.provider_variant, | |
| "max_tokens": self.settings.get("max_tokens", 2048), | |
| "temperature": self.settings.get("temperature", 0.7), | |
| "stream": False # For now, no streaming | |
| } | |
| try: | |
| async with httpx.AsyncClient(timeout=self.timeout) as client: | |
| with LogTimer(f"Spark LLM request"): | |
| response = await client.post( | |
| f"{self.spark_endpoint}/generate", | |
| json=payload, | |
| headers=headers | |
| ) | |
| # Check for rate limiting | |
| if response.status_code == 429: | |
| retry_after = response.headers.get("Retry-After", "60") | |
| log_warning(f"Rate limited by Spark", retry_after=retry_after) | |
| raise httpx.HTTPStatusError( | |
| f"Rate limited. Retry after {retry_after}s", | |
| request=response.request, | |
| response=response | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| # Extract response | |
| content = result.get("model_answer", "") | |
| # Check response length | |
| if len(content) > MAX_RESPONSE_LENGTH: | |
| log_warning(f"Response exceeded max length, truncating", | |
| original_length=len(content), | |
| max_length=MAX_RESPONSE_LENGTH) | |
| content = content[:MAX_RESPONSE_LENGTH] + "..." | |
| return content | |
| except httpx.TimeoutException: | |
| log_error(f"Spark request timed out", timeout=self.timeout) | |
| raise | |
| except httpx.HTTPStatusError as e: | |
| log_error(f"Spark HTTP error", | |
| status_code=e.response.status_code, | |
| response=e.response.text[:500]) | |
| raise | |
| except Exception as e: | |
| log_error("Spark unexpected error", error=str(e)) | |
| raise | |
| def get_provider_name(self) -> str: | |
| return f"spark-{self.provider_variant}" | |
| def get_model_info(self) -> Dict[str, Any]: | |
| return { | |
| "provider": "spark", | |
| "variant": self.provider_variant, | |
| "endpoint": self.spark_endpoint, | |
| "max_tokens": self.settings.get("max_tokens", 2048), | |
| "temperature": self.settings.get("temperature", 0.7) | |
| } |