File size: 18,201 Bytes
e5dd5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
import gradio as gr
import os
import json
import random
import time
from huggingface_hub import InferenceClient
import google.generativeai as genai

# --- Helper Functions ---

def log_message(message, type='info'):
    """Helper to format log messages with a timestamp."""
    timestamp = time.strftime("%H:%M:%S")
    # Simple prefixes for log clarity in the textbox
    if type == 'success':
        return f"[{timestamp}] SUCCESS: {message}"
    if type == 'fail':
        return f"[{timestamp}] FAIL: {message}"
    if type == 'best':
        return f"[{timestamp}] *** {message} ***"
    return f"[{timestamp}] {message}"

# --- Core GEPA Functions (Python Implementation) ---

def run_huggingface_rollout(client, model_id, prompt, input_text):
    """
    Calls the Hugging Face Inference API for the target model (e.g., Gemma).
    This function performs a "rollout" for a given prompt and input.
    """
    # Check if model is a chat model and format accordingly
    if "gemma" in model_id.lower() or "llama" in model_id.lower() or "chat" in model_id.lower():
        # Use chat template format for chat models
        messages = [
            {"role": "user", "content": f"{prompt}\n\nText: \"{input_text}\""}
        ]
        try:
            response = client.chat_completion(
                model=model_id,
                messages=messages,
                max_tokens=100,
                temperature=0.7,
                top_p=0.95
            )
            return response.choices[0].message.content
        except Exception as e:
            # Fallback to text generation if chat completion fails
            pass
    
    # Use standard text generation format
    full_prompt = f"{prompt}\n\nText: \"{input_text}\"\n\nResponse:"
    
    try:
        response = client.text_generation(
            model=model_id,
            prompt=full_prompt,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            top_p=0.95,
            return_full_text=False  # Only return the generated text, not the full prompt
        )
        return response
    except Exception as e:
        # Provide a more specific error message for common issues
        err_str = str(e).lower()
        if "authorization" in err_str or "401" in err_str:
            raise gr.Error(f"Hugging Face API Error: Authorization failed. Please ensure your HF Token is correct and that you have accepted the terms for the model '{model_id}' on its Hugging Face page.")
        if "not found" in err_str or "404" in err_str:
             raise gr.Error(f"Hugging Face API Error: Model '{model_id}' not found or requires a Pro subscription.")
        if "rate limit" in err_str or "429" in err_str:
            raise gr.Error(f"Hugging Face API Error: Rate limit exceeded. Please wait a moment and try again.")
        if "loading" in err_str or "503" in err_str:
            raise gr.Error(f"Hugging Face API Error: Model '{model_id}' is currently loading. Please wait a few minutes and try again.")
        raise gr.Error(f"Hugging Face API Error: {str(e)}")


def evaluation_and_feedback_function(output, task):
    """
    The evaluation function (μ_f in the paper).
    This function scores the model's output and provides textual feedback.
    IMPORTANT: This is the most critical part to customize for a specific task.
    """
    # Handle empty or None output
    if not output or not isinstance(output, str):
        return {
            "score": 0.0,
            "feedback": "No valid output generated by the model."
        }
    
    # --- CUSTOMIZE THIS FUNCTION ---
    # This example checks for keyword presence. For a real task, you might use
    # regex, semantic similarity, code compilation, etc.
    score = 0.0
    feedback = ""
    found_keywords = 0
    expected_keywords = task.get("expected_keywords", [])

    if not expected_keywords:
        return {
            "score": 0.0,
            "feedback": "No evaluation criteria (expected_keywords) found in training data for this task."
        }

    for keyword in expected_keywords:
        if keyword.lower() in output.lower():
            found_keywords += 1
            feedback += f"SUCCESS: Output correctly contained the keyword '{keyword}'.\n"
        else:
            feedback += f"FAILURE: Output was missing the required keyword '{keyword}'.\n"

    score = found_keywords / len(expected_keywords) if expected_keywords else 0.0
    feedback += f"Final Score for this task: {score:.2f}"
    return {"score": score, "feedback": feedback}
    # --- END CUSTOMIZATION ---


