lcipolina commited on
Commit
52ff713
·
verified ·
1 Parent(s): f7ec534

Added plots

Browse files

Added an additional tab for plots

Files changed (1) hide show
  1. app.py +98 -22
app.py CHANGED
@@ -15,11 +15,11 @@ def find_or_download_db():
15
  if not os.path.exists(db_dir):
16
  os.makedirs(db_dir)
17
  db_files = glob.glob(os.path.join(db_dir, "*.db"))
18
-
19
  # Ensure the random bot database exists
20
  if "results/random_None.db" not in db_files:
21
  raise FileNotFoundError("Please upload results for the random agent in a file named 'random_None.db'.")
22
-
23
  return db_files
24
 
25
  def extract_agent_info(filename: str):
@@ -36,7 +36,7 @@ def get_available_games(include_aggregated=True) -> List[str]:
36
  """Extracts all unique game names from all SQLite databases. Includes 'Aggregated Performance' only when required."""
37
  db_files = find_or_download_db()
38
  game_names = set()
39
-
40
  for db_file in db_files:
41
  conn = sqlite3.connect(db_file)
42
  try:
@@ -47,51 +47,74 @@ def get_available_games(include_aggregated=True) -> List[str]:
47
  pass # Ignore errors if table doesn't exist
48
  finally:
49
  conn.close()
50
-
51
  game_list = sorted(game_names) if game_names else ["No Games Found"]
52
  if include_aggregated:
53
  game_list.insert(0, "Aggregated Performance") # Ensure 'Aggregated Performance' is always first
54
  return game_list
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
57
  """Extract and aggregate leaderboard stats from all SQLite databases."""
58
  db_files = find_or_download_db()
59
  all_stats = []
60
-
61
  for db_file in db_files:
62
  conn = sqlite3.connect(db_file)
63
  agent_type, model_name = extract_agent_info(db_file)
64
-
65
  # Skip random agent rows
66
  if agent_type == "random":
67
  conn.close()
68
  continue
69
-
70
  if game_name == "Aggregated Performance":
71
  query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
72
  "SUM(reward) AS total_rewards " \
73
  "FROM game_results"
74
  df = pd.read_sql_query(query, conn)
75
-
76
- # Compute avg_generation_time across all games instead of a single game
77
- game_query = "SELECT AVG(generation_time) FROM moves"
78
  avg_gen_time = conn.execute(game_query).fetchone()[0] or 0
79
  else:
80
  query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
81
  "SUM(reward) AS total_rewards " \
82
  "FROM game_results WHERE game_name = ?"
83
  df = pd.read_sql_query(query, conn, params=(game_name,))
84
-
85
  # Fetch average generation time from moves table
86
  gen_time_query = "SELECT AVG(generation_time) FROM moves WHERE game_name = ?"
87
  avg_gen_time = conn.execute(gen_time_query, (game_name,)).fetchone()[0] or 0
88
-
89
  # Keep division by 2 for total rewards
90
  df["total_rewards"] = df["total_rewards"].fillna(0).astype(float) / 2
91
-
92
  # Ensure avg_gen_time has decimals
93
  avg_gen_time = round(avg_gen_time, 3)
94
-
95
  # Calculate win rate against random bot using moves table
