Spaces:
Paused
Paused
| """ | |
| OpenAI GPT Implementation | |
| """ | |
| import os | |
| import openai | |
| from typing import Dict, List, Any | |
| from llm_interface import LLMInterface | |
| from logger import log_info, log_error, log_warning, log_debug | |
| DEFAULT_LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT_SECONDS", "60")) | |
| class OpenAILLM(LLMInterface): | |
| """OpenAI GPT integration with improved error handling""" | |
| def __init__(self, api_key: str, model: str = "gpt-4", settings: Dict[str, Any] = None): | |
| super().__init__(settings) | |
| self.api_key = api_key | |
| self.model = model | |
| self.timeout = self.settings.get("timeout", DEFAULT_LLM_TIMEOUT) | |
| openai.api_key = api_key | |
| log_info(f"π OpenAI LLM initialized", model=self.model, timeout=self.timeout) | |
| async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
| """Generate response with consistent error handling""" | |
| # Build messages | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Add context | |
| for msg in context[-10:]: # Last 10 messages | |
| role = "assistant" if msg.get("role") == "assistant" else "user" | |
| messages.append({"role": role, "content": msg.get("content", "")}) | |
| # Add current input | |
| messages.append({"role": "user", "content": user_input}) | |
| try: | |
| with LogTimer(f"OpenAI {self.model} request"): | |
| # Use async client | |
| client = openai.AsyncOpenAI( | |
| api_key=self.api_key, | |
| timeout=self.timeout | |
| ) | |
| response = await client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| max_tokens=self.settings.get("max_tokens", 2048), | |
| temperature=self.settings.get("temperature", 0.7), | |
| stream=False | |
| ) | |
| # Extract content | |
| content = response.choices[0].message.content | |
| # Check length | |
| if len(content) > MAX_RESPONSE_LENGTH: | |
| log_warning(f"Response exceeded max length, truncating", | |
| original_length=len(content), | |
| max_length=MAX_RESPONSE_LENGTH) | |
| content = content[:MAX_RESPONSE_LENGTH] + "..." | |
| # Log token usage | |
| if response.usage: | |
| log_info(f"Token usage", | |
| prompt_tokens=response.usage.prompt_tokens, | |
| completion_tokens=response.usage.completion_tokens, | |
| total_tokens=response.usage.total_tokens) | |
| return content | |
| except openai.RateLimitError as e: | |
| log_warning("OpenAI rate limit", error=str(e)) | |
| raise | |
| except openai.APITimeoutError as e: | |
| log_error("OpenAI timeout", error=str(e)) | |
| raise | |
| except openai.APIError as e: | |
| log_error("OpenAI API error", | |
| status_code=e.status_code if hasattr(e, 'status_code') else None, | |
| error=str(e)) | |
| raise | |
| except Exception as e: | |
| log_error("OpenAI unexpected error", error=str(e)) | |
| raise | |
| async def startup(self, project_config: Dict) -> bool: | |
| """OpenAI doesn't need startup""" | |
| log_info("OpenAI startup called (no-op)") | |
| return True | |
| def get_provider_name(self) -> str: | |
| return f"openai-{self.model}" | |
| def get_model_info(self) -> Dict[str, Any]: | |
| return { | |
| "provider": "openai", | |
| "model": self.model, | |
| "max_tokens": self.settings.get("max_tokens", 2048), | |
| "temperature": self.settings.get("temperature", 0.7) | |
| } |