Spaces:
Running
Running
Fixed human and cleaning code
Browse files
app.py
CHANGED
@@ -2,6 +2,9 @@
|
|
2 |
"""
|
3 |
Game Reasoning Arena — Hugging Face Spaces Gradio App
|
4 |
|
|
|
|
|
|
|
5 |
Pipeline:
|
6 |
User clicks "Start Game" in Gradio
|
7 |
↓
|
@@ -12,21 +15,32 @@ ui/gradio_config_generator.py (run_game_with_existing_infrastructure)
|
|
12 |
src/game_reasoning_arena/ (core game infrastructure)
|
13 |
↓
|
14 |
Game results + metrics displayed in Gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
"""
|
16 |
|
17 |
from __future__ import annotations
|
18 |
|
19 |
-
|
|
|
|
|
20 |
|
|
|
|
|
21 |
import sys
|
22 |
import shutil
|
23 |
from pathlib import Path
|
24 |
from typing import List, Dict, Any, Tuple, Generator, TypedDict
|
25 |
|
|
|
26 |
import pandas as pd
|
27 |
import gradio as gr
|
28 |
|
29 |
-
# Logging
|
30 |
import logging
|
31 |
logging.basicConfig(level=logging.INFO)
|
32 |
log = logging.getLogger("arena_space")
|
@@ -37,27 +51,34 @@ try:
|
|
37 |
except Exception:
|
38 |
pass
|
39 |
|
|
|
|
|
|
|
|
|
40 |
# Make sure src is on PYTHONPATH
|
41 |
src_path = Path(__file__).parent / "src"
|
42 |
if str(src_path) not in sys.path:
|
43 |
sys.path.insert(0, str(src_path))
|
44 |
|
45 |
-
#
|
46 |
-
from game_reasoning_arena.arena.games.registry import
|
|
|
|
|
47 |
from game_reasoning_arena.backends.huggingface_backend import (
|
48 |
-
|
49 |
-
|
50 |
from game_reasoning_arena.backends import (
|
51 |
-
|
52 |
-
|
53 |
|
54 |
-
|
|
|
|
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
# -----------------------------------------------------------------------------
|
59 |
|
60 |
-
#
|
61 |
HUGGINGFACE_MODELS: Dict[str, str] = {
|
62 |
"gpt2": "gpt2",
|
63 |
"distilgpt2": "distilgpt2",
|
@@ -65,26 +86,31 @@ HUGGINGFACE_MODELS: Dict[str, str] = {
|
|
65 |
"EleutherAI/gpt-neo-125M": "EleutherAI/gpt-neo-125M",
|
66 |
}
|
67 |
|
|
|
68 |
GAMES_REGISTRY: Dict[str, Any] = {}
|
69 |
|
|
|
70 |
db_dir = Path(__file__).resolve().parent / "results"
|
71 |
|
|
|
72 |
LEADERBOARD_COLUMNS = [
|
73 |
"agent_name", "agent_type", "# games", "total rewards",
|
74 |
# "avg_generation_time (sec)", # Commented out - needs fixing
|
75 |
"win-rate", "win vs_random (%)",
|
76 |
]
|
77 |
|
78 |
-
#
|
79 |
-
#
|
80 |
-
#
|
81 |
|
|
|
82 |
huggingface_backend = None
|
83 |
if BACKEND_SYSTEM_AVAILABLE:
|
84 |
try:
|
85 |
huggingface_backend = HuggingFaceBackend()
|
86 |
initialize_llm_registry()
|
87 |
|
|
|
88 |
for model_name in HUGGINGFACE_MODELS.keys():
|
89 |
if huggingface_backend.is_model_available(model_name):
|
90 |
registry_key = f"hf_{model_name}"
|
@@ -97,10 +123,11 @@ if BACKEND_SYSTEM_AVAILABLE:
|
|
97 |
log.error("Failed to initialize HuggingFace backend: %s", e)
|
98 |
huggingface_backend = None
|
99 |
|
100 |
-
#
|
101 |
-
#
|
102 |
-
#
|
103 |
|
|
|
104 |
try:
|
105 |
if games_registry is not None:
|
106 |
GAMES_REGISTRY = {
|
@@ -113,33 +140,46 @@ except Exception as e:
|
|
113 |
log.warning("Failed to load games registry: %s", e)
|
114 |
GAMES_REGISTRY = {}
|
115 |
|
|
|
116 |
def _get_game_display_mapping() -> Dict[str, str]:
|
117 |
"""
|
118 |
-
Build a mapping from internal game keys to their human
|
119 |
-
If the registry is not available or a game has no
|
120 |
-
fall back to a title
|
|
|
|
|
|
|
|
|
121 |
"""
|
122 |
mapping: Dict[str, str] = {}
|
123 |
if games_registry is not None and hasattr(games_registry, "_registry"):
|
124 |
for key, info in games_registry._registry.items():
|
125 |
-
|
|
|
|
|
|
|
126 |
if not display:
|
127 |
display = key.replace("_", " ").title()
|
128 |
mapping[key] = display
|
129 |
return mapping
|
130 |
|
131 |
|
132 |
-
#
|
133 |
-
#
|
134 |
-
#
|
135 |
-
|
136 |
|
137 |
def ensure_results_dir() -> None:
|
|
|
138 |
db_dir.mkdir(parents=True, exist_ok=True)
|
139 |
|
140 |
|
141 |
def iter_agent_databases() -> Generator[Tuple[str, str, str], None, None]:
|
142 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
143 |
for db_file in find_or_download_db():
|
144 |
agent_type, model_name = extract_agent_info(db_file)
|
145 |
if agent_type != "random":
|
@@ -147,7 +187,12 @@ def iter_agent_databases() -> Generator[Tuple[str, str, str], None, None]:
|
|
147 |
|
148 |
|
149 |
def find_or_download_db() -> List[str]:
|
150 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
151 |
ensure_results_dir()
|
152 |
|
153 |
random_db_path = db_dir / "random_None.db"
|
@@ -174,6 +219,15 @@ def find_or_download_db() -> List[str]:
|
|
174 |
|
175 |
|
176 |
def extract_agent_info(filename: str) -> Tuple[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
base_name = Path(filename).stem
|
178 |
parts = base_name.split("_", 1)
|
179 |
if len(parts) == 2:
|
@@ -182,7 +236,15 @@ def extract_agent_info(filename: str) -> Tuple[str, str]:
|
|
182 |
|
183 |
|
184 |
def get_available_games(include_aggregated: bool = True) -> List[str]:
|
185 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
if GAMES_REGISTRY:
|
187 |
game_list = sorted(GAMES_REGISTRY.keys())
|
188 |
else:
|
@@ -193,7 +255,12 @@ def get_available_games(include_aggregated: bool = True) -> List[str]:
|
|
193 |
|
194 |
|
195 |
def extract_illegal_moves_summary() -> pd.DataFrame:
|
196 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
197 |
summary = []
|
198 |
for db_file, agent_type, model_name in iter_agent_databases():
|
199 |
conn = sqlite3.connect(db_file)
|
@@ -211,17 +278,19 @@ def extract_illegal_moves_summary() -> pd.DataFrame:
|
|
211 |
|
212 |
|
213 |
|
214 |
-
#
|
215 |
-
#
|
216 |
-
#
|
217 |
|
218 |
class PlayerConfigData(TypedDict, total=False):
|
|
|
219 |
player_types: List[str]
|
220 |
player_type_display: Dict[str, str]
|
221 |
available_models: List[str]
|
222 |
|
223 |
|
224 |
class GameArenaConfig(TypedDict, total=False):
|
|
|
225 |
available_games: List[str]
|
226 |
player_config: PlayerConfigData
|
227 |
model_info: str
|
@@ -231,10 +300,23 @@ class GameArenaConfig(TypedDict, total=False):
|
|
231 |
def setup_player_config(
|
232 |
player_type: str, player_model: str, player_id: str
|
233 |
) -> Dict[str, Any]:
|
234 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
# Create a temporary config to get the display-to-key mapping
|
236 |
temp_config = create_player_config()
|
237 |
-
display_to_key = {
|
|
|
|
|
|
|
238 |
|
239 |
# Map display label back to internal key
|
240 |
internal_key = display_to_key.get(player_type, player_type)
|
@@ -267,6 +349,15 @@ def setup_player_config(
|
|
267 |
|
268 |
|
269 |
def create_player_config(include_aggregated: bool = False) -> GameArenaConfig:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
# Internal names for arena dropdown
|
271 |
available_keys = get_available_games(include_aggregated=include_aggregated)
|
272 |
|
@@ -284,8 +375,14 @@ def create_player_config(include_aggregated: bool = False) -> GameArenaConfig:
|
|
284 |
available_games.append(name)
|
285 |
seen.add(name)
|
286 |
|
|
|
287 |
player_types = ["human", "random_bot"]
|
288 |
-
player_type_display = {
|
|
|
|
|
|
|
|
|
|
|
289 |
if BACKEND_SYSTEM_AVAILABLE:
|
290 |
for model_key in HUGGINGFACE_MODELS.keys():
|
291 |
key = f"hf_{model_key}"
|
@@ -320,9 +417,9 @@ def create_player_config(include_aggregated: bool = False) -> GameArenaConfig:
|
|
320 |
}
|
321 |
|
322 |
|
323 |
-
#
|
324 |
-
#
|
325 |
-
#
|
326 |
|
327 |
def play_game(
|
328 |
game_name: str,
|
@@ -333,6 +430,21 @@ def play_game(
|
|
333 |
rounds: int = 1,
|
334 |
seed: int | None = None,
|
335 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
if game_name == "No Games Found":
|
337 |
return "No games available. Please add game databases."
|
338 |
|
@@ -348,7 +460,8 @@ def play_game(
|
|
348 |
|
349 |
# Map human‑friendly game name back to internal key if needed
|
350 |
config = create_player_config()
|
351 |
-
if "game_display_to_key" in config and
|
|
|
352 |
game_name = config["game_display_to_key"][game_name]
|
353 |
|
354 |
# Map display labels for player types back to keys
|
@@ -381,16 +494,30 @@ def play_game(
|
|
381 |
except Exception as e:
|
382 |
return f"Error during game simulation: {e}"
|
383 |
|
|
|
|
|
|
|
|
|
|
|
384 |
def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
all_stats = []
|
386 |
for db_file, agent_type, model_name in iter_agent_databases():
|
387 |
conn = sqlite3.connect(db_file)
|
388 |
try:
|
389 |
if game_name == "Aggregated Performance":
|
390 |
-
#
|
391 |
df = pd.read_sql_query(
|
392 |
-
"SELECT COUNT(DISTINCT episode) AS games_played,
|
393 |
-
"FROM game_results",
|
394 |
conn,
|
395 |
)
|
396 |
# avg_time = conn.execute(
|
@@ -405,20 +532,22 @@ def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
|
|
405 |
"WHERE opponent = 'random_None'",
|
406 |
).fetchone()[0] or 0
|
407 |
else:
|
408 |
-
#
|
409 |
df = pd.read_sql_query(
|
410 |
-
"SELECT COUNT(DISTINCT episode) AS games_played,
|
|
|
411 |
"FROM game_results WHERE game_name = ?",
|
412 |
conn,
|
413 |
params=(game_name,),
|
414 |
)
|
415 |
# avg_time = conn.execute(
|
416 |
-
# "SELECT AVG(generation_time) FROM moves
|
417 |
-
# (game_name,),
|
418 |
# ).fetchone()[0] or 0
|
419 |
wins_vs_random = conn.execute(
|
420 |
"SELECT COUNT(*) FROM game_results "
|
421 |
-
"WHERE opponent = 'random_None' AND reward > 0
|
|
|
422 |
(game_name,),
|
423 |
).fetchone()[0] or 0
|
424 |
total_vs_random = conn.execute(
|
@@ -455,7 +584,8 @@ def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
|
|
455 |
finally:
|
456 |
conn.close()
|
457 |
|
458 |
-
# Concatenate all rows; if all_stats is empty, return an empty DataFrame
|
|
|
459 |
if not all_stats:
|
460 |
return pd.DataFrame(columns=LEADERBOARD_COLUMNS)
|
461 |
|
@@ -463,10 +593,9 @@ def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
|
|
463 |
return leaderboard_df[LEADERBOARD_COLUMNS]
|
464 |
|
465 |
|
466 |
-
|
467 |
-
#
|
468 |
-
#
|
469 |
-
# -----------------------------------------------------------------------------
|
470 |
|
471 |
def create_bar_plot(
|
472 |
data: pd.DataFrame,
|
@@ -477,7 +606,21 @@ def create_bar_plot(
|
|
477 |
y_label: str,
|
478 |
horizontal: bool = False,
|
479 |
) -> gr.BarPlot:
|
480 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
if horizontal:
|
482 |
# Swap x and y for horizontal bars
|
483 |
return gr.BarPlot(
|
@@ -498,12 +641,21 @@ def create_bar_plot(
|
|
498 |
y_label=y_label,
|
499 |
)
|
500 |
|
501 |
-
# -----------------------------------------------------------------------------
|
502 |
-
# Upload handler (save .db files to scripts/results/)
|
503 |
-
# -----------------------------------------------------------------------------
|
504 |
|
|
|
|
|
|
|
505 |
|
506 |
def handle_db_upload(files: list[gr.File]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
ensure_results_dir()
|
508 |
saved = []
|
509 |
for f in files or []:
|
@@ -515,14 +667,36 @@ def handle_db_upload(files: list[gr.File]) -> str:
|
|
515 |
)
|
516 |
|
517 |
|
518 |
-
#
|
519 |
-
#
|
520 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
|
522 |
with gr.Blocks() as interface:
|
|
|
|
|
|
|
|
|
523 |
with gr.Tab("Game Arena"):
|
524 |
config = create_player_config(include_aggregated=False)
|
525 |
|
|
|
526 |
gr.Markdown("# LLM Game Arena")
|
527 |
gr.Markdown("Play games against LLMs or watch LLMs compete!")
|
528 |
gr.Markdown(
|
@@ -531,6 +705,7 @@ with gr.Blocks() as interface:
|
|
531 |
"No API tokens required!"
|
532 |
)
|
533 |
|
|
|
534 |
with gr.Row():
|
535 |
game_dropdown = gr.Dropdown(
|
536 |
choices=config["available_games"],
|
@@ -550,6 +725,7 @@ with gr.Blocks() as interface:
|
|
550 |
)
|
551 |
|
552 |
def player_selector_block(label: str):
|
|
|
553 |
gr.Markdown(f"### {label}")
|
554 |
# Create display choices (what user sees)
|
555 |
display_choices = [
|
@@ -571,13 +747,18 @@ with gr.Blocks() as interface:
|
|
571 |
)
|
572 |
return dd_type, dd_model
|
573 |
|
|
|
574 |
with gr.Row():
|
575 |
p1_type, p1_model = player_selector_block("Player 0")
|
576 |
p2_type, p2_model = player_selector_block("Player 1")
|
577 |
|
578 |
def _vis(player_type: str):
|
|
|
579 |
# Map display label back to internal key
|
580 |
-
display_to_key = {
|
|
|
|
|
|
|
581 |
internal_key = display_to_key.get(player_type, player_type)
|
582 |
|
583 |
is_llm = (
|
@@ -592,10 +773,11 @@ with gr.Blocks() as interface:
|
|
592 |
)
|
593 |
return gr.update(visible=is_llm)
|
594 |
|
|
|
595 |
p1_type.change(_vis, inputs=p1_type, outputs=p1_model)
|
596 |
p2_type.change(_vis, inputs=p2_type, outputs=p2_model)
|
597 |
|
598 |
-
#
|
599 |
game_state = gr.State(value=None)
|
600 |
human_choices_p0 = gr.State([])
|
601 |
human_choices_p1 = gr.State([])
|
@@ -644,10 +826,15 @@ with gr.Blocks() as interface:
|
|
644 |
visible=False
|
645 |
)
|
646 |
|
647 |
-
#
|
648 |
play_button = gr.Button("🎮 Start Game", variant="primary")
|
649 |
-
start_btn = gr.Button(
|
|
|
|
|
|
|
|
|
650 |
|
|
|
651 |
game_output = gr.Textbox(
|
652 |
label="Game Log",
|
653 |
lines=20,
|
@@ -657,7 +844,10 @@ with gr.Blocks() as interface:
|
|
657 |
def check_for_human_players(p1_type, p2_type):
|
658 |
"""Show/hide interactive controls based on player types."""
|
659 |
# Map display labels back to internal keys
|
660 |
-
display_to_key = {
|
|
|
|
|
|
|
661 |
p1_key = display_to_key.get(p1_type, p1_type)
|
662 |
p2_key = display_to_key.get(p2_type, p2_type)
|
663 |
|
@@ -692,14 +882,19 @@ with gr.Blocks() as interface:
|
|
692 |
)
|
693 |
|
694 |
# Interactive game functions
|
695 |
-
def start_interactive_game(
|
|
|
|
|
696 |
"""Initialize an interactive game session."""
|
697 |
try:
|
698 |
from ui.gradio_config_generator import start_game_interactive
|
699 |
import time
|
700 |
|
701 |
# Map display labels back to internal keys
|
702 |
-
display_to_key = {
|
|
|
|
|
|
|
703 |
p1_key = display_to_key.get(p1_type, p1_type)
|
704 |
p2_key = display_to_key.get(p2_type, p2_type)
|
705 |
|
@@ -721,12 +916,18 @@ with gr.Blocks() as interface:
|
|
721 |
)
|
722 |
|
723 |
# Store choices in state for reliable mapping
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
730 |
|
731 |
# Show/hide dropdowns based on whether each player is human
|
732 |
p0_is_human = (p1_key == "human")
|
@@ -737,8 +938,16 @@ with gr.Blocks() as interface:
|
|
737 |
p0_choices, # human_choices_p0
|
738 |
p1_choices, # human_choices_p1
|
739 |
log, # board_display
|
740 |
-
gr.update(
|
741 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
742 |
gr.update(visible=True), # submit_btn
|
743 |
gr.update(visible=True), # reset_game_btn
|
744 |
)
|
@@ -768,35 +977,39 @@ with gr.Blocks() as interface:
|
|
768 |
gr.update(visible=False)
|
769 |
)
|
770 |
|
771 |
-
#
|
|
|
|
|
|
|
|
|
772 |
log_append, new_state, next_p0, next_p1 = submit_human_move(
|
773 |
-
action_p0=p0_action,
|
774 |
-
action_p1=p1_action,
|
775 |
state=state,
|
776 |
)
|
777 |
|
778 |
-
#
|
779 |
-
|
780 |
-
|
|
|
781 |
|
782 |
-
#
|
783 |
p0_dropdown_choices = [(label, action_id) for action_id, label in new_choices_p0]
|
784 |
p1_dropdown_choices = [(label, action_id) for action_id, label in new_choices_p1]
|
785 |
|
786 |
-
# Check if game is finished
|
787 |
game_over = (new_state.get("terminated", False) or
|
788 |
-
new_state.get("truncated", False)
|
789 |
-
(len(new_choices_p0) == 0 and len(new_choices_p1) == 0))
|
790 |
|
791 |
return (
|
792 |
new_state, # game_state
|
793 |
new_choices_p0, # human_choices_p0
|
794 |
new_choices_p1, # human_choices_p1
|
795 |
log_append, # board_display (append to current)
|
796 |
-
gr.update(choices=p0_dropdown_choices, visible=len(p0_dropdown_choices) > 0 and not game_over, value=None),
|
797 |
-
gr.update(choices=p1_dropdown_choices, visible=len(p1_dropdown_choices) > 0 and not game_over, value=None),
|
798 |
gr.update(visible=not game_over), # submit_btn
|
799 |
-
gr.update(visible=True), # reset_game_btn
|
800 |
)
|
801 |
except Exception as e:
|
802 |
return (
|
|
|
2 |
"""
|
3 |
Game Reasoning Arena — Hugging Face Spaces Gradio App
|
4 |
|
5 |
+
This module provides a web interface for playing games between humans and AI agents,
|
6 |
+
analyzing LLM performance, and visualizing game statistics.
|
7 |
+
|
8 |
Pipeline:
|
9 |
User clicks "Start Game" in Gradio
|
10 |
↓
|
|
|
15 |
src/game_reasoning_arena/ (core game infrastructure)
|
16 |
↓
|
17 |
Game results + metrics displayed in Gradio
|
18 |
+
|
19 |
+
Features:
|
20 |
+
- Interactive human vs AI gameplay
|
21 |
+
- LLM leaderboards and performance metrics
|
22 |
+
- Real-time game visualization
|
23 |
+
- Database management for results
|
24 |
"""
|
25 |
|
26 |
from __future__ import annotations
|
27 |
|
28 |
+
# =============================================================================
|
29 |
+
# IMPORTS
|
30 |
+
# =============================================================================
|
31 |
|
32 |
+
# Standard library imports
|
33 |
+
import sqlite3
|
34 |
import sys
|
35 |
import shutil
|
36 |
from pathlib import Path
|
37 |
from typing import List, Dict, Any, Tuple, Generator, TypedDict
|
38 |
|
39 |
+
# Third-party imports
|
40 |
import pandas as pd
|
41 |
import gradio as gr
|
42 |
|
43 |
+
# Logging configuration
|
44 |
import logging
|
45 |
logging.basicConfig(level=logging.INFO)
|
46 |
log = logging.getLogger("arena_space")
|
|
|
51 |
except Exception:
|
52 |
pass
|
53 |
|
54 |
+
# =============================================================================
|
55 |
+
# PATH SETUP & CORE IMPORTS
|
56 |
+
# =============================================================================
|
57 |
+
|
58 |
# Make sure src is on PYTHONPATH
|
59 |
src_path = Path(__file__).parent / "src"
|
60 |
if str(src_path) not in sys.path:
|
61 |
sys.path.insert(0, str(src_path))
|
62 |
|
63 |
+
# Game arena core imports
|
64 |
+
from game_reasoning_arena.arena.games.registry import (
|
65 |
+
registry as games_registry
|
66 |
+
)
|
67 |
from game_reasoning_arena.backends.huggingface_backend import (
|
68 |
+
HuggingFaceBackend,
|
69 |
+
)
|
70 |
from game_reasoning_arena.backends import (
|
71 |
+
initialize_llm_registry, LLM_REGISTRY,
|
72 |
+
)
|
73 |
|
74 |
+
# =============================================================================
|
75 |
+
# GLOBAL CONFIGURATION
|
76 |
+
# =============================================================================
|
77 |
|
78 |
+
# Backend availability flag
|
79 |
+
BACKEND_SYSTEM_AVAILABLE = True
|
|
|
80 |
|
81 |
+
# HuggingFace demo-safe tiny models (CPU friendly)
|
82 |
HUGGINGFACE_MODELS: Dict[str, str] = {
|
83 |
"gpt2": "gpt2",
|
84 |
"distilgpt2": "distilgpt2",
|
|
|
86 |
"EleutherAI/gpt-neo-125M": "EleutherAI/gpt-neo-125M",
|
87 |
}
|
88 |
|
89 |
+
# Global registries
|
90 |
GAMES_REGISTRY: Dict[str, Any] = {}
|
91 |
|
92 |
+
# Database configuration
|
93 |
db_dir = Path(__file__).resolve().parent / "results"
|
94 |
|
95 |
+
# Leaderboard display columns
|
96 |
LEADERBOARD_COLUMNS = [
|
97 |
"agent_name", "agent_type", "# games", "total rewards",
|
98 |
# "avg_generation_time (sec)", # Commented out - needs fixing
|
99 |
"win-rate", "win vs_random (%)",
|
100 |
]
|
101 |
|
102 |
+
# =============================================================================
|
103 |
+
# BACKEND INITIALIZATION
|
104 |
+
# =============================================================================
|
105 |
|
106 |
+
# Initialize HuggingFace backend and register models
|
107 |
huggingface_backend = None
|
108 |
if BACKEND_SYSTEM_AVAILABLE:
|
109 |
try:
|
110 |
huggingface_backend = HuggingFaceBackend()
|
111 |
initialize_llm_registry()
|
112 |
|
113 |
+
# Register available HuggingFace models
|
114 |
for model_name in HUGGINGFACE_MODELS.keys():
|
115 |
if huggingface_backend.is_model_available(model_name):
|
116 |
registry_key = f"hf_{model_name}"
|
|
|
123 |
log.error("Failed to initialize HuggingFace backend: %s", e)
|
124 |
huggingface_backend = None
|
125 |
|
126 |
+
# =============================================================================
|
127 |
+
# GAMES REGISTRY SETUP
|
128 |
+
# =============================================================================
|
129 |
|
130 |
+
# Load available games from the registry
|
131 |
try:
|
132 |
if games_registry is not None:
|
133 |
GAMES_REGISTRY = {
|
|
|
140 |
log.warning("Failed to load games registry: %s", e)
|
141 |
GAMES_REGISTRY = {}
|
142 |
|
143 |
+
|
144 |
def _get_game_display_mapping() -> Dict[str, str]:
|
145 |
"""
|
146 |
+
Build a mapping from internal game keys to their human-friendly
|
147 |
+
display names. If the registry is not available or a game has no
|
148 |
+
explicit display_name, fall back to a title-cased version of the
|
149 |
+
internal key.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
Dict mapping internal game keys to display names
|
153 |
"""
|
154 |
mapping: Dict[str, str] = {}
|
155 |
if games_registry is not None and hasattr(games_registry, "_registry"):
|
156 |
for key, info in games_registry._registry.items():
|
157 |
+
if isinstance(info, dict):
|
158 |
+
display = info.get("display_name")
|
159 |
+
else:
|
160 |
+
display = None
|
161 |
if not display:
|
162 |
display = key.replace("_", " ").title()
|
163 |
mapping[key] = display
|
164 |
return mapping
|
165 |
|
166 |
|
167 |
+
# =============================================================================
|
168 |
+
# DATABASE HELPER FUNCTIONS
|
169 |
+
# =============================================================================
|
|
|
170 |
|
171 |
def ensure_results_dir() -> None:
|
172 |
+
"""Create the results directory if it doesn't exist."""
|
173 |
db_dir.mkdir(parents=True, exist_ok=True)
|
174 |
|
175 |
|
176 |
def iter_agent_databases() -> Generator[Tuple[str, str, str], None, None]:
|
177 |
+
"""
|
178 |
+
Yield (db_file, agent_type, model_name) for non-random agents.
|
179 |
+
|
180 |
+
Yields:
|
181 |
+
Tuple of (database file path, agent type, model name)
|
182 |
+
"""
|
183 |
for db_file in find_or_download_db():
|
184 |
agent_type, model_name = extract_agent_info(db_file)
|
185 |
if agent_type != "random":
|
|
|
187 |
|
188 |
|
189 |
def find_or_download_db() -> List[str]:
|
190 |
+
"""
|
191 |
+
Return .db files; ensure random_None.db exists with minimal schema.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
List of database file paths
|
195 |
+
"""
|
196 |
ensure_results_dir()
|
197 |
|
198 |
random_db_path = db_dir / "random_None.db"
|
|
|
219 |
|
220 |
|
221 |
def extract_agent_info(filename: str) -> Tuple[str, str]:
|
222 |
+
"""
|
223 |
+
Extract agent type and model name from database filename.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
filename: Database filename (e.g., "llm_gpt2.db")
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
Tuple of (agent_type, model_name)
|
230 |
+
"""
|
231 |
base_name = Path(filename).stem
|
232 |
parts = base_name.split("_", 1)
|
233 |
if len(parts) == 2:
|
|
|
236 |
|
237 |
|
238 |
def get_available_games(include_aggregated: bool = True) -> List[str]:
|
239 |
+
"""
|
240 |
+
Return only games from the registry.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
include_aggregated: Whether to include "Aggregated Performance" option
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
List of available game names
|
247 |
+
"""
|
248 |
if GAMES_REGISTRY:
|
249 |
game_list = sorted(GAMES_REGISTRY.keys())
|
250 |
else:
|
|
|
255 |
|
256 |
|
257 |
def extract_illegal_moves_summary() -> pd.DataFrame:
|
258 |
+
"""
|
259 |
+
Extract summary of illegal moves per agent.
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
DataFrame with agent names and illegal move counts
|
263 |
+
"""
|
264 |
summary = []
|
265 |
for db_file, agent_type, model_name in iter_agent_databases():
|
266 |
conn = sqlite3.connect(db_file)
|
|
|
278 |
|
279 |
|
280 |
|
281 |
+
# =============================================================================
|
282 |
+
# PLAYER CONFIGURATION & TYPE DEFINITIONS
|
283 |
+
# =============================================================================
|
284 |
|
285 |
class PlayerConfigData(TypedDict, total=False):
|
286 |
+
"""Type definition for player configuration data."""
|
287 |
player_types: List[str]
|
288 |
player_type_display: Dict[str, str]
|
289 |
available_models: List[str]
|
290 |
|
291 |
|
292 |
class GameArenaConfig(TypedDict, total=False):
|
293 |
+
"""Type definition for game arena configuration."""
|
294 |
available_games: List[str]
|
295 |
player_config: PlayerConfigData
|
296 |
model_info: str
|
|
|
300 |
def setup_player_config(
|
301 |
player_type: str, player_model: str, player_id: str
|
302 |
) -> Dict[str, Any]:
|
303 |
+
"""
|
304 |
+
Map dropdown selection to agent config for the runner.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
player_type: Display label for player type
|
308 |
+
player_model: Model name if LLM type
|
309 |
+
player_id: Player identifier
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
Agent configuration dictionary
|
313 |
+
"""
|
314 |
# Create a temporary config to get the display-to-key mapping
|
315 |
temp_config = create_player_config()
|
316 |
+
display_to_key = {
|
317 |
+
v: k for k, v in
|
318 |
+
temp_config["player_config"]["player_type_display"].items()
|
319 |
+
}
|
320 |
|
321 |
# Map display label back to internal key
|
322 |
internal_key = display_to_key.get(player_type, player_type)
|
|
|
349 |
|
350 |
|
351 |
def create_player_config(include_aggregated: bool = False) -> GameArenaConfig:
|
352 |
+
"""
|
353 |
+
Create player and game configuration for the arena.
|
354 |
+
|
355 |
+
Args:
|
356 |
+
include_aggregated: Whether to include aggregated stats option
|
357 |
+
|
358 |
+
Returns:
|
359 |
+
Complete game arena configuration
|
360 |
+
"""
|
361 |
# Internal names for arena dropdown
|
362 |
available_keys = get_available_games(include_aggregated=include_aggregated)
|
363 |
|
|
|
375 |
available_games.append(name)
|
376 |
seen.add(name)
|
377 |
|
378 |
+
# Define available player types
|
379 |
player_types = ["human", "random_bot"]
|
380 |
+
player_type_display = {
|
381 |
+
"human": "Human Player",
|
382 |
+
"random_bot": "Random Bot"
|
383 |
+
}
|
384 |
+
|
385 |
+
# Add HuggingFace models if backend is available
|
386 |
if BACKEND_SYSTEM_AVAILABLE:
|
387 |
for model_key in HUGGINGFACE_MODELS.keys():
|
388 |
key = f"hf_{model_key}"
|
|
|
417 |
}
|
418 |
|
419 |
|
420 |
+
# =============================================================================
|
421 |
+
# MAIN GAME LOGIC
|
422 |
+
# =============================================================================
|
423 |
|
424 |
def play_game(
|
425 |
game_name: str,
|
|
|
430 |
rounds: int = 1,
|
431 |
seed: int | None = None,
|
432 |
) -> str:
|
433 |
+
"""
|
434 |
+
Execute a complete game simulation between two players.
|
435 |
+
|
436 |
+
Args:
|
437 |
+
game_name: Name of the game to play
|
438 |
+
player1_type: Type of player 1 (human, random, llm)
|
439 |
+
player2_type: Type of player 2 (human, random, llm)
|
440 |
+
player1_model: Model name for player 1 if LLM
|
441 |
+
player2_model: Model name for player 2 if LLM
|
442 |
+
rounds: Number of rounds to play
|
443 |
+
seed: Random seed for reproducibility
|
444 |
+
|
445 |
+
Returns:
|
446 |
+
Game result log as string
|
447 |
+
"""
|
448 |
if game_name == "No Games Found":
|
449 |
return "No games available. Please add game databases."
|
450 |
|
|
|
460 |
|
461 |
# Map human‑friendly game name back to internal key if needed
|
462 |
config = create_player_config()
|
463 |
+
if ("game_display_to_key" in config and
|
464 |
+
game_name in config["game_display_to_key"]):
|
465 |
game_name = config["game_display_to_key"][game_name]
|
466 |
|
467 |
# Map display labels for player types back to keys
|
|
|
494 |
except Exception as e:
|
495 |
return f"Error during game simulation: {e}"
|
496 |
|
497 |
+
|
498 |
+
# =============================================================================
|
499 |
+
# LEADERBOARD & ANALYTICS
|
500 |
+
# =============================================================================
|
501 |
+
|
502 |
def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
|
503 |
+
"""
|
504 |
+
Extract leaderboard statistics for a specific game or all games.
|
505 |
+
|
506 |
+
Args:
|
507 |
+
game_name: Name of the game or "Aggregated Performance"
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
DataFrame with leaderboard statistics
|
511 |
+
"""
|
512 |
all_stats = []
|
513 |
for db_file, agent_type, model_name in iter_agent_databases():
|
514 |
conn = sqlite3.connect(db_file)
|
515 |
try:
|
516 |
if game_name == "Aggregated Performance":
|
517 |
+
# Get totals across all games in this DB
|
518 |
df = pd.read_sql_query(
|
519 |
+
"SELECT COUNT(DISTINCT episode) AS games_played, "
|
520 |
+
"SUM(reward) AS total_rewards FROM game_results",
|
521 |
conn,
|
522 |
)
|
523 |
# avg_time = conn.execute(
|
|
|
532 |
"WHERE opponent = 'random_None'",
|
533 |
).fetchone()[0] or 0
|
534 |
else:
|
535 |
+
# Filter by the selected game
|
536 |
df = pd.read_sql_query(
|
537 |
+
"SELECT COUNT(DISTINCT episode) AS games_played, "
|
538 |
+
"SUM(reward) AS total_rewards "
|
539 |
"FROM game_results WHERE game_name = ?",
|
540 |
conn,
|
541 |
params=(game_name,),
|
542 |
)
|
543 |
# avg_time = conn.execute(
|
544 |
+
# "SELECT AVG(generation_time) FROM moves "
|
545 |
+
# "WHERE game_name = ?", (game_name,),
|
546 |
# ).fetchone()[0] or 0
|
547 |
wins_vs_random = conn.execute(
|
548 |
"SELECT COUNT(*) FROM game_results "
|
549 |
+
"WHERE opponent = 'random_None' AND reward > 0 "
|
550 |
+
"AND game_name = ?",
|
551 |
(game_name,),
|
552 |
).fetchone()[0] or 0
|
553 |
total_vs_random = conn.execute(
|
|
|
584 |
finally:
|
585 |
conn.close()
|
586 |
|
587 |
+
# Concatenate all rows; if all_stats is empty, return an empty DataFrame
|
588 |
+
# with columns.
|
589 |
if not all_stats:
|
590 |
return pd.DataFrame(columns=LEADERBOARD_COLUMNS)
|
591 |
|
|
|
593 |
return leaderboard_df[LEADERBOARD_COLUMNS]
|
594 |
|
595 |
|
596 |
+
# =============================================================================
|
597 |
+
# VISUALIZATION HELPERS
|
598 |
+
# =============================================================================
|
|
|
599 |
|
600 |
def create_bar_plot(
|
601 |
data: pd.DataFrame,
|
|
|
606 |
y_label: str,
|
607 |
horizontal: bool = False,
|
608 |
) -> gr.BarPlot:
|
609 |
+
"""
|
610 |
+
Create a bar plot with optional horizontal orientation.
|
611 |
+
|
612 |
+
Args:
|
613 |
+
data: DataFrame containing the data
|
614 |
+
x_col: Column name for x-axis
|
615 |
+
y_col: Column name for y-axis
|
616 |
+
title: Plot title
|
617 |
+
x_label: X-axis label
|
618 |
+
y_label: Y-axis label
|
619 |
+
horizontal: Whether to create horizontal bars
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
Gradio BarPlot component
|
623 |
+
"""
|
624 |
if horizontal:
|
625 |
# Swap x and y for horizontal bars
|
626 |
return gr.BarPlot(
|
|
|
641 |
y_label=y_label,
|
642 |
)
|
643 |
|
|
|
|
|
|
|
644 |
|
645 |
+
# =============================================================================
|
646 |
+
# FILE UPLOAD HANDLERS
|
647 |
+
# =============================================================================
|
648 |
|
649 |
def handle_db_upload(files: list[gr.File]) -> str:
|
650 |
+
"""
|
651 |
+
Handle upload of database files to the results directory.
|
652 |
+
|
653 |
+
Args:
|
654 |
+
files: List of uploaded files
|
655 |
+
|
656 |
+
Returns:
|
657 |
+
Status message about upload success
|
658 |
+
"""
|
659 |
ensure_results_dir()
|
660 |
saved = []
|
661 |
for f in files or []:
|
|
|
667 |
)
|
668 |
|
669 |
|
670 |
+
# =============================================================================
|
671 |
+
# GRADIO USER INTERFACE
|
672 |
+
# =============================================================================
|
673 |
+
|
674 |
+
"""
|
675 |
+
This section defines the complete Gradio web interface with the following tabs:
|
676 |
+
1. Game Arena: Interactive gameplay between humans and AI
|
677 |
+
2. Leaderboard: Performance statistics and rankings
|
678 |
+
3. Metrics Dashboard: Visual analytics and charts
|
679 |
+
4. Analysis of LLM Reasoning: Illegal moves and behavior analysis
|
680 |
+
5. About: Documentation and information
|
681 |
+
|
682 |
+
The interface supports:
|
683 |
+
- Real-time human vs AI gameplay
|
684 |
+
- Automatic AI move processing
|
685 |
+
- Dynamic dropdown population
|
686 |
+
- State management for interactive games
|
687 |
+
- File upload for database results
|
688 |
+
- Interactive visualizations
|
689 |
+
"""
|
690 |
|
691 |
with gr.Blocks() as interface:
|
692 |
+
# =========================================================================
|
693 |
+
# TAB 1: GAME ARENA
|
694 |
+
# =========================================================================
|
695 |
+
|
696 |
with gr.Tab("Game Arena"):
|
697 |
config = create_player_config(include_aggregated=False)
|
698 |
|
699 |
+
# Header and introduction
|
700 |
gr.Markdown("# LLM Game Arena")
|
701 |
gr.Markdown("Play games against LLMs or watch LLMs compete!")
|
702 |
gr.Markdown(
|
|
|
705 |
"No API tokens required!"
|
706 |
)
|
707 |
|
708 |
+
# Game selection and configuration
|
709 |
with gr.Row():
|
710 |
game_dropdown = gr.Dropdown(
|
711 |
choices=config["available_games"],
|
|
|
725 |
)
|
726 |
|
727 |
def player_selector_block(label: str):
|
728 |
+
"""Create player selection UI block."""
|
729 |
gr.Markdown(f"### {label}")
|
730 |
# Create display choices (what user sees)
|
731 |
display_choices = [
|
|
|
747 |
)
|
748 |
return dd_type, dd_model
|
749 |
|
750 |
+
# Player configuration
|
751 |
with gr.Row():
|
752 |
p1_type, p1_model = player_selector_block("Player 0")
|
753 |
p2_type, p2_model = player_selector_block("Player 1")
|
754 |
|
755 |
def _vis(player_type: str):
|
756 |
+
"""Show/hide model dropdown based on player type."""
|
757 |
# Map display label back to internal key
|
758 |
+
display_to_key = {
|
759 |
+
v: k for k, v in
|
760 |
+
config["player_config"]["player_type_display"].items()
|
761 |
+
}
|
762 |
internal_key = display_to_key.get(player_type, player_type)
|
763 |
|
764 |
is_llm = (
|
|
|
773 |
)
|
774 |
return gr.update(visible=is_llm)
|
775 |
|
776 |
+
# Wire up model dropdown visibility
|
777 |
p1_type.change(_vis, inputs=p1_type, outputs=p1_model)
|
778 |
p2_type.change(_vis, inputs=p2_type, outputs=p2_model)
|
779 |
|
780 |
+
# Game state management
|
781 |
game_state = gr.State(value=None)
|
782 |
human_choices_p0 = gr.State([])
|
783 |
human_choices_p1 = gr.State([])
|
|
|
826 |
visible=False
|
827 |
)
|
828 |
|
829 |
+
# Game control buttons
|
830 |
play_button = gr.Button("🎮 Start Game", variant="primary")
|
831 |
+
start_btn = gr.Button(
|
832 |
+
"🎯 Start Interactive Game",
|
833 |
+
variant="secondary",
|
834 |
+
visible=False
|
835 |
+
)
|
836 |
|
837 |
+
# Game output display
|
838 |
game_output = gr.Textbox(
|
839 |
label="Game Log",
|
840 |
lines=20,
|
|
|
844 |
def check_for_human_players(p1_type, p2_type):
|
845 |
"""Show/hide interactive controls based on player types."""
|
846 |
# Map display labels back to internal keys
|
847 |
+
display_to_key = {
|
848 |
+
v: k for k, v in
|
849 |
+
config["player_config"]["player_type_display"].items()
|
850 |
+
}
|
851 |
p1_key = display_to_key.get(p1_type, p1_type)
|
852 |
p2_key = display_to_key.get(p2_type, p2_type)
|
853 |
|
|
|
882 |
)
|
883 |
|
884 |
# Interactive game functions
|
885 |
+
def start_interactive_game(
|
886 |
+
game_name, p1_type, p2_type, p1_model, p2_model, rounds
|
887 |
+
):
|
888 |
"""Initialize an interactive game session."""
|
889 |
try:
|
890 |
from ui.gradio_config_generator import start_game_interactive
|
891 |
import time
|
892 |
|
893 |
# Map display labels back to internal keys
|
894 |
+
display_to_key = {
|
895 |
+
v: k for k, v in
|
896 |
+
config["player_config"]["player_type_display"].items()
|
897 |
+
}
|
898 |
p1_key = display_to_key.get(p1_type, p1_type)
|
899 |
p2_key = display_to_key.get(p2_type, p2_type)
|
900 |
|
|
|
916 |
)
|
917 |
|
918 |
# Store choices in state for reliable mapping
|
919 |
+
# [(action_id, label), ...] from _legal_actions_with_labels()
|
920 |
+
p0_choices = legal_p0
|
921 |
+
p1_choices = legal_p1
|
922 |
+
|
923 |
+
# Create Gradio dropdown choices: user sees OpenSpiel action
|
924 |
+
# labels, selects action IDs
|
925 |
+
p0_dropdown_choices = [
|
926 |
+
(label, action_id) for action_id, label in p0_choices
|
927 |
+
]
|
928 |
+
p1_dropdown_choices = [
|
929 |
+
(label, action_id) for action_id, label in p1_choices
|
930 |
+
]
|
931 |
|
932 |
# Show/hide dropdowns based on whether each player is human
|
933 |
p0_is_human = (p1_key == "human")
|
|
|
938 |
p0_choices, # human_choices_p0
|
939 |
p1_choices, # human_choices_p1
|
940 |
log, # board_display
|
941 |
+
gr.update(
|
942 |
+
choices=p0_dropdown_choices,
|
943 |
+
visible=p0_is_human,
|
944 |
+
value=None
|
945 |
+
), # human_move_p0
|
946 |
+
gr.update(
|
947 |
+
choices=p1_dropdown_choices,
|
948 |
+
visible=p1_is_human,
|
949 |
+
value=None
|
950 |
+
), # human_move_p1
|
951 |
gr.update(visible=True), # submit_btn
|
952 |
gr.update(visible=True), # reset_game_btn
|
953 |
)
|
|
|
977 |
gr.update(visible=False)
|
978 |
)
|
979 |
|
980 |
+
# The submit_human_move function already handles:
|
981 |
+
# 1. Taking human actions for human players
|
982 |
+
# 2. Computing AI actions for AI players
|
983 |
+
# 3. Advancing the game with both actions
|
984 |
+
# 4. Returning the next legal moves
|
985 |
log_append, new_state, next_p0, next_p1 = submit_human_move(
|
986 |
+
action_p0=p0_action, # None if P0 is AI, action_id if P0 is human
|
987 |
+
action_p1=p1_action, # None if P1 is AI, action_id if P1 is human
|
988 |
state=state,
|
989 |
)
|
990 |
|
991 |
+
# next_p0 and next_p1 are from _legal_actions_with_labels()
|
992 |
+
# Format: [(action_id, label), ...] where label comes from OpenSpiel
|
993 |
+
new_choices_p0 = next_p0
|
994 |
+
new_choices_p1 = next_p1
|
995 |
|
996 |
+
# Create Gradio dropdown choices: user sees OpenSpiel labels, selects action IDs
|
997 |
p0_dropdown_choices = [(label, action_id) for action_id, label in new_choices_p0]
|
998 |
p1_dropdown_choices = [(label, action_id) for action_id, label in new_choices_p1]
|
999 |
|
1000 |
+
# Check if game is finished
|
1001 |
game_over = (new_state.get("terminated", False) or
|
1002 |
+
new_state.get("truncated", False))
|
|
|
1003 |
|
1004 |
return (
|
1005 |
new_state, # game_state
|
1006 |
new_choices_p0, # human_choices_p0
|
1007 |
new_choices_p1, # human_choices_p1
|
1008 |
log_append, # board_display (append to current)
|
1009 |
+
gr.update(choices=p0_dropdown_choices, visible=len(p0_dropdown_choices) > 0 and not game_over, value=None),
|
1010 |
+
gr.update(choices=p1_dropdown_choices, visible=len(p1_dropdown_choices) > 0 and not game_over, value=None),
|
1011 |
gr.update(visible=not game_over), # submit_btn
|
1012 |
+
gr.update(visible=True), # reset_game_btn
|
1013 |
)
|
1014 |
except Exception as e:
|
1015 |
return (
|