Arena / app.py
FredOru's picture
first working draft
330f0b8
import gradio as gr
import pandas as pd
import os
import time
from threading import Thread
from arena import PromptArena
LABEL_A = "Proposition A"
LABEL_B = "Proposition B"
class PromptArenaApp:
"""
Classe pour encapsuler l'arène et gérer l'interface Gradio.
"""
def __init__(self, arena: PromptArena) -> None:
"""
Initialise l'application et charge les prompts depuis le fichier CSV.
"""
self.arena: PromptArena = arena
def select_and_display_match(self):
"""
Sélectionne un match et l'affiche.
Returns:
Tuple contenant:
- Le texte du premier prompt
- Le texte du second prompt
- Un dictionnaire d'état contenant les IDs des prompts
"""
try:
prompt_a_id, prompt_b_id = self.arena.select_match()
prompt_a_text = self.arena.prompts.get(prompt_a_id, "")
prompt_b_text = self.arena.prompts.get(prompt_b_id, "")
state = {"prompt_a_id": prompt_a_id, "prompt_b_id": prompt_b_id}
return (
prompt_a_text,
prompt_b_text,
state,
gr.update(interactive=True), # button A
gr.update(interactive=True), # button B
gr.update(interactive=False), # match button
)
except Exception as e:
return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {}
def record_winner_a(self, state: dict[str, str]):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
self.arena.record_result(
prompt_a_id, prompt_b_id
) # Mettre à jour la progression et le classement
progress_info = self.get_progress_info()
rankings_table = self.get_rankings_table()
return (
f"Vous avez choisi : {LABEL_A}",
progress_info,
rankings_table,
gr.update(interactive=False), # button A
gr.update(interactive=False), # button B
gr.update(interactive=True), # match button
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def record_winner_b(self, state: dict[str, str]):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
self.arena.record_result(
prompt_b_id, prompt_a_id
) # Mettre à jour la progression et le classement
progress_info = self.get_progress_info()
rankings_table = self.get_rankings_table()
return (
f"Vous avez choisi : {LABEL_B}",
progress_info,
rankings_table,
gr.update(interactive=False), # button A
gr.update(interactive=False), # button B
gr.update(interactive=True), # match button
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def get_progress_info(self) -> str:
"""
Obtient les informations sur la progression du tournoi.
Returns:
str: Message formaté contenant les statistiques de progression
"""
if not self.arena:
return "Aucune arène initialisée. Veuillez d'abord charger des prompts."
try:
progress = self.arena.get_progress()
info = f"Prompts: {progress['total_prompts']}\n"
info += f"Matchs joués: {progress['total_matches']}\n"
info += f"Progression: {progress['progress']:.2f}%\n"
info += (
f"Matchs restants estimés: {progress['estimated_remaining_matches']}\n"
)
info += f"Incertitude moyenne (σ): {progress['avg_sigma']:.4f}"
return info
except Exception as e:
return f"Erreur lors de la récupération de la progression: {str(e)}"
def get_rankings_table(self) -> pd.DataFrame:
"""
Obtient le classement des prompts sous forme de tableau.
Returns:
pd.DataFrame: Tableau de classement des prompts
"""
if not self.arena:
return pd.DataFrame([{"Erreur": "Aucune arène initialisée"}])
try:
rankings = self.arena.get_rankings()
df = pd.DataFrame(rankings)
df = df[["rank", "prompt_id", "score"]]
df = df.rename(
columns={
"rank": "Rang",
"prompt_id": "ID",
"score": "Score",
}
)
return df
except Exception as e:
return pd.DataFrame([{"Erreur": str(e)}])
def create_ui(self) -> gr.Blocks:
"""
Crée l'interface utilisateur Gradio.
Returns:
gr.Blocks: L'application Gradio configurée
"""
with gr.Blocks(title="Prompt Arena", theme=gr.themes.Ocean()) as app:
gr.Markdown('<h1 style="text-align:center;">🥊 Prompt Arena 🥊</h1>')
with gr.Row():
select_btn = gr.Button("Lancer un nouveau match", variant="primary")
with gr.Row():
proposition_a = gr.Textbox(label=LABEL_A, interactive=False)
proposition_b = gr.Textbox(label=LABEL_B, interactive=False)
with gr.Row():
vote_a_btn = gr.Button("Choisir " + LABEL_A, interactive=False)
vote_b_btn = gr.Button("Choisir " + LABEL_B, interactive=False)
result = gr.Textbox("Résultat", interactive=False)
progress_info = gr.Textbox(
label="Progression du concours", interactive=False
)
rankings_table = gr.DataFrame(label="Classement des prompts")
state = gr.State() # contient les IDs des prompts du match en cours
select_btn.click(
self.select_and_display_match,
inputs=[],
outputs=[
proposition_a,
proposition_b,
state,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
vote_a_btn.click(
self.record_winner_a,
inputs=[state],
outputs=[
result,
progress_info,
rankings_table,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
vote_b_btn.click(
self.record_winner_b,
inputs=[state],
outputs=[
result,
progress_info,
rankings_table,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
gr.Row([progress_info, rankings_table])
return app
# Exemple d'utilisation
if __name__ == "__main__":
# load the prompts from the CSV file
prompts = pd.read_csv("prompts.csv", header=None).iloc[:, 0].tolist()
arena = PromptArena(prompts=prompts)
app_instance = PromptArenaApp(arena=arena)
app = app_instance.create_ui()
app.launch()