File size: 6,035 Bytes
eec6b5e
 
 
3ef1210
 
 
eec6b5e
 
 
 
 
3ef1210
eec6b5e
 
 
5f11885
3ef1210
 
 
eec6b5e
 
 
 
 
 
 
 
 
 
 
 
 
3ef1210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec6b5e
3ef1210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744a850
3ef1210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec6b5e
 
3ef1210
eec6b5e
 
3ef1210
744a850
eec6b5e
3ef1210
eec6b5e
 
 
3ef1210
eec6b5e
3ef1210
eec6b5e
3ef1210
eec6b5e
 
 
3ef1210
 
eec6b5e
cbde7e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import datetime
import plotly.graph_objects as go
import threading
import time
from utils import *  # Ensure get_wandb_runs and get_scores are defined here.
from chain import get_model_info

# Global history list to record the lowest avg_loss over time.
loss_history = []

# Set your project name and filter.
project_name = 'ai-factory-validators'
filters = {"State": {"$eq": "running"}}

window_size = 32

# Create a global lock so that update_results runs in mutual exclusion.
update_lock = threading.Lock()

def moving_average(data, window=10):
    """Compute the moving average of data using a sliding window."""
    if not data:
        return []
    ma = []
    for i in range(len(data)):
        start = max(0, i - window + 1)
        window_vals = data[start:i+1]
        ma.append(sum(window_vals) / len(window_vals))
    return ma

def update_results():
    """Fetch runs and scores, update the leaderboard and plot, ensuring that only one call runs at a time."""
    with update_lock:
        # Load new results using provided snippets.
        runs = get_wandb_runs(project_name, filters)
        scores = get_scores(list(range(256)), runs)
        
        # Group scores by competition_id with required fields.
        tables = {}
        for uid, data in scores.items():
            comp_id = data.get("competition_id", "unknown")
            if comp_id not in tables:
                tables[comp_id] = []
            tables[comp_id].append({
                "uid": uid,
                "avg_loss": data.get("avg_loss"),
                "win_rate": data.get("win_rate"),
                "model": get_model_info(uid)
            })
        # Sort each table by UID.
        for comp_id in tables:
            tables[comp_id] = sorted(tables[comp_id], key=lambda x: x["uid"])
        
        # Determine the current lowest avg_loss (for plotting).
        try:
            min_loss = min(data.get("avg_loss", float("inf")) for data in scores.values())
        except ValueError:
            min_loss = None
        
        # Record the current time and update loss_history.
        now = datetime.datetime.now()
        if not loss_history or loss_history[-1][1] != min_loss:
            loss_history.append((now, min_loss))
            if len(loss_history) > 10000:
                loss_history[:] = loss_history[-10000:]
        
        # Create time series and compute moving average.
        times = [t[0] for t in loss_history]
        losses = [t[1] for t in loss_history]
        ma_losses = moving_average(losses, window=window_size)
        
        # Build the Plotly graph.
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=times, y=losses, mode='lines+markers', name='Lowest avg_loss'))
        fig.add_trace(go.Scatter(x=times, y=ma_losses, mode='lines', name=f'Moving Average (window={window_size})'))
        fig.update_layout(
            title="Lowest Avg Loss Over Time",
            xaxis_title="Time",
            yaxis_title="Lowest Avg Loss",
            template="plotly_white",
            height=400
        )
        
        # Build the HTML content for the leaderboard.
        html_content = "<h1>AI Factory Leaderboard</h1>"
        for comp_id, rows in tables.items():
            # Identify the row with the highest win_rate.
            best_win_rate = max(row["win_rate"] for row in rows)
            comp_title = f"Competition ID: {comp_id}"
            if comp_id == 0:
                comp_title += " (Research Track)"
            html_content += f"<h3>{comp_title}</h3>"
            html_content += """
            <table border='1' style='border-collapse: collapse; width: 100%;'>
                <tr>
                    <th>UID</th>
                    <th>Avg Loss</th>
                    <th>Win Rate</th>
                    <th>Model Name</th>
                </tr>
            """
            for row in rows:
                if row["win_rate"] == best_win_rate:
                    style = "background-color: #ffeb99;"  # Light yellow background.
                    crown = " πŸ‘‘"
                else:
                    style = ""
                    crown = ""
                html_content += f"<tr style='{style}'><td>{row['uid']}</td><td>{row['avg_loss']:.4f}</td><td>{row['win_rate']:.2f}</td><td>{row['model']}{crown}</td></tr>"
            html_content += "</table><br>"
            
        return html_content, fig

# Global variables to store the latest outputs.
latest_html = ""
latest_fig = None

def background_update():
    """Background thread that runs update_results every 10 seconds and stores its outputs."""
    global latest_html, latest_fig
    while True:
        try:
            html_content, fig = update_results()
            latest_html, latest_fig = html_content, fig
        except Exception as e:
            print("Error during background update:", e)
        time.sleep(10)

# Start the background update thread.
threading.Thread(target=background_update, daemon=True).start()

def get_latest_results():
    """Return the latest HTML and Plotly graph."""
    return latest_html, latest_fig

with gr.Blocks() as demo:
    # Hide any unwanted refresh button in the DOM.
    gr.HTML("<style>#refresh_button {display: none;}</style>")
    
    # Display the title.
    # gr.HTML("<h1 style='text-align:center;'>AI Factory Leaderboard</h1>")
    
    # Define the outputs.
    tables_output = gr.HTML()
    graph_output = gr.Plot()
    
    # A hidden textbox triggers periodic updates every 10 seconds.
    trigger = gr.Textbox(visible=False, every=10)
    trigger.change(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output])
    
    # Manual refresh button that also calls update_results.
    manual_refresh = gr.Button("Refresh Now")
    manual_refresh.click(fn=update_results, inputs=[], outputs=[tables_output, graph_output])
    
    # Load results once on startup.
    demo.load(fn=get_latest_results, inputs=[], outputs=[tables_output, graph_output])
    
demo.launch(share=True)