vis data commited on
Commit
3ef1210
·
verified ·
1 Parent(s): 5f11885

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -104
app.py CHANGED
@@ -1,23 +1,22 @@
1
  import gradio as gr
2
  import datetime
3
  import plotly.graph_objects as go
4
- from utils import *
 
 
5
  from chain import get_model_info
6
 
7
  # Global history list to record the lowest avg_loss over time.
8
  loss_history = []
9
 
10
- # Set your project name and filter for runs.
11
  project_name = 'ai-factory-validators'
12
  filters = {"State": {"$eq": "running"}}
13
 
14
  window_size = 32
15
- # NOTE: Ensure that get_wandb_runs and get_scores are defined in your environment.
16
- # For example:
17
- # def get_wandb_runs(project_name, filters):
18
- # # Your implementation here
19
- # def get_scores(ids, runs):
20
- # # Your implementation here
21
 
22
  def moving_average(data, window=10):
23
  """Compute the moving average of data using a sliding window."""
@@ -31,115 +30,130 @@ def moving_average(data, window=10):
31
  return ma
32
 
33
  def update_results():
34
- """
35
- Fetches the latest runs and scores using your provided snippet,
36
- groups data by competition, highlights the best (lowest avg_loss) entry,
37
- records the current lowest avg_loss (skipping duplicate values and limiting the history length),
38
- computes the moving average, and returns both an HTML table and a Plotly graph.
39
- """
40
- # Load new results using your provided snippet.
41
- runs = get_wandb_runs(project_name, filters)
42
- scores = get_scores(list(range(256)), runs)
43
-
44
- # Group scores by competition_id and filter columns: uid, avg_loss, win_rate.
45
- tables = {}
46
- for uid, data in scores.items():
47
- comp_id = data.get("competition_id", "unknown")
48
- if comp_id not in tables:
49
- tables[comp_id] = []
50
- tables[comp_id].append({
51
- "uid": uid,
52
- "avg_loss": data.get("avg_loss"),
53
- "win_rate": data.get("win_rate"),
54
- "model": get_model_info(uid)
55
- })
56
- # Sort each table by uid.
57
- for comp_id in tables:
58
- tables[comp_id] = sorted(tables[comp_id], key=lambda x: x["uid"])
59
-
60
- # Determine the lowest avg_loss in the current scores.
61
- try:
62
- min_loss = min(data.get("avg_loss", float("inf")) for data in scores.values())
63
- except ValueError:
64
- min_loss = None # if scores is empty
65
-
66
- # Record the current time.
67
- now = datetime.datetime.now()
68
- # Only append if loss_history is empty or the last value is different from current min_loss.
69
- if not loss_history or loss_history[-1][1] != min_loss:
70
- loss_history.append((now, min_loss))
71
- # Limit loss_history length to 10000
72
- if len(loss_history) > 10000:
73
- loss_history[:] = loss_history[-10000:]
74
-
75
- # Create time series data.
76
- times = [t[0] for t in loss_history]
77
- losses = [t[1] for t in loss_history]
78
-
79
- # Compute the moving average with window of 10.
80
- ma_losses = moving_average(losses, window=window_size)
81
-
82
- # Create a Plotly line graph with both the raw lowest avg_loss and its moving average.
83
- fig = go.Figure()
84
- fig.add_trace(go.Scatter(x=times, y=losses, mode='lines+markers', name='Lowest avg_loss'))
85
- fig.add_trace(go.Scatter(x=times, y=ma_losses, mode='lines', name=f'Moving Average (window={window_size})'))
86
- fig.update_layout(
87
- title="Lowest Avg Loss Over Time",
88
- xaxis_title="Time",
89
- yaxis_title="Lowest Avg Loss",
90
- template="plotly_white",
91
- height=400
92
- )
93
-
94
- # Build HTML content: one table per competition_id.
95
- html_content = "<h1>AI Factory Leaderboard</h1>"
96
- for comp_id, rows in tables.items():
97
- # Identify best (lowest avg_loss) entry in the current competition.
98
- best_loss = min(row["avg_loss"] for row in rows)
99
- # For competition 0, mark it as Research Track.
100
- comp_title = f"Competition ID: {comp_id}"
101
- if comp_id == 0:
102
- comp_title += " (Research Track)"
103
- html_content += f"<h3>{comp_title}</h3>"
104
- html_content += """
105
- <table border='1' style='border-collapse: collapse; width: 100%;'>
106
- <tr>
107
- <th>UID</th>
108
- <th>Avg Loss</th>
109
- <th>Win Rate</th>
110
- <th>Model Name</th>
111
- </tr>
112
- """
113
- for row in rows:
114
- # Highlight the row if it has the best avg_loss.
115
- style = "background-color: #d4edda;" if row["avg_loss"] == best_loss else ""
116
- html_content += f"<tr style='{style}'><td>{row['uid']}</td><td>{row['avg_loss']:.4f}</td><td>{row['win_rate']:.2f}</td><td>{row['model']}</td></tr>"
117
- html_content += "</table><br>"
118
 
