Spaces:
Runtime error
Runtime error
import json | |
import logging | |
from openai import OpenAI | |
from typing import Dict, Any, Optional | |
import gradio as gr | |
from prompts import PROMPT_ANALYZER_TEMPLATE | |
import time | |
logger = logging.getLogger(__name__) | |
FALLBACK_MODELS = [ | |
"mixtral-8x7b-32768", | |
"llama-3.1-70b-versatile", | |
"llama-3.1-8b-instant", | |
"llama3-70b-8192", | |
"llama3-8b-8192" | |
] | |
class ModelManager: | |
def __init__(self): | |
self.current_model_index = 0 | |
self.max_retries = len(FALLBACK_MODELS) | |
def current_model(self) -> str: | |
return FALLBACK_MODELS[self.current_model_index] | |
def next_model(self) -> str: | |
self.current_model_index = (self.current_model_index + 1) % len(FALLBACK_MODELS) | |
logger.info(f"Switching to model: {self.current_model}") | |
return self.current_model | |
class PromptEnhancementAPI: | |
def __init__(self, api_key: str, base_url: Optional[str] = None): | |
self.client = OpenAI( | |
api_key=api_key, | |
base_url=base_url or "https://api.groq.com/openai/v1" | |
) | |
self.model_manager = ModelManager() | |
def _try_parse_json(self, content: str, retries: int = 0) -> Dict[str, Any]: | |
try: | |
result = json.loads(content.strip().lstrip('\n')) | |
if not isinstance(result, dict): | |
raise ValueError("Response is not a valid JSON object") | |
return result | |
except (json.JSONDecodeError, ValueError) as e: | |
if retries < self.model_manager.max_retries - 1: | |
logger.warning(f"JSON parsing failed with model {self.model_manager.current_model}. Switching models...") | |
self.model_manager.next_model() | |
raise e | |
logger.error(f"JSON parsing failed with all models: {str(e)}") | |
raise | |
def generate_enhancement(self, system_prompt: str, user_prompt: str, user_directive: str = "", state: Optional[Dict] = None) -> Dict[str, Any]: | |
retries = 0 | |
last_error = None | |
while retries < self.model_manager.max_retries: | |
try: | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
if user_directive: | |
messages.append({"role": "user", "content": f"User directive: {user_directive}"}) | |
if state: | |
messages.append({ | |
"role": "assistant", | |
"content": json.dumps(state) | |
}) | |
response = self.client.chat.completions.create( | |
model=self.model_manager.current_model, | |
messages=messages, | |
temperature=0.7, | |
max_tokens=4000, | |
response_format={"type": "json_object"} | |
) | |
result = self._try_parse_json(response.choices[0].message.content, retries) | |
return result | |
except (json.JSONDecodeError, ValueError) as e: | |
last_error = e | |
retries += 1 | |
if retries < self.model_manager.max_retries: | |
logger.warning(f"Attempt {retries} failed. Switching models and retrying...") | |
time.sleep(1) # Brief pause before retry | |
continue | |
break | |
except Exception as e: | |
logger.error(f"API error: {str(e)}") | |
if "rate limit" in str(e).lower(): | |
if retries < self.model_manager.max_retries - 1: | |
self.model_manager.next_model() | |
retries += 1 | |
time.sleep(1) | |
continue | |
raise gr.Error(f"API request failed: {str(e)}") | |
logger.error(f"All models failed to generate valid JSON: {str(last_error)}") | |
return create_error_response(user_prompt, user_directive) | |
class PromptEnhancementSystem: | |
def __init__(self, api_key: str, base_url: Optional[str] = None): | |
self.api = PromptEnhancementAPI(api_key, base_url) | |
self.current_state = None | |
self.history = [] | |
def start_session(self, prompt: str, user_directive: str = "") -> Dict[str, Any]: | |
formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format( | |
input_prompt=prompt, | |
user_directive=user_directive | |
) | |
result = self.api.generate_enhancement( | |
system_prompt=formatted_system_prompt, | |
user_prompt=prompt, | |
user_directive=user_directive | |
) | |
self.current_state = result | |
self.history = [result] | |
return result | |
def apply_enhancement(self, choice: str, user_directive: str = "") -> Dict[str, Any]: | |
formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format( | |
input_prompt=choice, | |
user_directive=user_directive | |
) | |
result = self.api.generate_enhancement( | |
system_prompt=formatted_system_prompt, | |
user_prompt=choice, | |
user_directive=user_directive, | |
state=self.current_state | |
) | |
self.current_state = result | |
self.history.append(result) | |
return result |