Spaces:
Sleeping
Sleeping
Added plots
Browse filesAdded an additional tab for plots
app.py
CHANGED
@@ -15,11 +15,11 @@ def find_or_download_db():
|
|
15 |
if not os.path.exists(db_dir):
|
16 |
os.makedirs(db_dir)
|
17 |
db_files = glob.glob(os.path.join(db_dir, "*.db"))
|
18 |
-
|
19 |
# Ensure the random bot database exists
|
20 |
if "results/random_None.db" not in db_files:
|
21 |
raise FileNotFoundError("Please upload results for the random agent in a file named 'random_None.db'.")
|
22 |
-
|
23 |
return db_files
|
24 |
|
25 |
def extract_agent_info(filename: str):
|
@@ -36,7 +36,7 @@ def get_available_games(include_aggregated=True) -> List[str]:
|
|
36 |
"""Extracts all unique game names from all SQLite databases. Includes 'Aggregated Performance' only when required."""
|
37 |
db_files = find_or_download_db()
|
38 |
game_names = set()
|
39 |
-
|
40 |
for db_file in db_files:
|
41 |
conn = sqlite3.connect(db_file)
|
42 |
try:
|
@@ -47,51 +47,74 @@ def get_available_games(include_aggregated=True) -> List[str]:
|
|
47 |
pass # Ignore errors if table doesn't exist
|
48 |
finally:
|
49 |
conn.close()
|
50 |
-
|
51 |
game_list = sorted(game_names) if game_names else ["No Games Found"]
|
52 |
if include_aggregated:
|
53 |
game_list.insert(0, "Aggregated Performance") # Ensure 'Aggregated Performance' is always first
|
54 |
return game_list
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
|
57 |
"""Extract and aggregate leaderboard stats from all SQLite databases."""
|
58 |
db_files = find_or_download_db()
|
59 |
all_stats = []
|
60 |
-
|
61 |
for db_file in db_files:
|
62 |
conn = sqlite3.connect(db_file)
|
63 |
agent_type, model_name = extract_agent_info(db_file)
|
64 |
-
|
65 |
# Skip random agent rows
|
66 |
if agent_type == "random":
|
67 |
conn.close()
|
68 |
continue
|
69 |
-
|
70 |
if game_name == "Aggregated Performance":
|
71 |
query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
|
72 |
"SUM(reward) AS total_rewards " \
|
73 |
"FROM game_results"
|
74 |
df = pd.read_sql_query(query, conn)
|
75 |
-
|
76 |
-
#
|
77 |
-
game_query = "SELECT AVG(generation_time) FROM moves"
|
78 |
avg_gen_time = conn.execute(game_query).fetchone()[0] or 0
|
79 |
else:
|
80 |
query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
|
81 |
"SUM(reward) AS total_rewards " \
|
82 |
"FROM game_results WHERE game_name = ?"
|
83 |
df = pd.read_sql_query(query, conn, params=(game_name,))
|
84 |
-
|
85 |
# Fetch average generation time from moves table
|
86 |
gen_time_query = "SELECT AVG(generation_time) FROM moves WHERE game_name = ?"
|
87 |
avg_gen_time = conn.execute(gen_time_query, (game_name,)).fetchone()[0] or 0
|
88 |
-
|
89 |
# Keep division by 2 for total rewards
|
90 |
df["total_rewards"] = df["total_rewards"].fillna(0).astype(float) / 2
|
91 |
-
|
92 |
# Ensure avg_gen_time has decimals
|
93 |
avg_gen_time = round(avg_gen_time, 3)
|
94 |
-
|
95 |
# Calculate win rate against random bot using moves table
|
96 |
vs_random_query = """
|
97 |
SELECT COUNT(DISTINCT gr.episode) FROM game_results gr
|
@@ -106,22 +129,24 @@ def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
|
|
106 |
wins_vs_random = conn.execute(vs_random_query).fetchone()[0] or 0
|
107 |
total_vs_random = conn.execute(total_vs_random_query).fetchone()[0] or 0
|
108 |
vs_random_rate = (wins_vs_random / total_vs_random * 100) if total_vs_random > 0 else 0
|
109 |
-
|
110 |
df.insert(0, "agent_name", model_name) # Ensure agent_name is the first column
|
111 |
df.insert(1, "agent_type", agent_type) # Ensure agent_type is second column
|
112 |
df["avg_generation_time (sec)"] = avg_gen_time
|
113 |
df["win vs_random (%)"] = round(vs_random_rate, 2)
|
114 |
-
|
115 |
all_stats.append(df)
|
116 |
conn.close()
|
117 |
-
|
118 |
leaderboard_df = pd.concat(all_stats, ignore_index=True) if all_stats else pd.DataFrame()
|
119 |
-
|
120 |
if leaderboard_df.empty:
|
121 |
leaderboard_df = pd.DataFrame(columns=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"])
|
122 |
-
|
123 |
return leaderboard_df
|
124 |
|
|
|
|
|
125 |
with gr.Blocks() as interface:
|
126 |
# Tab for playing games against LLMs
|
127 |
with gr.Tab("Game Arena"):
|
@@ -134,10 +159,10 @@ with gr.Blocks() as interface:
|
|
134 |
play_button = gr.Button("Start Game")
|
135 |
# Textbox to display the game log
|
136 |
game_output = gr.Textbox(label="Game Log")
|
137 |
-
|
138 |
# Event to start the game when the button is clicked
|
139 |
play_button.click(lambda game, opponent: f"Game {game} started against {opponent}", inputs=[game_dropdown, opponent_dropdown], outputs=[game_output])
|
140 |
-
|
141 |
# Tab for leaderboard and performance tracking
|
142 |
with gr.Tab("Leaderboard"):
|
143 |
gr.Markdown("# LLM Model Leaderboard\nTrack performance across different games!")
|
@@ -147,6 +172,57 @@ with gr.Blocks() as interface:
|
|
147 |
leaderboard_table = gr.Dataframe(value=extract_leaderboard_stats("Aggregated Performance"), headers=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"], every=5)
|
148 |
# Update the leaderboard when a new game is selected
|
149 |
leaderboard_game_dropdown.change(fn=extract_leaderboard_stats, inputs=[leaderboard_game_dropdown], outputs=[leaderboard_table])
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
# Launch the Gradio interface
|
152 |
interface.launch()
|
|
|
15 |
if not os.path.exists(db_dir):
|
16 |
os.makedirs(db_dir)
|
17 |
db_files = glob.glob(os.path.join(db_dir, "*.db"))
|
18 |
+
|
19 |
# Ensure the random bot database exists
|
20 |
if "results/random_None.db" not in db_files:
|
21 |
raise FileNotFoundError("Please upload results for the random agent in a file named 'random_None.db'.")
|
22 |
+
|
23 |
return db_files
|
24 |
|
25 |
def extract_agent_info(filename: str):
|
|
|
36 |
"""Extracts all unique game names from all SQLite databases. Includes 'Aggregated Performance' only when required."""
|
37 |
db_files = find_or_download_db()
|
38 |
game_names = set()
|
39 |
+
|
40 |
for db_file in db_files:
|
41 |
conn = sqlite3.connect(db_file)
|
42 |
try:
|
|
|
47 |
pass # Ignore errors if table doesn't exist
|
48 |
finally:
|
49 |
conn.close()
|
50 |
+
|
51 |
game_list = sorted(game_names) if game_names else ["No Games Found"]
|
52 |
if include_aggregated:
|
53 |
game_list.insert(0, "Aggregated Performance") # Ensure 'Aggregated Performance' is always first
|
54 |
return game_list
|
55 |
|
56 |
+
def extract_illegal_moves_summary()-> pd.DataFrame:
|
57 |
+
"""Extracts the number of illegal moves made by each LLM agent.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
pd.DataFrame: DataFrame with columns [agent_name, illegal_moves].
|
61 |
+
"""
|
62 |
+
db_files = find_or_download_db()
|
63 |
+
summary = []
|
64 |
+
for db_file in db_files:
|
65 |
+
agent_type, model_name = extract_agent_info(db_file)
|
66 |
+
if agent_type == "random":
|
67 |
+
continue # Skip the random agent from this analysis
|
68 |
+
conn = sqlite3.connect(db_file)
|
69 |
+
try:
|
70 |
+
# Count number of illegal moves from the illegal_moves table
|
71 |
+
df = pd.read_sql_query("SELECT COUNT(*) AS illegal_moves FROM illegal_moves", conn)
|
72 |
+
count = int(df["illegal_moves"].iloc[0]) if not df.empty else 0
|
73 |
+
except Exception:
|
74 |
+
count = 0 # If the table does not exist or error occurs
|
75 |
+
summary.append({"agent_name": model_name, "illegal_moves": count})
|
76 |
+
conn.close()
|
77 |
+
return pd.DataFrame(summary)
|
78 |
+
|
79 |
def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
|
80 |
"""Extract and aggregate leaderboard stats from all SQLite databases."""
|
81 |
db_files = find_or_download_db()
|
82 |
all_stats = []
|
83 |
+
|
84 |
for db_file in db_files:
|
85 |
conn = sqlite3.connect(db_file)
|
86 |
agent_type, model_name = extract_agent_info(db_file)
|
87 |
+
|
88 |
# Skip random agent rows
|
89 |
if agent_type == "random":
|
90 |
conn.close()
|
91 |
continue
|
92 |
+
|
93 |
if game_name == "Aggregated Performance":
|
94 |
query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
|
95 |
"SUM(reward) AS total_rewards " \
|
96 |
"FROM game_results"
|
97 |
df = pd.read_sql_query(query, conn)
|
98 |
+
|
99 |
+
# Use avg_generation_time from a specific game (e.g., Kuhn Poker)
|
100 |
+
game_query = "SELECT AVG(generation_time) FROM moves WHERE game_name = 'kuhn_poker'"
|
101 |
avg_gen_time = conn.execute(game_query).fetchone()[0] or 0
|
102 |
else:
|
103 |
query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
|
104 |
"SUM(reward) AS total_rewards " \
|
105 |
"FROM game_results WHERE game_name = ?"
|
106 |
df = pd.read_sql_query(query, conn, params=(game_name,))
|
107 |
+
|
108 |
# Fetch average generation time from moves table
|
109 |
gen_time_query = "SELECT AVG(generation_time) FROM moves WHERE game_name = ?"
|
110 |
avg_gen_time = conn.execute(gen_time_query, (game_name,)).fetchone()[0] or 0
|
111 |
+
|
112 |
# Keep division by 2 for total rewards
|
113 |
df["total_rewards"] = df["total_rewards"].fillna(0).astype(float) / 2
|
114 |
+
|
115 |
# Ensure avg_gen_time has decimals
|
116 |
avg_gen_time = round(avg_gen_time, 3)
|
117 |
+
|
118 |
# Calculate win rate against random bot using moves table
|
119 |
vs_random_query = """
|
120 |
SELECT COUNT(DISTINCT gr.episode) FROM game_results gr
|
|
|
129 |
wins_vs_random = conn.execute(vs_random_query).fetchone()[0] or 0
|
130 |
total_vs_random = conn.execute(total_vs_random_query).fetchone()[0] or 0
|
131 |
vs_random_rate = (wins_vs_random / total_vs_random * 100) if total_vs_random > 0 else 0
|
132 |
+
|
133 |
df.insert(0, "agent_name", model_name) # Ensure agent_name is the first column
|
134 |
df.insert(1, "agent_type", agent_type) # Ensure agent_type is second column
|
135 |
df["avg_generation_time (sec)"] = avg_gen_time
|
136 |
df["win vs_random (%)"] = round(vs_random_rate, 2)
|
137 |
+
|
138 |
all_stats.append(df)
|
139 |
conn.close()
|
140 |
+
|
141 |
leaderboard_df = pd.concat(all_stats, ignore_index=True) if all_stats else pd.DataFrame()
|
142 |
+
|
143 |
if leaderboard_df.empty:
|
144 |
leaderboard_df = pd.DataFrame(columns=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"])
|
145 |
+
|
146 |
return leaderboard_df
|
147 |
|
148 |
+
|
149 |
+
##########################################################
|
150 |
with gr.Blocks() as interface:
|
151 |
# Tab for playing games against LLMs
|
152 |
with gr.Tab("Game Arena"):
|
|
|
159 |
play_button = gr.Button("Start Game")
|
160 |
# Textbox to display the game log
|
161 |
game_output = gr.Textbox(label="Game Log")
|
162 |
+
|
163 |
# Event to start the game when the button is clicked
|
164 |
play_button.click(lambda game, opponent: f"Game {game} started against {opponent}", inputs=[game_dropdown, opponent_dropdown], outputs=[game_output])
|
165 |
+
|
166 |
# Tab for leaderboard and performance tracking
|
167 |
with gr.Tab("Leaderboard"):
|
168 |
gr.Markdown("# LLM Model Leaderboard\nTrack performance across different games!")
|
|
|
172 |
leaderboard_table = gr.Dataframe(value=extract_leaderboard_stats("Aggregated Performance"), headers=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"], every=5)
|
173 |
# Update the leaderboard when a new game is selected
|
174 |
leaderboard_game_dropdown.change(fn=extract_leaderboard_stats, inputs=[leaderboard_game_dropdown], outputs=[leaderboard_table])
|
175 |
+
|
176 |
+
# Tab for visual insights and performance metrics
|
177 |
+
with gr.Tab("Metrics Dashboard"):
|
178 |
+
gr.Markdown("# 📊 Metrics Dashboard\nVisual summaries of LLM performance across games.")
|
179 |
+
|
180 |
+
# Extract data for visualizations
|
181 |
+
metrics_df = extract_leaderboard_stats("Aggregated Performance")
|
182 |
+
|
183 |
+
with gr.Row():
|
184 |
+
gr.BarPlot(
|
185 |
+
x=metrics_df["agent_name"],
|
186 |
+
y=metrics_df["win vs_random (%)"],
|
187 |
+
title="Win Rate vs Random Bot",
|
188 |
+
x_label="LLM Model",
|
189 |
+
y_label="Win Rate (%)"
|
190 |
+
)
|
191 |
+
|
192 |
+
with gr.Row():
|
193 |
+
gr.BarPlot(
|
194 |
+
x=metrics_df["agent_name"],
|
195 |
+
y=metrics_df["avg_generation_time (sec)"],
|
196 |
+
title="Average Generation Time",
|
197 |
+
x_label="LLM Model",
|
198 |
+
y_label="Time (sec)"
|
199 |
+
)
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
gr.Dataframe(value=metrics_df, label="Performance Summary")
|
203 |
+
|
204 |
+
# Tab for LLM reasoning and illegal move analysis
|
205 |
+
with gr.Tab("Analysis of LLM Reasoning"):
|
206 |
+
gr.Markdown("# 🧠 Analysis of LLM Reasoning\nInsights into move legality and decision behavior.")
|
207 |
+
|
208 |
+
# Load illegal move stats using global function
|
209 |
+
illegal_df = extract_illegal_moves_summary()
|
210 |
+
|
211 |
+
with gr.Row():
|
212 |
+
gr.BarPlot(
|
213 |
+
x=illegal_df["agent_name"],
|
214 |
+
y=illegal_df["illegal_moves"],
|
215 |
+
title="Illegal Moves by Model",
|
216 |
+
x_label="LLM Model",
|
217 |
+
y_label="# of Illegal Moves"
|
218 |
+
)
|
219 |
+
|
220 |
+
with gr.Row():
|
221 |
+
gr.Dataframe(value=illegal_df, label="Illegal Move Summary")
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
# Launch the Gradio interface
|
228 |
interface.launch()
|