def reflect_and_propose_new_prompt(gemini_model, current_prompt, examples):
    """
    Performs the Reflective Prompt Mutation step using a powerful LLM (Gemini).
    """
    examples_text = '---'.join(
        f'Task Input: "{e["input"]}"\nGenerated Output: "{e["output"]}"\nFeedback:\n{e["feedback"]}\n\n'
        for e in examples
    )

    reflection_prompt = f"""You are an expert prompt engineer. Your task is to refine a prompt to improve its performance based on feedback from previous attempts.

Here is the current prompt that needs improvement:
--- CURRENT PROMPT ---
{current_prompt}
--------------------

Here are examples of how the prompt performed on a few tasks, along with feedback on what went wrong or right:
--- EXAMPLES & FEEDBACK ---
{examples_text}
-------------------------

Based on this analysis, your task is to write a new, improved prompt. The new prompt should be a complete set of instructions that directly addresses the failures and incorporates the successful strategies observed in the feedback. Do not just give suggestions; provide the full, ready-to-use prompt.
Your response should ONLY contain the new prompt text, and nothing else."""

    try:
        response = gemini_model.generate_content(reflection_prompt)
        return response.text.strip()
    except Exception as e:
        raise gr.Error(f"Gemini API Error: {str(e)}. Check your Gemini API Key.")


def select_candidate_for_mutation(candidate_pool, num_tasks):
    """
    Selects the next candidate to mutate based on the Pareto-based strategy.
    """
    if len(candidate_pool) == 1:
        return candidate_pool[0]

    best_scores_per_task = [-1.0] * num_tasks
    for candidate in candidate_pool:
        for i in range(num_tasks):
            if candidate["scores"][i] > best_scores_per_task[i]:
                best_scores_per_task[i] = candidate["scores"][i]

    pareto_front_ids = set()
    for i in range(num_tasks):
        for candidate in candidate_pool:
            if abs(candidate["scores"][i] - best_scores_per_task[i]) < 1e-6:
                pareto_front_ids.add(candidate["id"])

    if not pareto_front_ids:
        return max(candidate_pool, key=lambda c: c["avg_score"])

    selected_id = random.choice(list(pareto_front_ids))
    return next(c for c in candidate_pool if c["id"] == selected_id)


def test_model_connection(hf_client, model_id):
    """Test if the model is accessible and working"""
    test_prompt = "Hello, world!"
    try:
        response = run_huggingface_rollout(hf_client, model_id, "Say hello", test_prompt)
        return True, response
    except Exception as e:
        return False, str(e)


# --- Main Gradio Application Logic ---