96
  vs_random_query = """
97
  SELECT COUNT(DISTINCT gr.episode) FROM game_results gr
@@ -106,22 +129,24 @@ def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
106
  wins_vs_random = conn.execute(vs_random_query).fetchone()[0] or 0
107
  total_vs_random = conn.execute(total_vs_random_query).fetchone()[0] or 0
108
  vs_random_rate = (wins_vs_random / total_vs_random * 100) if total_vs_random > 0 else 0
109
-
110
  df.insert(0, "agent_name", model_name) # Ensure agent_name is the first column
111
  df.insert(1, "agent_type", agent_type) # Ensure agent_type is second column
112
  df["avg_generation_time (sec)"] = avg_gen_time
113
  df["win vs_random (%)"] = round(vs_random_rate, 2)
114
-
115
  all_stats.append(df)
116
  conn.close()
117
-
118
  leaderboard_df = pd.concat(all_stats, ignore_index=True) if all_stats else pd.DataFrame()
119
-
120
  if leaderboard_df.empty:
121
  leaderboard_df = pd.DataFrame(columns=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"])
122
-
123
  return leaderboard_df
124
 
 
 
125
  with gr.Blocks() as interface:
126
  # Tab for playing games against LLMs
127
  with gr.Tab("Game Arena"):
@@ -134,10 +159,10 @@ with gr.Blocks() as interface:
134
  play_button = gr.Button("Start Game")
135
  # Textbox to display the game log
136
  game_output = gr.Textbox(label="Game Log")
137
-
138
  # Event to start the game when the button is clicked
139
  play_button.click(lambda game, opponent: f"Game {game} started against {opponent}", inputs=[game_dropdown, opponent_dropdown], outputs=[game_output])
140
-
141
  # Tab for leaderboard and performance tracking
142
  with gr.Tab("Leaderboard"):
143
  gr.Markdown("# LLM Model Leaderboard\nTrack performance across different games!")
@@ -147,6 +172,57 @@ with gr.Blocks() as interface:
147
  leaderboard_table = gr.Dataframe(value=extract_leaderboard_stats("Aggregated Performance"), headers=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"], every=5)
148
  # Update the leaderboard when a new game is selected
149
  leaderboard_game_dropdown.change(fn=extract_leaderboard_stats, inputs=[leaderboard_game_dropdown], outputs=[leaderboard_table])
150
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  # Launch the Gradio interface
152
  interface.launch()
 
15
  if not os.path.exists(db_dir):
16
  os.makedirs(db_dir)
17
  db_files = glob.glob(os.path.join(db_dir, "*.db"))
18
+
19
  # Ensure the random bot database exists
20
  if "results/random_None.db" not in db_files:
21
  raise FileNotFoundError("Please upload results for the random agent in a file named 'random_None.db'.")
22
+
23
  return db_files
24
 
25
  def extract_agent_info(filename: str):
 
36
  """Extracts all unique game names from all SQLite databases. Includes 'Aggregated Performance' only when required."""
37
  db_files = find_or_download_db()
38
  game_names = set()
39
+
40
  for db_file in db_files:
41
  conn = sqlite3.connect(db_file)
42
  try:
 
47
  pass # Ignore errors if table doesn't exist
48
  finally:
49
  conn.close()
50
+
51
  game_list = sorted(game_names) if game_names else ["No Games Found"]
52
  if include_aggregated:
53
  game_list.insert(0, "Aggregated Performance") # Ensure 'Aggregated Performance' is always first
54
  return game_list
55
 
56
+ def extract_illegal_moves_summary()-> pd.DataFrame:
57
+ """Extracts the number of illegal moves made by each LLM agent.
58
+
59
+ Returns:
60
+ pd.DataFrame: DataFrame with columns [agent_name, illegal_moves].
61
+ """
62
+ db_files = find_or_download_db()
63
+ summary = []
64
+ for db_file in db_files:
65
+ agent_type, model_name = extract_agent_info(db_file)
66
+ if agent_type == "random":
67
+ continue # Skip the random agent from this analysis
68
+ conn = sqlite3.connect(db_file)
69
+ try:
70
+ # Count number of illegal moves from the illegal_moves table
71
+ df = pd.read_sql_query("SELECT COUNT(*) AS illegal_moves FROM illegal_moves", conn)
72
+ count = int(df["illegal_moves"].iloc[0]) if not df.empty else 0
73
+ except Exception:
74
+ count = 0 # If the table does not exist or error occurs
75
+ summary.append({"agent_name": model_name, "illegal_moves": count})
76
+ conn.close()
77
+ return pd.DataFrame(summary)
78
+
79
  def extract_leaderboard_stats(game_name: str) -> pd.DataFrame:
80
  """Extract and aggregate leaderboard stats from all SQLite databases."""
81
  db_files = find_or_download_db()
82
  all_stats = []
83
+
84
  for db_file in db_files:
85
  conn = sqlite3.connect(db_file)
86
  agent_type, model_name = extract_agent_info(db_file)
87
+
88
  # Skip random agent rows
89
  if agent_type == "random":
90
  conn.close()
91
  continue
92
+
93
  if game_name == "Aggregated Performance":
94
  query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
95
  "SUM(reward) AS total_rewards " \
96
  "FROM game_results"
97
  df = pd.read_sql_query(query, conn)
98
+
99
+ # Use avg_generation_time from a specific game (e.g., Kuhn Poker)
100
+ game_query = "SELECT AVG(generation_time) FROM moves WHERE game_name = 'kuhn_poker'"
101
  avg_gen_time = conn.execute(game_query).fetchone()[0] or 0
102
  else:
103
  query = "SELECT COUNT(DISTINCT episode) AS games_played, " \
104
  "SUM(reward) AS total_rewards " \
105
  "FROM game_results WHERE game_name = ?"
106
  df = pd.read_sql_query(query, conn, params=(game_name,))
107
+
108
  # Fetch average generation time from moves table
109
  gen_time_query = "SELECT AVG(generation_time) FROM moves WHERE game_name = ?"
110
  avg_gen_time = conn.execute(gen_time_query, (game_name,)).fetchone()[0] or 0
111
+
112
  # Keep division by 2 for total rewards
113
  df["total_rewards"] = df["total_rewards"].fillna(0).astype(float) / 2
114
+
115
  # Ensure avg_gen_time has decimals
116
  avg_gen_time = round(avg_gen_time, 3)
117
+
118
  # Calculate win rate against random bot using moves table
119
  vs_random_query = """