119
- return html_content, fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  with gr.Blocks() as demo:
122
- # Add CSS to hide the refresh button but keep it in the DOM.
123
  gr.HTML("<style>#refresh_button {display: none;}</style>")
124
 
125
- # Display the title at the top.
126
  gr.HTML("<h1 style='text-align:center;'>AI Factory Leaderboard</h1>")
127
 
128
- # Outputs: HTML tables and Plotly graph.
129
  tables_output = gr.HTML()
130
  graph_output = gr.Plot()
131
 
132
- # Hidden component to trigger periodic updates
133
  trigger = gr.Textbox(visible=False, every=10)
134
-
135
- # Set up the function to run every 10 seconds
136
- trigger.change(fn=update_results, inputs=[], outputs=[tables_output, graph_output])
137
 
138
- # Provide a manual refresh button as well.
139
  manual_refresh = gr.Button("Refresh Now")
140
  manual_refresh.click(fn=update_results, inputs=[], outputs=[tables_output, graph_output])
141
 
142
- # Load the results once on startup.
143
- demo.load(fn=update_results, inputs=[], outputs=[tables_output, graph_output])
144
 
145
- demo.launch()
 
1
  import gradio as gr
2
  import datetime
3
  import plotly.graph_objects as go
4
+ import threading
5
+ import time
6
+ from utils import * # Ensure get_wandb_runs and get_scores are defined here.
7
  from chain import get_model_info
8
 
9
  # Global history list to record the lowest avg_loss over time.
10
  loss_history = []
11
 
12
+ # Set your project name and filter.
13
  project_name = 'ai-factory-validators'
14
  filters = {"State": {"$eq": "running"}}
15
 
16
  window_size = 32
17
+
18
+ # Create a global lock so that update_results runs in mutual exclusion.
19
+ update_lock = threading.Lock()
 
 
 
20
 
21
  def moving_average(data, window=10):
22
  """Compute the moving average of data using a sliding window."""
 
30
  return ma
31
 
32
  def update_results():
