File size: 5,347 Bytes
05f2374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
import logging
from openai import OpenAI
from typing import Dict, Any, Optional
import gradio as gr
from prompts import PROMPT_ANALYZER_TEMPLATE
import time

logger = logging.getLogger(__name__)

FALLBACK_MODELS = [
    "mixtral-8x7b-32768",
    "llama-3.1-70b-versatile",
    "llama-3.1-8b-instant",
    "llama3-70b-8192",
    "llama3-8b-8192"
]

class ModelManager:
    def __init__(self):
        self.current_model_index = 0
        self.max_retries = len(FALLBACK_MODELS)

    @property
    def current_model(self) -> str:
        return FALLBACK_MODELS[self.current_model_index]

    def next_model(self) -> str:
        self.current_model_index = (self.current_model_index + 1) % len(FALLBACK_MODELS)
        logger.info(f"Switching to model: {self.current_model}")
        return self.current_model

class PromptEnhancementAPI:
    def __init__(self, api_key: str, base_url: Optional[str] = None):
        self.client = OpenAI(
            api_key=api_key,
            base_url=base_url or "https://api.groq.com/openai/v1"
        )
        self.model_manager = ModelManager()

    def _try_parse_json(self, content: str, retries: int = 0) -> Dict[str, Any]:
        try:
            result = json.loads(content.strip().lstrip('\n'))
            if not isinstance(result, dict):
                raise ValueError("Response is not a valid JSON object")
            return result
        except (json.JSONDecodeError, ValueError) as e:
            if retries < self.model_manager.max_retries - 1:
                logger.warning(f"JSON parsing failed with model {self.model_manager.current_model}. Switching models...")
                self.model_manager.next_model()
                raise e
            logger.error(f"JSON parsing failed with all models: {str(e)}")
            raise

    def generate_enhancement(self, system_prompt: str, user_prompt: str, user_directive: str = "", state: Optional[Dict] = None) -> Dict[str, Any]:
        retries = 0
        last_error = None

        while retries < self.model_manager.max_retries:
            try:
                messages = [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ]
                
                if user_directive:
                    messages.append({"role": "user", "content": f"User directive: {user_directive}"})

                if state:
                    messages.append({
                        "role": "assistant",
                        "content": json.dumps(state)
                    })

                response = self.client.chat.completions.create(
                    model=self.model_manager.current_model,
                    messages=messages,
                    temperature=0.7,
                    max_tokens=4000,
                    response_format={"type": "json_object"}
                )
                
                result = self._try_parse_json(response.choices[0].message.content, retries)
                return result

            except (json.JSONDecodeError, ValueError) as e:
                last_error = e
                retries += 1
                if retries < self.model_manager.max_retries:
                    logger.warning(f"Attempt {retries} failed. Switching models and retrying...")
                    time.sleep(1)  # Brief pause before retry
                    continue
                break

            except Exception as e:
                logger.error(f"API error: {str(e)}")
                if "rate limit" in str(e).lower():
                    if retries < self.model_manager.max_retries - 1:
                        self.model_manager.next_model()
                        retries += 1
                        time.sleep(1)
                        continue
                raise gr.Error(f"API request failed: {str(e)}")

        logger.error(f"All models failed to generate valid JSON: {str(last_error)}")
        return create_error_response(user_prompt, user_directive)

class PromptEnhancementSystem:
    def __init__(self, api_key: str, base_url: Optional[str] = None):
        self.api = PromptEnhancementAPI(api_key, base_url)
        self.current_state = None
        self.history = []

    def start_session(self, prompt: str, user_directive: str = "") -> Dict[str, Any]:
        formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format(
            input_prompt=prompt,
            user_directive=user_directive
        )
        
        result = self.api.generate_enhancement(
            system_prompt=formatted_system_prompt,
            user_prompt=prompt,
            user_directive=user_directive
        )
        
        self.current_state = result
        self.history = [result]
        return result

    def apply_enhancement(self, choice: str, user_directive: str = "") -> Dict[str, Any]:
        formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format(
            input_prompt=choice,
            user_directive=user_directive
        )
        
        result = self.api.generate_enhancement(
            system_prompt=formatted_system_prompt,
            user_prompt=choice,
            user_directive=user_directive,
            state=self.current_state
        )
        
        self.current_state = result
        self.history.append(result)
        return result