m-ric's picture
Format
79f07b1
raw
history blame
12.1 kB
import os
import pickle
import pandas as pd
import numpy as np
import gradio as gr
from datetime import datetime
from huggingface_hub import HfApi
from apscheduler.schedulers.background import BackgroundScheduler
import plotly.graph_objects as go
from utils import (
KEY_TO_CATEGORY_NAME,
CAT_NAME_TO_EXPLANATION,
download_latest_data_from_space,
get_constants,
update_release_date_mapping,
format_data,
get_trendlines,
find_crossover_point,
sigmoid_transition
)
###################
### Initialize scheduler
###################
def restart_space():
HfApi(token=os.getenv("HF_TOKEN", None)).restart_space(
repo_id="andrewrreed/closed-vs-open-arena-elo"
)
print(f"Space restarted on {datetime.now()}")
# restart the space every day at 9am
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "cron", day_of_week="mon-sun", hour=7, minute=0)
scheduler.start()
###################
### Load Data
###################
# gather ELO data
latest_elo_file_local = download_latest_data_from_space(
repo_id="lmsys/chatbot-arena-leaderboard", file_type="pkl"
)
with open(latest_elo_file_local, "rb") as fin:
elo_results = pickle.load(fin)
# TO-DO: need to also include vision
elo_results = elo_results["text"]
arena_dfs = {}
for k in KEY_TO_CATEGORY_NAME.keys():
if k not in elo_results:
continue
arena_dfs[KEY_TO_CATEGORY_NAME[k]] = elo_results[k]["leaderboard_table_df"]
# gather open llm leaderboard data
latest_leaderboard_file_local = download_latest_data_from_space(
repo_id="lmsys/chatbot-arena-leaderboard", file_type="csv"
)
leaderboard_df = pd.read_csv(latest_leaderboard_file_local)
# load release date mapping data
release_date_mapping = pd.read_json("release_date_mapping.json", orient="records")
###################
### Prepare Data
###################
# update release date mapping with new models
# check for new models in ELO data
new_model_keys_to_add = [
model
for model in arena_dfs["Overall"].index.to_list()
if model not in release_date_mapping["key"].to_list()
]
if new_model_keys_to_add:
release_date_mapping = update_release_date_mapping(
new_model_keys_to_add, leaderboard_df, release_date_mapping
)
# merge leaderboard data with ELO data
merged_dfs = {}
for k, v in arena_dfs.items():
merged_dfs[k] = (
pd.merge(arena_dfs[k], leaderboard_df, left_index=True, right_on="key")
.sort_values("rating", ascending=False)
.reset_index(drop=True)
)
# add release dates into the merged data
for k, v in merged_dfs.items():
merged_dfs[k] = pd.merge(
merged_dfs[k], release_date_mapping[["key", "Release Date"]], on="key"
)
# format dataframes
merged_dfs = {k: format_data(v) for k, v in merged_dfs.items()}
# get constants
min_elo_score, max_elo_score, _ = get_constants(merged_dfs)
date_updated = elo_results["full"]["last_updated_datetime"].split(" ")[0]
orgs = merged_dfs["Overall"].Organization.unique().tolist()
###################
### Build and Plot Data
###################
df = merged_dfs["Overall"]
top_orgs = df.groupby("Organization")["rating"].max().nlargest(11).index.tolist()
df = df.loc[(df["Organization"].isin(top_orgs)) & (df["rating"] > 1000)]
print(df)
df = df.loc[~df["Release Date"].isna()]
def get_data_split(dfs, set_name):
df = dfs[set_name].copy(deep=True)
return df.reset_index(drop=True)
def clean_df_for_display(df):
df = df.loc[
:,
[
"Model",
"rating",
"MMLU",
"MT-bench (score)",
"Release Date",
"Organization",
"License",
"Link",
],
].rename(columns={"rating": "ELO Score", "MT-bench (score)": "MT-Bench"})
df["Release Date"] = df["Release Date"].astype(str)
df.sort_values("ELO Score", ascending=False, inplace=True)
df.reset_index(drop=True, inplace=True)
return df
def format_data(df):
"""
Formats the given DataFrame by performing the following operations:
- Converts the 'License' column values to 'Proprietary LLM' if they are in PROPRIETARY_LICENSES, otherwise 'Open LLM'.
- Converts the 'Release Date' column to datetime format.
- Adds a new 'Month-Year' column by extracting the month and year from the 'Release Date' column.
- Rounds the 'rating' column to the nearest integer.
- Resets the index of the DataFrame.
Args:
df (pandas.DataFrame): The DataFrame to be formatted.
Returns:
pandas.DataFrame: The formatted DataFrame.
"""
PROPRIETARY_LICENSES = ["Proprietary", "Proprietory"]
df["License"] = df["License"].apply(
lambda x: "Proprietary LLM" if x in PROPRIETARY_LICENSES else "Open LLM"
)
df["Release Date"] = pd.to_datetime(df["Release Date"])
df["Month-Year"] = df["Release Date"].dt.to_period("M")
df["rating"] = df["rating"].round()
return df.reset_index(drop=True)
# Define organization to country mapping and colors
org_info = {
"OpenAI": ("#00A67E", "🇺🇸"), # Teal
"Google": ("#4285F4", "🇺🇸"), # Google Blue
"xAI": ("black", "🇺🇸"), # Bright Orange
"Anthropic": ("#cc785c", "🇺🇸"), # Brown (as requested)
"Meta": ("#0064E0", "🇺🇸"), # Facebook Blue
"Alibaba": ("#6958cf", "🇨🇳"),
"DeepSeek": ("#C70039", "🇨🇳"),
"01 AI": ("#11871e", "🇨🇳"), # Bright Green
"DeepSeek AI": ("#9900CC", "🇨🇳"), # Purple
"Mistral": ("#ff7000", "🇫🇷"), # Mistral Orange (as requested)
"AI21 Labs": ("#1E90FF", "🇮🇱"), # Dodger Blue,
"Reka AI": ("#FFC300", "🇺🇸"),
"Zhipu AI": ("#FFC300", "🇨🇳"),
}
def make_figure(df):
fig = go.Figure()
for i, org in enumerate(
df.groupby("Organization")["rating"]
.max()
.sort_values(ascending=False)
.index.tolist()
):
org_data = df[df["Organization"] == org]
if len(org_data) > 0:
x_values = []
y_values = []
current_best = -np.inf
best_models = []
# Group by date and get the best model for each date
daily_best = org_data.groupby("Release Date").first().reset_index()
for _, row in daily_best.iterrows():
if row["rating"] > current_best:
if len(x_values) > 0:
# Create smooth transition
transition_days = (row["Release Date"] - x_values[-1]).days
transition_points = pd.date_range(
x_values[-1],
row["Release Date"],
periods=max(100, transition_days),
)
x_values.extend(transition_points)
transition_y = current_best + (
row["rating"] - current_best
) * sigmoid_transition(
np.linspace(-6, 6, len(transition_points)), 0, k=1
)
y_values.extend(transition_y)
x_values.append(row["Release Date"])
y_values.append(row["rating"])
current_best = row["rating"]
best_models.append(row)
# Extend the line to the current date
current_date = pd.Timestamp.now()
if x_values[-1] < current_date:
x_values.append(current_date)
y_values.append(current_best)
# Get org color and flag
color, flag = org_info.get(org, ("#808080", ""))
# Add line plot
fig.add_trace(
go.Scatter(
x=x_values,
y=y_values,
mode="lines",
name=f"{i+1}. {org} {flag}",
line=dict(color=color, width=2),
hoverinfo="skip",
)
)
# Add scatter plot for best model points
best_models_df = pd.DataFrame(best_models)
fig.add_trace(
go.Scatter(
x=best_models_df["Release Date"],
y=best_models_df["rating"],
mode="markers",
name=org,
showlegend=False,
marker=dict(color=color, size=8, symbol="circle"),
text=best_models_df["Model"],
hovertemplate="<b>%{text}</b><br>Date: %{x}<br>ELO Score: %{y:.2f}<extra></extra>",
)
)
# Update layout
speak_french = False
if speak_french:
fig.update_layout(
xaxis_title="Date",
title="La course au classement",
yaxis_title="Score ELO",
legend_title="Classement en Novembre 2024",
xaxis_range=[pd.Timestamp("2024-01-01"), current_date], # Extend x-axis for labels
yaxis_range=[1103, 1350],
)
else:
fig.update_layout(
xaxis_title="Date",
yaxis_title="ELO score on Chatbot Arena",
legend_title="Ranking as of November 2024",
title="The race for the best LLM",
hovermode="closest",
xaxis_range=[pd.Timestamp("2024-01-01"), current_date], # Extend x-axis for labels
yaxis_range=[1103, 1350],
)
# apply_template(fig)
fig.update_xaxes(
tickformat="%m-%Y",
)
print(fig)
return fig, df
def filter_df():
return df
set_dark_mode = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue=gr.themes.colors.sky,
secondary_hue=gr.themes.colors.green,
# spacing_size=gr.themes.sizes.spacing_sm,
text_size=gr.themes.sizes.text_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
),
js=set_dark_mode,
) as demo:
gr.Markdown(
"""
<div style="text-align: center; max-width: 650px; margin: auto;">
<h1 style="font-weight: 900; margin-top: 5px;">🚀 The race for the best LLM 🚀</h1>
<p style="text-align: left; margin-top: 30px; margin-bottom: 30px; line-height: 20px;">
This app visualizes the progress of LLMs over time as scored by the <a href="https://leaderboard.lmsys.org/">LMSYS Chatbot Arena</a>.
The app is adapted from <a href="https://huggingface.co/spaces/andrewrreed/closed-vs-open-arena-elo"> this app</a> by Andew Reed,
and is intended to stay up-to-date as new models are released and evaluated.
<div style="text-align: left;">
<strong>Plot info:</strong>
<br>
<ul style="padding-left: 20px;">
<li> The ELO score (y-axis) is a measure of the relative strength of a model based on its performance against other models in the arena. </li>
<li> The Release Date (x-axis) corresponds to when the model was first publicly released or when its ELO results were first reported (for ease of automated updates). </li>
<li> Trend lines are based on Ordinary Least Squares (OLS) regression and adjust based on the filter criteria. </li>
<ul>
</div>
</p>
</div>
"""
)
filtered_df = gr.State()
with gr.Group():
with gr.Tab("Plot"):
plot = gr.Plot(show_label=False)
with gr.Tab("Raw Data"):
display_df = gr.DataFrame()
demo.load(
fn=filter_df,
inputs=[],
outputs=filtered_df,
).then(
fn=make_figure,
inputs=[filtered_df],
outputs=[plot, display_df],
)
demo.launch()