33
+ """Fetch runs and scores, update the leaderboard and plot, ensuring that only one call runs at a time."""
34
+ with update_lock:
35
+ # Load new results using provided snippets.
36
+ runs = get_wandb_runs(project_name, filters)
37
+ scores = get_scores(list(range(256)), runs)
38
+
39
+ # Group scores by competition_id with required fields.
40
+ tables = {}
41
+ for uid, data in scores.items():
42
+ comp_id = data.get("competition_id", "unknown")
43
+ if comp_id not in tables:
44
+ tables[comp_id] = []
45
+ tables[comp_id].append({
46
+ "uid": uid,
47
+ "avg_loss": data.get("avg_loss"),
48
+ "win_rate": data.get("win_rate"),
49
+ "model": get_model_info(uid)
50
+ })
51
+ # Sort each table by UID.
52
+ for comp_id in tables:
53
+ tables[comp_id] = sorted(tables[comp_id], key=lambda x: x["uid"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Determine the current lowest avg_loss (for plotting).
56
+ try:
57
+ min_loss = min(data.get("avg_loss", float("inf")) for data in scores.values())
58
+ except ValueError:
59
+ min_loss = None
60
+
61
+ # Record the current time and update loss_history.
62
+ now = datetime.datetime.now()
63
+ if not loss_history or loss_history[-1][1] != min_loss:
64
+ loss_history.append((now, min_loss))
65
+ if len(loss_history) > 10000:
66
+ loss_history[:] = loss_history[-10000:]
67
+
68
+ # Create time series and compute moving average.
69
+ times = [t[0] for t in loss_history]
70
+ losses = [t[1] for t in loss_history]
71
+ ma_losses = moving_average(losses, window=window_size)
72
+
73
+ # Build the Plotly graph.
74
+ fig = go.Figure()
75
+ fig.add_trace(go.Scatter(x=times, y=losses, mode='lines+markers', name='Lowest avg_loss'))
76
+ fig.add_trace(go.Scatter(x=times, y=ma_losses, mode='lines', name=f'Moving Average (window={window_size})'))
77
+ fig.update_layout(
78
+ title="Lowest Avg Loss Over Time",
79
+ xaxis_title="Time",
80
+ yaxis_title="Lowest Avg Loss",
81
+ template="plotly_white",
82
+ height=400
83
+ )
84
+
85
+ # Build the HTML content for the leaderboard.
86
+ html_content = "<h1>AI Factory Leaderboard</h1>"
87
+ for comp_id, rows in tables.items():
88
+ # Identify the row with the highest win_rate.
89
+ best_win_rate = max(row["win_rate"] for row in rows)
90
+ comp_title = f"Competition ID: {comp_id}"
91
+ if comp_id == 0:
92
+ comp_title += " (Research Track)"
93
+ html_content += f"<h3>{comp_title}</h3>"
94
+ html_content += """
95
+ <table border='1' style='border-collapse: collapse; width: 100%;'>
96
+ <tr>
97
+ <th>UID</th>
98
+ <th>Avg Loss</th>
99
+ <th>Win Rate</th>
100
+ <th>Model Name</th>
101
+ </tr>
102
+ """
103
+ for row in rows:
104
+ if row["win_rate"] == best_win_rate:
105
+ style = "background-color: #ffeb99;" # Light yellow background.
106
+ crown = " 👑"
107
+ else:
108
+ style = ""
109
+ crown = ""
110
+ html_content += f"<tr style='{style}'><td>{row['uid']}</td><td>{row['avg_loss']:.4f}</td><td>{row['win_rate']:.2f}</td><td>{row['model']}{crown}</td></tr>"
111
+ html_content += "</table><br>"
112
+
113
+ return html_content, fig
114
+
115
+ # Global variables to store the latest outputs.
116
+ latest_html = ""
117
+ latest_fig = None
118
+
119
+ def background_update():
120
+ """Background thread that runs update_results every 10 seconds and stores its outputs."""
121
+ global latest_html, latest_fig
122
+ while True:
123
+ try:
124
+ html_content, fig = update_results()
125
+ latest_html, latest_fig = html_content, fig
126
+ except Exception as e:
127
+ print("Error during background update:", e)
128
+ time.sleep(10)
129
+
130
+ # Start the background update thread.
131
+ threading.Thread(target=background_update, daemon=True).start()
132
+
133
+ def get_latest_results():
134
+ """Return the latest HTML and Plotly graph."""
135
+ return latest_html, latest_fig
136
 
137
  with gr.Blocks() as demo:
138
+ # Hide any unwanted refresh button in the DOM.
139
  gr.HTML("<style>#refresh_button {display: none;}</style>")
140
 
141
+ # Display the title.
142
  gr.HTML("<h1 style='text-align:center;'>AI Factory Leaderboard</h1>")
143
 
144
+ # Define the outputs.
145
  tables_output = gr.HTML()
146
  graph_output = gr.Plot()
147
 
148
+ # A hidden textbox triggers periodic updates every 10 seconds.
149
  trigger = gr.Textbox(visible=False, every=10)
150
+ trigger.change(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output])
 
 
151
 
152
+ # Manual refresh button that also calls update_results.
153
  manual_refresh = gr.Button("Refresh Now")
154
  manual_refresh.click(fn=update_results, inputs=[], outputs=[tables_output, graph_output])
155
 
156
+ # Load results once on startup.
157
+ demo.load(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output])
158
 
159
+ demo.launch()