120
  SELECT COUNT(DISTINCT gr.episode) FROM game_results gr
 
129
  wins_vs_random = conn.execute(vs_random_query).fetchone()[0] or 0
130
  total_vs_random = conn.execute(total_vs_random_query).fetchone()[0] or 0
131
  vs_random_rate = (wins_vs_random / total_vs_random * 100) if total_vs_random > 0 else 0
132
+
133
  df.insert(0, "agent_name", model_name) # Ensure agent_name is the first column
134
  df.insert(1, "agent_type", agent_type) # Ensure agent_type is second column
135
  df["avg_generation_time (sec)"] = avg_gen_time
136
  df["win vs_random (%)"] = round(vs_random_rate, 2)
137
+
138
  all_stats.append(df)
139
  conn.close()
140
+
141
  leaderboard_df = pd.concat(all_stats, ignore_index=True) if all_stats else pd.DataFrame()
142
+
143
  if leaderboard_df.empty:
144
  leaderboard_df = pd.DataFrame(columns=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"])
145
+
146
  return leaderboard_df
147
 
148
+
149
+ ##########################################################
150
  with gr.Blocks() as interface:
151
  # Tab for playing games against LLMs
152
  with gr.Tab("Game Arena"):
 
159
  play_button = gr.Button("Start Game")
160
  # Textbox to display the game log
161
  game_output = gr.Textbox(label="Game Log")
162
+
163
  # Event to start the game when the button is clicked
164
  play_button.click(lambda game, opponent: f"Game {game} started against {opponent}", inputs=[game_dropdown, opponent_dropdown], outputs=[game_output])
165
+
166
  # Tab for leaderboard and performance tracking
167
  with gr.Tab("Leaderboard"):
168
  gr.Markdown("# LLM Model Leaderboard\nTrack performance across different games!")
 
172
  leaderboard_table = gr.Dataframe(value=extract_leaderboard_stats("Aggregated Performance"), headers=["agent_name", "agent_type", "# games", "total rewards", "avg_generation_time (sec)", "win-rate", "win vs_random (%)"], every=5)
173
  # Update the leaderboard when a new game is selected
174
  leaderboard_game_dropdown.change(fn=extract_leaderboard_stats, inputs=[leaderboard_game_dropdown], outputs=[leaderboard_table])
175
+
176
+ # Tab for visual insights and performance metrics
177
+ with gr.Tab("Metrics Dashboard"):
178
+ gr.Markdown("# 📊 Metrics Dashboard\nVisual summaries of LLM performance across games.")
179
+
180
+ # Extract data for visualizations
181
+ metrics_df = extract_leaderboard_stats("Aggregated Performance")
182
+
183
+ with gr.Row():
184
+ gr.BarPlot(
185
+ x=metrics_df["agent_name"],
186
+ y=metrics_df["win vs_random (%)"],
187
+ title="Win Rate vs Random Bot",
188
+ x_label="LLM Model",
189
+ y_label="Win Rate (%)"
190
+ )
191
+
192
+ with gr.Row():
193
+ gr.BarPlot(
194
+ x=metrics_df["agent_name"],
195
+ y=metrics_df["avg_generation_time (sec)"],
196
+ title="Average Generation Time",
197
+ x_label="LLM Model",
198
+ y_label="Time (sec)"
199
+ )
200
+
201
+ with gr.Row():
202
+ gr.Dataframe(value=metrics_df, label="Performance Summary")
203
+
204
+ # Tab for LLM reasoning and illegal move analysis
205
+ with gr.Tab("Analysis of LLM Reasoning"):
206
+ gr.Markdown("# 🧠 Analysis of LLM Reasoning\nInsights into move legality and decision behavior.")
207
+
208
+ # Load illegal move stats using global function
209
+ illegal_df = extract_illegal_moves_summary()
210
+
211
+ with gr.Row():
212
+ gr.BarPlot(
213
+ x=illegal_df["agent_name"],
214
+ y=illegal_df["illegal_moves"],
215
+ title="Illegal Moves by Model",
216
+ x_label="LLM Model",
217
+ y_label="# of Illegal Moves"
218
+ )
219
+
220
+ with gr.Row():
221
+ gr.Dataframe(value=illegal_df, label="Illegal Move Summary")
222
+
223
+
224
+
225
+
226
+
227
  # Launch the Gradio interface
228
  interface.launch()