def run_gepa_optimization(hf_token_from_input, gemini_key_from_input, model_id, seed_prompt, training_data_str, budget):
    """
    The main function that orchestrates the GEPA optimization process.
    This is a generator function that yields updates to the Gradio UI.
    """
    # --- Get API Keys from Secrets or Inputs ---
    hf_token = os.environ.get("HF_TOKEN") or hf_token_from_input
    gemini_key = os.environ.get("GEMINI_API_KEY") or gemini_key_from_input

    # --- Validate Inputs ---
    if not hf_token:
        raise gr.Error("Hugging Face API Token is required. Add it as a Space Secret named HF_TOKEN or enter it in the textbox.")
    if not gemini_key:
        raise gr.Error("Google Gemini API Key is required. Add it as a Space Secret named GEMINI_API_KEY or enter it in the textbox.")
    try:
        training_data = json.loads(training_data_str)
        if not isinstance(training_data, list) or not all(isinstance(item, dict) for item in training_data):
            raise ValueError()
    except (json.JSONDecodeError, ValueError):
        raise gr.Error("Training Data is not valid JSON. It should be a list of objects.")

    # --- Initialization ---
    log_history = []
    hf_client = InferenceClient(token=hf_token)
    genai.configure(api_key=gemini_key)
    gemini_model = genai.GenerativeModel('gemini-1.5-flash')

    rollout_count = 0
    candidate_pool = []
    best_candidate = {
        "prompt": "Initializing...",
        "avg_score": 0.0
    }

    def get_current_state():
        return "\n".join(log_history), best_candidate["prompt"], f"{best_candidate['avg_score']:.2f}"

    # --- Test Model Connection First ---
    log_history.append(log_message("Testing model connection..."))
    yield get_current_state()
    
    connection_ok, test_result = test_model_connection(hf_client, model_id)
    if not connection_ok:
        log_history.append(log_message(f"Model connection failed: {test_result}", 'fail'))
        yield get_current_state()
        raise gr.Error(f"Cannot connect to model '{model_id}': {test_result}")
    
    log_history.append(log_message("Model connection successful!", 'success'))
    yield get_current_state()

    # --- Initial Evaluation of Seed Prompt ---
    log_history.append(log_message("Initializing with seed prompt..."))
    yield get_current_state()

    initial_candidate = {"id": 0, "prompt": seed_prompt, "parentId": None, "scores": [0.0] * len(training_data), "avg_score": 0.0}
    total_score = 0.0
    for i, task in enumerate(training_data):
        log_history.append(log_message(f"  - Evaluating seed on task {i+1}..."))
        yield get_current_state()
        
        try:
            output = run_huggingface_rollout(hf_client, model_id, initial_candidate["prompt"], task["input"])
            eval_result = evaluation_and_feedback_function(output, task)
            initial_candidate["scores"][i] = eval_result["score"]
            total_score += eval_result["score"]
            rollout_count += 1
        except Exception as e:
            log_history.append(log_message(f"Error on task {i+1}: {str(e)}", 'fail'))
            yield get_current_state()
            # Continue with next task but record 0 score
            initial_candidate["scores"][i] = 0.0
            rollout_count += 1

    initial_candidate["avg_score"] = total_score / len(training_data) if training_data else 0.0
    candidate_pool.append(initial_candidate)
    best_candidate = initial_candidate

    log_history.append(log_message(f"Seed prompt initial score: {initial_candidate['avg_score']:.2f}", 'best'))
    yield get_current_state()

    # --- Main Optimization Loop ---
    while rollout_count < budget:
        log_history.append(log_message(f"--- Iteration Start (Rollouts: {rollout_count}/{budget}) ---"))
        yield get_current_state()

        parent_candidate = select_candidate_for_mutation(candidate_pool, len(training_data))
        log_history.append(log_message(f"Selected candidate #{parent_candidate['id']} (Score: {parent_candidate['avg_score']:.2f}) for mutation."))
        yield get_current_state()

        task_index = random.randint(0, len(training_data) - 1)
        reflection_task = training_data[task_index]
        log_history.append(log_message(f"Performing reflective mutation using task {task_index + 1}..."))
        yield get_current_state()

        try:
            rollout_output = run_huggingface_rollout(hf_client, model_id, parent_candidate["prompt"], reflection_task["input"])
            rollout_count += 1
            eval_result = evaluation_and_feedback_function(rollout_output, reflection_task)

            new_prompt = reflect_and_propose_new_prompt(gemini_model, parent_candidate["prompt"], [{
                "input": reflection_task["input"],
                "output": rollout_output,
                "feedback": eval_result["feedback"]
            }])
            
            new_candidate = {"id": len(candidate_pool), "prompt": new_prompt, "parentId": parent_candidate["id"], "scores": [0.0] * len(training_data), "avg_score": 0.0}
            log_history.append(log_message(f"Generated new candidate prompt #{new_candidate['id']}."))
            yield get_current_state()

            new_total_score = 0.0
            for i, task in enumerate(training_data):
                if rollout_count >= budget: break
                try:
                    output = run_huggingface_rollout(hf_client, model_id, new_candidate["prompt"], task["input"])
                    eval_result = evaluation_and_feedback_function(output, task)
                    new_candidate["scores"][i] = eval_result["score"]
                    new_total_score += eval_result["score"]
                    rollout_count += 1
                except Exception as e:
                    log_history.append(log_message(f"Error evaluating new candidate on task {i+1}: {str(e)}", 'fail'))
                    new_candidate["scores"][i] = 0.0
                    rollout_count += 1
                    
            new_candidate["avg_score"] = new_total_score / len(training_data) if training_data else 0.0

            if new_candidate["avg_score"] > parent_candidate["avg_score"]:
                log_history.append(log_message(f"New candidate #{new_candidate['id']} improved! Score: {new_candidate['avg_score']:.2f} > {parent_candidate['avg_score']:.2f}", 'success'))
                candidate_pool.append(new_candidate)
                if new_candidate["avg_score"] > best_candidate["avg_score"]:
                    best_candidate = new_candidate
                    log_history.append(log_message("NEW BEST PROMPT FOUND!", 'best'))
                    yield get_current_state()
            else:
                log_history.append(log_message(f"New candidate #{new_candidate['id']} did not improve. Score: {new_candidate['avg_score']:.2f}. Discarding.", 'fail'))
                
        except Exception as e:
            log_history.append(log_message(f"Error in optimization iteration: {str(e)}", 'fail'))
            rollout_count += 1  # Count the failed attempt
        
        yield get_current_state()

    log_history.append(log_message("Optimization budget exhausted. Finished.", 'best'))
    yield get_current_state()


# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="GEPA Prompt Optimizer") as demo:
    gr.Markdown("""
    # GEPA Prompt Optimizer for Hugging Face Models
    This Space implements the **GEPA (Genetic-Pareto)** framework to automatically optimize prompts for a target model (like Gemma) hosted on Hugging Face.
    It uses a powerful LLM (Gemini) for the "reflection" step to propose high-quality prompt improvements.
    
    **Important Notes:**
    - Make sure you have accepted the model's license on Hugging Face
    - Some models require a Pro subscription for Inference API access
    - Start with a small budget (5-10) to test connectivity first
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## 1. Configuration")
            hf_token_input = gr.Textbox(label="Hugging Face API Token (Optional)", type="password", info="Leave blank if HF_TOKEN is set as a Space Secret.")
            gemini_key_input = gr.Textbox(label="Google Gemini API Key (Optional)", type="password", info="Leave blank if GEMINI_API_KEY is set as a Space Secret.")
            model_id_input = gr.Textbox(
                label="Target Model ID", 
                value="microsoft/DialoGPT-medium", 
                info="The Hugging Face model to optimize for. Try 'microsoft/DialoGPT-medium' or 'gpt2' for free models."
            )
            seed_prompt_input = gr.Textbox(
                label="Initial Seed Prompt", 
                lines=5, 
                value="You are a helpful assistant that summarizes text. Given the following text, provide a one-sentence summary that captures the main points."
            )
            training_data_input = gr.Code(
                label="Training Data (JSON)",
                language="json",
                lines=10,
                value="""[
    {
        "input": "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France. It is named after the engineer Gustave Eiffel, whose company designed and built the tower.",
        "expected_keywords": ["Eiffel Tower", "Paris"]
    },
    {
        "input": "The Great Wall of China is a series of fortifications that were built across the historical northern borders of ancient Chinese states and Imperial China as protection against various nomadic groups from the Eurasian Steppe.",
        "expected_keywords": ["Great Wall", "China", "fortifications"]
    },
    {
        "input": "The Colosseum is an oval amphitheatre in the centre of the city of Rome, Italy, just east of the Roman Forum. It is the largest ancient amphitheatre ever built, and is still the largest standing amphitheatre in the world today, despite its age.",
        "expected_keywords": ["Colosseum", "Rome", "amphitheatre"]
    }
]"""
            )
            budget_input = gr.Slider(label="Optimization Budget (Total Rollouts)", minimum=5, maximum=50, value=10, step=1)
            start_button = gr.Button("Start Optimization", variant="primary")

        with gr.Column(scale=1):
            gr.Markdown("## 2. Results")
            best_prompt_output = gr.Textbox(label="Best Prompt Found", lines=8, interactive=False)
            best_score_output = gr.Textbox(label="Best Score", interactive=False)
            log_output = gr.Textbox(label="Optimization Log", lines=20, interactive=False, autoscroll=True)

    start_button.click(
        fn=run_gepa_optimization,
        inputs=[hf_token_input, gemini_key_input, model_id_input, seed_prompt_input, training_data_input, budget_input],
        outputs=[log_output, best_prompt_output, best_score_output]
    )

if __name__ == "__main__":
    demo.launch()