Spaces:
Running
Running
import altair as alt | |
import fev | |
import pandas as pd | |
import pandas.io.formats.style | |
# Color constants - all colors defined in one place | |
COLORS = { | |
"dl_text": "#5A7FA5", | |
"st_text": "#A5795A", | |
# "st_text": "#666666", | |
"bar_fill": "#8d5eb7", | |
"error_bar": "#222222", | |
"point": "#111111", | |
"text_white": "white", | |
"text_black": "black", | |
"text_default": "#111", | |
"gold": "#F7D36B", | |
"silver": "#E5E7EB", | |
"bronze": "#E6B089", | |
"leakage_impute": "#3B82A0", | |
"failure_impute": "#E07B39", | |
} | |
HEATMAP_COLOR_SCHEME = "purplegreen" | |
# Model configuration: (url, org, zero_shot, model_type) | |
MODEL_CONFIG = { | |
# Chronos Models | |
"chronos_tiny": ("amazon/chronos-t5-tiny", "AWS", True, "DL"), | |
"chronos_mini": ("amazon/chronos-t5-mini", "AWS", True, "DL"), | |
"chronos_small": ("amazon/chronos-t5-small", "AWS", True, "DL"), | |
"chronos_base": ("amazon/chronos-t5-base", "AWS", True, "DL"), | |
"chronos_large": ("amazon/chronos-t5-large", "AWS", True, "DL"), | |
"chronos_bolt_tiny": ("amazon/chronos-bolt-tiny", "AWS", True, "DL"), | |
"chronos_bolt_mini": ("amazon/chronos-bolt-mini", "AWS", True, "DL"), | |
"chronos_bolt_small": ("amazon/chronos-bolt-small", "AWS", True, "DL"), | |
"chronos_bolt_base": ("amazon/chronos-bolt-base", "AWS", True, "DL"), | |
"chronos-bolt": ("amazon/chronos-bolt-base", "AWS", True, "DL"), | |
# Moirai Models | |
"moirai_large": ("Salesforce/moirai-1.1-R-large", "Salesforce", True, "DL"), | |
"moirai_base": ("Salesforce/moirai-1.1-R-base", "Salesforce", True, "DL"), | |
"moirai_small": ("Salesforce/moirai-1.1-R-small", "Salesforce", True, "DL"), | |
"moirai-2.0": ("Salesforce/moirai-2.0-R-small", "Salesforce", True, "DL"), | |
# TimesFM Models | |
"timesfm": ("google/timesfm-1.0-200m-pytorch", "Google", True, "DL"), | |
"timesfm-2.0": ("google/timesfm-2.0-500m-pytorch", "Google", True, "DL"), | |
"timesfm-2.5": ("google/timesfm-2.5-200m-pytorch", "Google", True, "DL"), | |
# Toto Models | |
"toto-1.0": ("Datadog/Toto-Open-Base-1.0", "Datadog", True, "DL"), | |
# Other Models | |
"tirex": ("NX-AI/TiRex", "NX-AI", True, "DL"), | |
"tabpfn-ts": ("Prior-Labs/TabPFN-v2-reg", "Prior Labs", True, "DL"), | |
"sundial-base": ("thuml/sundial-base-128m", "Tsinghua University", True, "DL"), | |
"ttm-r2": ("ibm-granite/granite-timeseries-ttm-r2", "IBM", True, "DL"), | |
# Task-specific models | |
"stat. ensemble": ( | |
"https://nixtlaverse.nixtla.io/statsforecast/", | |
"—", | |
False, | |
"ST", | |
), | |
"autoarima": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
"autotheta": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
"autoets": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
"seasonalnaive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
"seasonal naive": ( | |
"https://nixtlaverse.nixtla.io/statsforecast/", | |
"—", | |
False, | |
"ST", | |
), | |
"drift": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
"naive": ("https://nixtlaverse.nixtla.io/statsforecast/", "—", False, "ST"), | |
} | |
ALL_METRICS = { | |
"SQL": ( | |
"SQL: Scaled Quantile Loss", | |
"The [Scaled Quantile Loss (SQL)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.SQL) is a **scale-invariant** metric for evaluating **probabilistic** forecasts.", | |
), | |
"MASE": ( | |
"MASE: Mean Absolute Scaled Error", | |
"The [Mean Absolute Scaled Error (MASE)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.MASE) is a **scale-invariant** metric for evaluating **point** forecasts.", | |
), | |
"WQL": ( | |
"WQL: Weighted Quantile Loss", | |
"The [Weighted Quantile Loss (WQL)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.WQL), is a **scale-dependent** metric for evaluating **probabilistic** forecasts.", | |
), | |
"WAPE": ( | |
"WAPE: Weighted Absolute Percentage Error", | |
"The [Weighted Absolute Percentage Error (WAPE)](https://auto.gluon.ai/dev/tutorials/timeseries/forecasting-metrics.html#autogluon.timeseries.metrics.WAPE) is a **scale-dependent** metric for evaluating **point** forecasts.", | |
), | |
} | |
def format_metric_name(metric_name: str): | |
return ALL_METRICS[metric_name][0] | |
def get_metric_description(metric_name: str): | |
return ALL_METRICS[metric_name][1] | |
def get_model_link(model_name): | |
config = MODEL_CONFIG.get(model_name.lower()) | |
if not config or not config[0]: | |
return "" | |
url = config[0] | |
return url if url.startswith("https:") else f"https://huggingface.co/{url}" | |
def get_model_organization(model_name): | |
config = MODEL_CONFIG.get(model_name.lower()) | |
return config[1] if config else "—" | |
def get_zero_shot_status(model_name): | |
config = MODEL_CONFIG.get(model_name.lower()) | |
return "✓" if config and config[2] else "×" | |
def get_model_type(model_name): | |
config = MODEL_CONFIG.get(model_name.lower()) | |
return config[3] if config else "—" | |
def highlight_model_type_color(cell): | |
config = MODEL_CONFIG.get(cell.lower()) | |
if config: | |
color = COLORS["dl_text"] if config[3] == "DL" else COLORS["st_text"] | |
return f"font-weight: bold; color: {color}" | |
return "font-weight: bold" | |
def format_leaderboard(df: pd.DataFrame): | |
df = df.copy() | |
df["skill_score"] = df["skill_score"].round(1) | |
df["win_rate"] = df["win_rate"].round(1) | |
df["zero_shot"] = df["model_name"].apply(get_zero_shot_status) | |
# Format leakage column: convert to int for all models, 0 for non-zero-shot | |
df["training_corpus_overlap"] = df.apply( | |
lambda row: int(round(row["training_corpus_overlap"] * 100)) | |
if row["zero_shot"] == "✓" | |
else 0, | |
axis=1, | |
) | |
df["link"] = df["model_name"].apply(get_model_link) | |
df["org"] = df["model_name"].apply(get_model_organization) | |
df = df[ | |
[ | |
"model_name", | |
"win_rate", | |
"skill_score", | |
"median_inference_time_s", | |
"training_corpus_overlap", | |
"num_failures", | |
"zero_shot", | |
"org", | |
"link", | |
] | |
] | |
return ( | |
df.style.map(highlight_model_type_color, subset=["model_name"]) | |
.map(lambda x: "font-weight: bold", subset=["zero_shot"]) | |
.apply( | |
lambda x: [ | |
"background-color: #f8f9fa" if i % 2 == 1 else "" for i in range(len(x)) | |
], | |
axis=0, | |
) | |
) | |
def construct_bar_chart(df: pd.DataFrame, col: str, metric_name: str): | |
label = "Skill Score" if col == "skill_score" else "Win Rate" | |
tooltip = [ | |
alt.Tooltip("model_name:N"), | |
alt.Tooltip(f"{col}:Q", format=".2f"), | |
alt.Tooltip(f"{col}_lower:Q", title="95% CI Lower", format=".2f"), | |
alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".2f"), | |
] | |
base_encode = { | |
"y": alt.Y("model_name:N", title="Forecasting Model", sort=None), | |
"tooltip": tooltip, | |
} | |
bars = ( | |
alt.Chart(df) | |
.mark_bar(color=COLORS["bar_fill"], cornerRadius=4) | |
.encode( | |
x=alt.X(f"{col}:Q", title=f"{label} (%)", scale=alt.Scale(zero=False)), | |
**base_encode, | |
) | |
) | |
error_bars = ( | |
alt.Chart(df) | |
.mark_errorbar(ticks={"height": 5}, color=COLORS["error_bar"]) | |
.encode( | |
y=alt.Y("model_name:N", title=None, sort=None), | |
x=alt.X(f"{col}_lower:Q", title=f"{label} (%)"), | |
x2=alt.X2(f"{col}_upper:Q"), | |
tooltip=tooltip, | |
) | |
) | |
points = ( | |
alt.Chart(df) | |
.mark_point(filled=True, color=COLORS["point"]) | |
.encode(x=alt.X(f"{col}:Q", title=f"{label} (%)"), **base_encode) | |
) | |
return ( | |
(bars + error_bars + points) | |
.properties(height=500, title=f"{label} ({metric_name}) with 95% CIs") | |
.configure_title(fontSize=16) | |
) | |
def construct_pairwise_chart(df: pd.DataFrame, col: str, metric_name: str): | |
config = { | |
"win_rate": ("Win Rate", [0, 100], 50, f"abs(datum.{col} - 50) > 30"), | |
"skill_score": ("Skill Score", [-15, 15], 0, f"abs(datum.{col}) > 10"), | |
} | |
cbar_label, domain, domain_mid, text_condition = config[col] | |
df = df.copy() | |
for c in [col, f"{col}_lower", f"{col}_upper"]: | |
df[c] *= 100 | |
model_order = ( | |
df.groupby("model_1")[col].mean().sort_values(ascending=False).index.tolist() | |
) | |
tooltip = [ | |
alt.Tooltip("model_1:N", title="Model 1"), | |
alt.Tooltip("model_2:N", title="Model 2"), | |
alt.Tooltip(f"{col}:Q", title=cbar_label.split(" ")[0], format=".1f"), | |
alt.Tooltip(f"{col}_lower:Q", title="95% CI Lower", format=".1f"), | |
alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".1f"), | |
] | |
base = alt.Chart(df).encode( | |
x=alt.X( | |
"model_2:N", | |
sort=model_order, | |
title="Model 2", | |
axis=alt.Axis(orient="top", labelAngle=-90), | |
), | |
y=alt.Y("model_1:N", sort=model_order, title="Model 1"), | |
) | |
heatmap = base.mark_rect().encode( | |
color=alt.Color( | |
f"{col}:Q", | |
legend=None, | |
scale=alt.Scale( | |
scheme=HEATMAP_COLOR_SCHEME, | |
domain=domain, | |
domainMid=domain_mid, | |
clamp=True, | |
), | |
), | |
tooltip=tooltip, | |
) | |
text_main = base.mark_text(dy=-8, fontSize=8, baseline="top", yOffset=5).encode( | |
text=alt.Text(f"{col}:Q", format=".1f"), | |
color=alt.condition( | |
text_condition, | |
alt.value(COLORS["text_white"]), | |
alt.value(COLORS["text_black"]), | |
), | |
tooltip=tooltip, | |
) | |
return ( | |
(heatmap + text_main) | |
.properties( | |
height=550, | |
title={ | |
"text": f"Pairwise {cbar_label} ({metric_name}) with 95% CIs", | |
"fontSize": 16, | |
}, | |
) | |
.configure_axis(labelFontSize=11, titleFontSize=13, titleFontWeight="bold") | |
.resolve_scale(color="independent") | |
) | |
def construct_pivot_table_from_df( | |
errors: pd.DataFrame, metric_name: str | |
) -> pd.io.formats.style.Styler: | |
"""Construct styled pivot table from precomputed DataFrame.""" | |
def highlight_by_position(styler): | |
rank_colors = {1: COLORS["gold"], 2: COLORS["silver"], 3: COLORS["bronze"]} | |
for row_idx in errors.index: | |
row_ranks = errors.loc[row_idx].rank(method="min") | |
for col_idx in errors.columns: | |
rank = row_ranks[col_idx] | |
style_parts = [] | |
# Rank background colors | |
if rank <= 3: | |
style_parts.append(f"background-color: {rank_colors[rank]}") | |
else: | |
style_parts.append(f"color: {COLORS['text_default']}") | |
if style_parts: | |
styler = styler.map( | |
lambda x, s="; ".join(style_parts): s, | |
subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx], | |
) | |
return styler | |
return highlight_by_position(errors.style).format(precision=3) | |
def construct_pivot_table( | |
summaries: pd.DataFrame, | |
metric_name: str, | |
baseline_model: str, | |
leakage_imputation_model: str, | |
) -> pd.io.formats.style.Styler: | |
errors = fev.pivot_table( | |
summaries=summaries, metric_column=metric_name, task_columns=["task_name"] | |
) | |
train_overlap = ( | |
fev.pivot_table( | |
summaries=summaries, | |
metric_column="trained_on_this_dataset", | |
task_columns=["task_name"], | |
) | |
.fillna(False) | |
.astype(bool) | |
) | |
is_imputed_baseline = errors.isna() | |
is_leakage_imputed = train_overlap | |
# Handle imputations | |
errors = errors.mask(train_overlap, errors[leakage_imputation_model], axis=0) | |
for col in errors.columns: | |
if col != baseline_model: | |
errors[col] = errors[col].fillna(errors[baseline_model]) | |
errors = errors[errors.rank(axis=1).mean().sort_values().index] | |
errors.index.rename("Task name", inplace=True) | |
def highlight_by_position(styler): | |
rank_colors = {1: COLORS["gold"], 2: COLORS["silver"], 3: COLORS["bronze"]} | |
for row_idx in errors.index: | |
row_ranks = errors.loc[row_idx].rank(method="min") | |
for col_idx in errors.columns: | |
rank = row_ranks[col_idx] | |
style_parts = [] | |
# Rank background colors | |
if rank <= 3: | |
style_parts.append(f"background-color: {rank_colors[rank]}") | |
# Imputation text colors | |
if is_leakage_imputed.loc[row_idx, col_idx]: | |
style_parts.append(f"color: {COLORS['leakage_impute']}") | |
elif is_imputed_baseline.loc[row_idx, col_idx]: | |
style_parts.append(f"color: {COLORS['failure_impute']}") | |
elif not style_parts or ( | |
len(style_parts) == 1 and "font-weight" in style_parts[0] | |
): | |
style_parts.append(f"color: {COLORS['text_default']}") | |
if style_parts: | |
styler = styler.map( | |
lambda x, s="; ".join(style_parts): s, | |
subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx], | |
) | |
return styler | |
return highlight_by_position(errors.style).format(precision=3) | |