Spaces:
Sleeping
Sleeping
import gradio as gr | |
import datetime | |
import plotly.graph_objects as go | |
import threading | |
import time | |
from utils import * # Ensure get_wandb_runs and get_scores are defined here. | |
from chain import get_model_info | |
# Global history list to record the lowest avg_loss over time. | |
loss_history = [] | |
# Set your project name and filter. | |
project_name = 'ai-factory-validators' | |
filters = {"State": {"$eq": "running"}} | |
window_size = 32 | |
# Create a global lock so that update_results runs in mutual exclusion. | |
update_lock = threading.Lock() | |
def moving_average(data, window=10): | |
"""Compute the moving average of data using a sliding window.""" | |
if not data: | |
return [] | |
ma = [] | |
for i in range(len(data)): | |
start = max(0, i - window + 1) | |
window_vals = data[start:i+1] | |
ma.append(sum(window_vals) / len(window_vals)) | |
return ma | |
def update_results(): | |
"""Fetch runs and scores, update the leaderboard and plot, ensuring that only one call runs at a time.""" | |
with update_lock: | |
# Load new results using provided snippets. | |
runs = get_wandb_runs(project_name, filters) | |
scores = get_scores(list(range(256)), runs) | |
# Group scores by competition_id with required fields. | |
tables = {} | |
for uid, data in scores.items(): | |
comp_id = data.get("competition_id", "unknown") | |
if comp_id not in tables: | |
tables[comp_id] = [] | |
tables[comp_id].append({ | |
"uid": uid, | |
"avg_loss": data.get("avg_loss"), | |
"win_rate": data.get("win_rate"), | |
"model": get_model_info(uid) | |
}) | |
# Sort each table by UID. | |
for comp_id in tables: | |
tables[comp_id] = sorted(tables[comp_id], key=lambda x: x["uid"]) | |
# Determine the current lowest avg_loss (for plotting). | |
try: | |
min_loss = min(data.get("avg_loss", float("inf")) for data in scores.values()) | |
except ValueError: | |
min_loss = None | |
# Record the current time and update loss_history. | |
now = datetime.datetime.now() | |
if not loss_history or loss_history[-1][1] != min_loss: | |
loss_history.append((now, min_loss)) | |
if len(loss_history) > 10000: | |
loss_history[:] = loss_history[-10000:] | |
# Create time series and compute moving average. | |
times = [t[0] for t in loss_history] | |
losses = [t[1] for t in loss_history] | |
ma_losses = moving_average(losses, window=window_size) | |
# Build the Plotly graph. | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=times, y=losses, mode='lines+markers', name='Lowest avg_loss')) | |
fig.add_trace(go.Scatter(x=times, y=ma_losses, mode='lines', name=f'Moving Average (window={window_size})')) | |
fig.update_layout( | |
title="Lowest Avg Loss Over Time", | |
xaxis_title="Time", | |
yaxis_title="Lowest Avg Loss", | |
template="plotly_white", | |
height=400 | |
) | |
# Build the HTML content for the leaderboard. | |
html_content = "<h1>AI Factory Leaderboard</h1>" | |
for comp_id, rows in tables.items(): | |
# Identify the row with the highest win_rate. | |
best_win_rate = max(row["win_rate"] for row in rows) | |
comp_title = f"Competition ID: {comp_id}" | |
if comp_id == 0: | |
comp_title += " (Research Track)" | |
html_content += f"<h3>{comp_title}</h3>" | |
html_content += """ | |
<table border='1' style='border-collapse: collapse; width: 100%;'> | |
<tr> | |
<th>UID</th> | |
<th>Avg Loss</th> | |
<th>Win Rate</th> | |
<th>Model Name</th> | |
</tr> | |
""" | |
for row in rows: | |
if row["win_rate"] == best_win_rate: | |
style = "background-color: #ffeb99;" # Light yellow background. | |
crown = " π" | |
else: | |
style = "" | |
crown = "" | |
html_content += f"<tr style='{style}'><td>{row['uid']}</td><td>{row['avg_loss']:.4f}</td><td>{row['win_rate']:.2f}</td><td>{row['model']}{crown}</td></tr>" | |
html_content += "</table><br>" | |
return html_content, fig | |
# Global variables to store the latest outputs. | |
latest_html = "" | |
latest_fig = None | |
def background_update(): | |
"""Background thread that runs update_results every 10 seconds and stores its outputs.""" | |
global latest_html, latest_fig | |
while True: | |
try: | |
html_content, fig = update_results() | |
latest_html, latest_fig = html_content, fig | |
except Exception as e: | |
print("Error during background update:", e) | |
time.sleep(10) | |
# Start the background update thread. | |
threading.Thread(target=background_update, daemon=True).start() | |
def get_latest_results(): | |
"""Return the latest HTML and Plotly graph.""" | |
return latest_html, latest_fig | |
with gr.Blocks() as demo: | |
# Hide any unwanted refresh button in the DOM. | |
gr.HTML("<style>#refresh_button {display: none;}</style>") | |
# Display the title. | |
# gr.HTML("<h1 style='text-align:center;'>AI Factory Leaderboard</h1>") | |
# Define the outputs. | |
tables_output = gr.HTML() | |
graph_output = gr.Plot() | |
# A hidden textbox triggers periodic updates every 10 seconds. | |
trigger = gr.Textbox(visible=False, every=10) | |
trigger.change(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output]) | |
# Manual refresh button that also calls update_results. | |
manual_refresh = gr.Button("Refresh Now") | |
manual_refresh.click(fn=update_results, inputs=[], outputs=[tables_output, graph_output]) | |
# Load results once on startup. | |
demo.load(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output]) | |
demo.launch(share=True) | |