Spaces:
Sleeping
Sleeping
File size: 6,035 Bytes
eec6b5e 3ef1210 eec6b5e 3ef1210 eec6b5e 5f11885 3ef1210 eec6b5e 3ef1210 eec6b5e 3ef1210 744a850 3ef1210 eec6b5e 3ef1210 eec6b5e 3ef1210 744a850 eec6b5e 3ef1210 eec6b5e 3ef1210 eec6b5e 3ef1210 eec6b5e 3ef1210 eec6b5e 3ef1210 eec6b5e cbde7e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import gradio as gr
import datetime
import plotly.graph_objects as go
import threading
import time
from utils import * # Ensure get_wandb_runs and get_scores are defined here.
from chain import get_model_info
# Global history list to record the lowest avg_loss over time.
loss_history = []
# Set your project name and filter.
project_name = 'ai-factory-validators'
filters = {"State": {"$eq": "running"}}
window_size = 32
# Create a global lock so that update_results runs in mutual exclusion.
update_lock = threading.Lock()
def moving_average(data, window=10):
"""Compute the moving average of data using a sliding window."""
if not data:
return []
ma = []
for i in range(len(data)):
start = max(0, i - window + 1)
window_vals = data[start:i+1]
ma.append(sum(window_vals) / len(window_vals))
return ma
def update_results():
"""Fetch runs and scores, update the leaderboard and plot, ensuring that only one call runs at a time."""
with update_lock:
# Load new results using provided snippets.
runs = get_wandb_runs(project_name, filters)
scores = get_scores(list(range(256)), runs)
# Group scores by competition_id with required fields.
tables = {}
for uid, data in scores.items():
comp_id = data.get("competition_id", "unknown")
if comp_id not in tables:
tables[comp_id] = []
tables[comp_id].append({
"uid": uid,
"avg_loss": data.get("avg_loss"),
"win_rate": data.get("win_rate"),
"model": get_model_info(uid)
})
# Sort each table by UID.
for comp_id in tables:
tables[comp_id] = sorted(tables[comp_id], key=lambda x: x["uid"])
# Determine the current lowest avg_loss (for plotting).
try:
min_loss = min(data.get("avg_loss", float("inf")) for data in scores.values())
except ValueError:
min_loss = None
# Record the current time and update loss_history.
now = datetime.datetime.now()
if not loss_history or loss_history[-1][1] != min_loss:
loss_history.append((now, min_loss))
if len(loss_history) > 10000:
loss_history[:] = loss_history[-10000:]
# Create time series and compute moving average.
times = [t[0] for t in loss_history]
losses = [t[1] for t in loss_history]
ma_losses = moving_average(losses, window=window_size)
# Build the Plotly graph.
fig = go.Figure()
fig.add_trace(go.Scatter(x=times, y=losses, mode='lines+markers', name='Lowest avg_loss'))
fig.add_trace(go.Scatter(x=times, y=ma_losses, mode='lines', name=f'Moving Average (window={window_size})'))
fig.update_layout(
title="Lowest Avg Loss Over Time",
xaxis_title="Time",
yaxis_title="Lowest Avg Loss",
template="plotly_white",
height=400
)
# Build the HTML content for the leaderboard.
html_content = "<h1>AI Factory Leaderboard</h1>"
for comp_id, rows in tables.items():
# Identify the row with the highest win_rate.
best_win_rate = max(row["win_rate"] for row in rows)
comp_title = f"Competition ID: {comp_id}"
if comp_id == 0:
comp_title += " (Research Track)"
html_content += f"<h3>{comp_title}</h3>"
html_content += """
<table border='1' style='border-collapse: collapse; width: 100%;'>
<tr>
<th>UID</th>
<th>Avg Loss</th>
<th>Win Rate</th>
<th>Model Name</th>
</tr>
"""
for row in rows:
if row["win_rate"] == best_win_rate:
style = "background-color: #ffeb99;" # Light yellow background.
crown = " π"
else:
style = ""
crown = ""
html_content += f"<tr style='{style}'><td>{row['uid']}</td><td>{row['avg_loss']:.4f}</td><td>{row['win_rate']:.2f}</td><td>{row['model']}{crown}</td></tr>"
html_content += "</table><br>"
return html_content, fig
# Global variables to store the latest outputs.
latest_html = ""
latest_fig = None
def background_update():
"""Background thread that runs update_results every 10 seconds and stores its outputs."""
global latest_html, latest_fig
while True:
try:
html_content, fig = update_results()
latest_html, latest_fig = html_content, fig
except Exception as e:
print("Error during background update:", e)
time.sleep(10)
# Start the background update thread.
threading.Thread(target=background_update, daemon=True).start()
def get_latest_results():
"""Return the latest HTML and Plotly graph."""
return latest_html, latest_fig
with gr.Blocks() as demo:
# Hide any unwanted refresh button in the DOM.
gr.HTML("<style>#refresh_button {display: none;}</style>")
# Display the title.
# gr.HTML("<h1 style='text-align:center;'>AI Factory Leaderboard</h1>")
# Define the outputs.
tables_output = gr.HTML()
graph_output = gr.Plot()
# A hidden textbox triggers periodic updates every 10 seconds.
trigger = gr.Textbox(visible=False, every=10)
trigger.change(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output])
# Manual refresh button that also calls update_results.
manual_refresh = gr.Button("Refresh Now")
manual_refresh.click(fn=update_results, inputs=[], outputs=[tables_output, graph_output])
# Load results once on startup.
demo.load(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output])
demo.launch(share=True)
|