fev-leaderboard / src /utils.py
shchuro's picture
Fix colorbar
72f0175
raw
history blame
13.7 kB
import altair as alt
import fev
import pandas as pd
import pandas.io.formats.style
# Color constants - all colors defined in one place
COLORS = {
"dl_text": "#5A7FA5",
"st_text": "#A5795A",
# "st_text": "#666666",
"bar_fill": "#8d5eb7",
"error_bar": "#222222",
"point": "#111111",
"text_white": "white",
"text_black": "black",
"text_default": "#111",
"gold": "#F7D36B",
"silver": "#E5E7EB",
"bronze": "#E6B089",
"leakage_impute": "#3B82A0",
"failure_impute": "#E07B39",
}
HEATMAP_COLOR_SCHEME = "purplegreen"
# Model configuration: (url, org, zero_shot, model_type)
MODEL_CONFIG = {
# Chronos Models
"chronos_tiny": ("amazon/chronos-t5-tiny", "AWS", True, "DL"),
"chronos_mini": ("amazon/chronos-t5-mini", "AWS", True, "DL"),
"chronos_small": ("amazon/chronos-t5-small", "AWS", True, "DL"),
"chronos_base": ("amazon/chronos-t5-base", "AWS", True, "DL"),
"chronos_large": ("amazon/chronos-t5-large", "AWS", True, "DL"),
"chronos_bolt_tiny": ("amazon/chronos-bolt-tiny", "AWS", True, "DL"),
"chronos_bolt_mini": ("amazon/chronos-bolt-mini", "AWS", True, "DL"),
"chronos_bolt_small": ("amazon/chronos-bolt-small", "AWS", True, "DL"),
"chronos_bolt_base": ("amazon/chronos-bolt-base", "AWS", True, "DL"),
"chronos-bolt": ("amazon/chronos-bolt-base", "AWS", True, "DL"),
# Moirai Models
"moirai_large": ("Salesforce/moirai-1.1-R-large", "Salesforce", True, "DL"),
"moirai_base": ("Salesforce/moirai-1.1-R-base", "Salesforce", True, "DL"),
"moirai_small": ("Salesforce/moirai-1.1-R-small", "Salesforce", True, "DL"),
"moirai-2.0": ("Salesforce/moirai-2.0-R-small", "Salesforce", True, "DL"),
# TimesFM Models
"timesfm": ("google/timesfm-1.0-200m-pytorch", "Google", True, "DL"),
"timesfm-2.0": ("google/timesfm-2.0-500m-pytorch", "Google", True, "DL"),
"timesfm-2.5": ("google/timesfm-2.5-200m-pytorch", "Google", True, "DL"),
# Toto Models
"toto-1.0": ("Datadog/Toto-Open-Base-1.0", "Datadog", True, "DL"),
# Other Models
"tirex": ("NX-AI/TiRex", "NX-AI", True, "DL"),
"tabpfn-ts": ("Prior-Labs/TabPFN-v2-reg", "Prior Labs", True, "DL"),
"sundial-base": ("thuml/sundial-base-128m", "Tsinghua University", True, "DL"),
"ttm-r2": ("ibm-granite/granite-timeseries-ttm-r2", "IBM", True, "DL"),
# Task-specific models
"stat. ensemble": (
"https://nixtlaverse.nixtla.io/statsforecast/",
"—",
False,
"ST",
),
"autoarima": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"),
"autotheta": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"),
"autoets": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"),
"seasonalnaive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"),
"seasonal naive": (
"https://nixtlaverse.nixtla.io/statsforecast/",
"—",
False,
"ST",
),
"drift": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"),
"naive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"),
}
ALL_METRICS = {
"SQL": (
"SQL: Scaled Quantile Loss",
"The [Scaled Quantile Loss (SQL)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.SQL) is a **scale-invariant** metric for evaluating **probabilistic** forecasts.",
),
"MASE": (
"MASE: Mean Absolute Scaled Error",
"The [Mean Absolute Scaled Error (MASE)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.MASE) is a **scale-invariant** metric for evaluating **point** forecasts.",
),
"WQL": (
"WQL: Weighted Quantile Loss",
"The [Weighted Quantile Loss (WQL)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.WQL), is a **scale-dependent** metric for evaluating **probabilistic** forecasts.",
),
"WAPE": (
"WAPE: Weighted Absolute Percentage Error",
"The [Weighted Absolute Percentage Error (WAPE)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.WAPE) is a **scale-dependent** metric for evaluating **point** forecasts.",
),
}
def format_metric_name(metric_name: str):
return ALL_METRICS[metric_name][0]
def get_metric_description(metric_name: str):
return ALL_METRICS[metric_name][1]
def get_model_link(model_name):
config = MODEL_CONFIG.get(model_name.lower())
if not config or not config[0]:
return ""
url = config[0]
return url if url.startswith("https:") else f"https://huggingface.co/{url}"
def get_model_organization(model_name):
config = MODEL_CONFIG.get(model_name.lower())
return config[1] if config else "—"
def get_zero_shot_status(model_name):
config = MODEL_CONFIG.get(model_name.lower())
return "✓" if config and config[2] else "×"
def get_model_type(model_name):
config = MODEL_CONFIG.get(model_name.lower())
return config[3] if config else "—"
def highlight_model_type_color(cell):
config = MODEL_CONFIG.get(cell.lower())
if config:
color = COLORS["dl_text"] if config[3] == "DL" else COLORS["st_text"]
return f"font-weight: bold; color: {color}"
return "font-weight: bold"
def format_leaderboard(df: pd.DataFrame):
df = df.copy()
df["skill_score"] = df["skill_score"].round(1)
df["win_rate"] = df["win_rate"].round(1)
df["zero_shot"] = df["model_name"].apply(get_zero_shot_status)
# Format leakage column: convert to int for all models, 0 for non-zero-shot
df["training_corpus_overlap"] = df.apply(
lambda row: int(round(row["training_corpus_overlap"] * 100))
if row["zero_shot"] == "✓"
else 0,
axis=1,
)
df["link"] = df["model_name"].apply(get_model_link)
df["org"] = df["model_name"].apply(get_model_organization)
df = df[
[
"model_name",
"win_rate",
"skill_score",
"median_inference_time_s",
"training_corpus_overlap",
"num_failures",
"zero_shot",
"org",
"link",
]
]
return (
df.style.map(highlight_model_type_color, subset=["model_name"])
.map(lambda x: "font-weight: bold", subset=["zero_shot"])
.apply(
lambda x: [
"background-color: #f8f9fa" if i % 2 == 1 else "" for i in range(len(x))
],
axis=0,
)
)
def construct_bar_chart(df: pd.DataFrame, col: str, metric_name: str):
label = "Skill Score" if col == "skill_score" else "Win Rate"
tooltip = [
alt.Tooltip("model_name:N"),
alt.Tooltip(f"{col}:Q", format=".2f"),
alt.Tooltip(f"{col}_lower:Q", title="95% CI Lower", format=".2f"),
alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".2f"),
]
base_encode = {
"y": alt.Y("model_name:N", title="Forecasting Model", sort=None),
"tooltip": tooltip,
}
bars = (
alt.Chart(df)
.mark_bar(color=COLORS["bar_fill"], cornerRadius=4)
.encode(
x=alt.X(f"{col}:Q", title=f"{label} (%)", scale=alt.Scale(zero=False)),
**base_encode,
)
)
error_bars = (
alt.Chart(df)
.mark_errorbar(ticks={"height": 5}, color=COLORS["error_bar"])
.encode(
y=alt.Y("model_name:N", title=None, sort=None),
x=alt.X(f"{col}_lower:Q", title=f"{label} (%)"),
x2=alt.X2(f"{col}_upper:Q"),
tooltip=tooltip,
)
)
points = (
alt.Chart(df)
.mark_point(filled=True, color=COLORS["point"])
.encode(x=alt.X(f"{col}:Q", title=f"{label} (%)"), **base_encode)
)
return (
(bars + error_bars + points)
.properties(height=500, title=f"{label} ({metric_name}) with 95% CIs")
.configure_title(fontSize=16)
)
def construct_pairwise_chart(df: pd.DataFrame, col: str, metric_name: str):
config = {
"win_rate": ("Win Rate", [0, 100], 50, f"abs(datum.{col} - 50) > 30"),
"skill_score": ("Skill Score", [-15, 15], 0, f"abs(datum.{col}) > 10"),
}
cbar_label, domain, domain_mid, text_condition = config[col]
df = df.copy()
for c in [col, f"{col}_lower", f"{col}_upper"]:
df[c] *= 100
model_order = (
df.groupby("model_1")[col].mean().sort_values(ascending=False).index.tolist()
)
tooltip = [
alt.Tooltip("model_1:N", title="Model 1"),
alt.Tooltip("model_2:N", title="Model 2"),
alt.Tooltip(f"{col}:Q", title=cbar_label.split(" ")[0], format=".1f"),
alt.Tooltip(f"{col}_lower:Q", title="95% CI Lower", format=".1f"),
alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".1f"),
]
base = alt.Chart(df).encode(
x=alt.X(
"model_2:N",
sort=model_order,
title="Model 2",
axis=alt.Axis(orient="top", labelAngle=-90),
),
y=alt.Y("model_1:N", sort=model_order, title="Model 1"),
)
heatmap = base.mark_rect().encode(
color=alt.Color(
f"{col}:Q",
legend=None,
scale=alt.Scale(
scheme=HEATMAP_COLOR_SCHEME,
domain=domain,
domainMid=domain_mid,
clamp=True,
),
),
tooltip=tooltip,
)
text_main = base.mark_text(dy=-8, fontSize=8, baseline="top", yOffset=5).encode(
text=alt.Text(f"{col}:Q", format=".1f"),
color=alt.condition(
text_condition,
alt.value(COLORS["text_white"]),
alt.value(COLORS["text_black"]),
),
tooltip=tooltip,
)
return (
(heatmap + text_main)
.properties(
height=550,
title={
"text": f"Pairwise {cbar_label} ({metric_name}) with 95% CIs",
"fontSize": 16,
},
)
.configure_axis(labelFontSize=11, titleFontSize=13, titleFontWeight="bold")
.resolve_scale(color="independent")
)
def construct_pivot_table_from_df(
errors: pd.DataFrame, metric_name: str
) -> pd.io.formats.style.Styler:
"""Construct styled pivot table from precomputed DataFrame."""
def highlight_by_position(styler):
rank_colors = {1: COLORS["gold"], 2: COLORS["silver"], 3: COLORS["bronze"]}
for row_idx in errors.index:
row_ranks = errors.loc[row_idx].rank(method="min")
for col_idx in errors.columns:
rank = row_ranks[col_idx]
style_parts = []
# Rank background colors
if rank <= 3:
style_parts.append(f"background-color: {rank_colors[rank]}")
else:
style_parts.append(f"color: {COLORS['text_default']}")
if style_parts:
styler = styler.map(
lambda x, s="; ".join(style_parts): s,
subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx],
)
return styler
return highlight_by_position(errors.style).format(precision=3)
def construct_pivot_table(
summaries: pd.DataFrame,
metric_name: str,
baseline_model: str,
leakage_imputation_model: str,
) -> pd.io.formats.style.Styler:
errors = fev.pivot_table(
summaries=summaries, metric_column=metric_name, task_columns=["task_name"]
)
train_overlap = (
fev.pivot_table(
summaries=summaries,
metric_column="trained_on_this_dataset",
task_columns=["task_name"],
)
.fillna(False)
.astype(bool)
)
is_imputed_baseline = errors.isna()
is_leakage_imputed = train_overlap
# Handle imputations
errors = errors.mask(train_overlap, errors[leakage_imputation_model], axis=0)
for col in errors.columns:
if col != baseline_model:
errors[col] = errors[col].fillna(errors[baseline_model])
errors = errors[errors.rank(axis=1).mean().sort_values().index]
errors.index.rename("Task name", inplace=True)
def highlight_by_position(styler):
rank_colors = {1: COLORS["gold"], 2: COLORS["silver"], 3: COLORS["bronze"]}
for row_idx in errors.index:
row_ranks = errors.loc[row_idx].rank(method="min")
for col_idx in errors.columns:
rank = row_ranks[col_idx]
style_parts = []
# Rank background colors
if rank <= 3:
style_parts.append(f"background-color: {rank_colors[rank]}")
# Imputation text colors
if is_leakage_imputed.loc[row_idx, col_idx]:
style_parts.append(f"color: {COLORS['leakage_impute']}")
elif is_imputed_baseline.loc[row_idx, col_idx]:
style_parts.append(f"color: {COLORS['failure_impute']}")
elif not style_parts or (
len(style_parts) == 1 and "font-weight" in style_parts[0]
):
style_parts.append(f"color: {COLORS['text_default']}")
if style_parts:
styler = styler.map(
lambda x, s="; ".join(style_parts): s,
subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx],
)
return styler
return highlight_by_position(errors.style).format(precision=3)