shchuro commited on
Commit
72f0175
Β·
1 Parent(s): a30697b

Fix colorbar

Browse files
Files changed (1) hide show
  1. src/utils.py +80 -19
src/utils.py CHANGED
@@ -53,12 +53,22 @@ MODEL_CONFIG = {
53
  "sundial-base": ("thuml/sundial-base-128m", "Tsinghua University", True, "DL"),
54
  "ttm-r2": ("ibm-granite/granite-timeseries-ttm-r2", "IBM", True, "DL"),
55
  # Task-specific models
56
- "stat. ensemble": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
 
 
 
 
 
57
  "autoarima": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
58
  "autotheta": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
59
  "autoets": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
60
  "seasonalnaive": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
61
- "seasonal naive": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
 
 
 
 
 
62
  "drift": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
63
  "naive": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
64
  }
@@ -130,7 +140,10 @@ def format_leaderboard(df: pd.DataFrame):
130
  df["zero_shot"] = df["model_name"].apply(get_zero_shot_status)
131
  # Format leakage column: convert to int for all models, 0 for non-zero-shot
132
  df["training_corpus_overlap"] = df.apply(
133
- lambda row: int(round(row["training_corpus_overlap"] * 100)) if row["zero_shot"] == "βœ“" else 0, axis=1
 
 
 
134
  )
135
  df["link"] = df["model_name"].apply(get_model_link)
136
  df["org"] = df["model_name"].apply(get_model_organization)
@@ -150,7 +163,12 @@ def format_leaderboard(df: pd.DataFrame):
150
  return (
151
  df.style.map(highlight_model_type_color, subset=["model_name"])
152
  .map(lambda x: "font-weight: bold", subset=["zero_shot"])
153
- .apply(lambda x: ["background-color: #f8f9fa" if i % 2 == 1 else "" for i in range(len(x))], axis=0)
 
 
 
 
 
154
  )
155
 
156
 
@@ -164,12 +182,18 @@ def construct_bar_chart(df: pd.DataFrame, col: str, metric_name: str):
164
  alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".2f"),
165
  ]
166
 
167
- base_encode = {"y": alt.Y("model_name:N", title="Forecasting Model", sort=None), "tooltip": tooltip}
 
 
 
168
 
169
  bars = (
170
  alt.Chart(df)
171
  .mark_bar(color=COLORS["bar_fill"], cornerRadius=4)
172
- .encode(x=alt.X(f"{col}:Q", title=f"{label} (%)", scale=alt.Scale(zero=False)), **base_encode)
 
 
 
173
  )
174
 
175
  error_bars = (
@@ -207,7 +231,9 @@ def construct_pairwise_chart(df: pd.DataFrame, col: str, metric_name: str):
207
  for c in [col, f"{col}_lower", f"{col}_upper"]:
208
  df[c] *= 100
209
 
210
- model_order = df.groupby("model_1")[col].mean().sort_values(ascending=False).index.tolist()
 
 
211
 
212
  tooltip = [
213
  alt.Tooltip("model_1:N", title="Model 1"),
@@ -218,34 +244,56 @@ def construct_pairwise_chart(df: pd.DataFrame, col: str, metric_name: str):
218
  ]
219
 
220
  base = alt.Chart(df).encode(
221
- x=alt.X("model_2:N", sort=model_order, title="Model 2", axis=alt.Axis(orient="top", labelAngle=-90)),
 
 
 
 
 
222
  y=alt.Y("model_1:N", sort=model_order, title="Model 1"),
223
  )
224
 
225
  heatmap = base.mark_rect().encode(
226
  color=alt.Color(
227
  f"{col}:Q",
228
- legend=alt.Legend(title=f"{cbar_label} (%)", direction="vertical", orient="right"),
229
- scale=alt.Scale(scheme=HEATMAP_COLOR_SCHEME, domain=domain, domainMid=domain_mid, clamp=True),
 
 
 
 
 
230
  ),
231
  tooltip=tooltip,
232
  )
233
 
234
  text_main = base.mark_text(dy=-8, fontSize=8, baseline="top", yOffset=5).encode(
235
  text=alt.Text(f"{col}:Q", format=".1f"),
236
- color=alt.condition(text_condition, alt.value(COLORS["text_white"]), alt.value(COLORS["text_black"])),
 
 
 
 
237
  tooltip=tooltip,
238
  )
239
 
240
  return (
241
  (heatmap + text_main)
242
- .properties(height=550, title={"text": f"Pairwise {cbar_label} ({metric_name}) with 95% CIs", "fontSize": 16})
 
 
 
 
 
 
243
  .configure_axis(labelFontSize=11, titleFontSize=13, titleFontWeight="bold")
244
  .resolve_scale(color="independent")
245
  )
246
 
247
 
248
- def construct_pivot_table_from_df(errors: pd.DataFrame, metric_name: str) -> pd.io.formats.style.Styler:
 
 
249
  """Construct styled pivot table from precomputed DataFrame."""
250
 
251
  def highlight_by_position(styler):
@@ -265,7 +313,8 @@ def construct_pivot_table_from_df(errors: pd.DataFrame, metric_name: str) -> pd.
265
 
266
  if style_parts:
267
  styler = styler.map(
268
- lambda x, s="; ".join(style_parts): s, subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx]
 
269
  )
270
  return styler
271
 
@@ -273,11 +322,20 @@ def construct_pivot_table_from_df(errors: pd.DataFrame, metric_name: str) -> pd.
273
 
274
 
275
  def construct_pivot_table(
276
- summaries: pd.DataFrame, metric_name: str, baseline_model: str, leakage_imputation_model: str
 
 
 
277
  ) -> pd.io.formats.style.Styler:
278
- errors = fev.pivot_table(summaries=summaries, metric_column=metric_name, task_columns=["task_name"])
 
 
279
  train_overlap = (
280
- fev.pivot_table(summaries=summaries, metric_column="trained_on_this_dataset", task_columns=["task_name"])
 
 
 
 
281
  .fillna(False)
282
  .astype(bool)
283
  )
@@ -312,12 +370,15 @@ def construct_pivot_table(
312
  style_parts.append(f"color: {COLORS['leakage_impute']}")
313
  elif is_imputed_baseline.loc[row_idx, col_idx]:
314
  style_parts.append(f"color: {COLORS['failure_impute']}")
315
- elif not style_parts or (len(style_parts) == 1 and "font-weight" in style_parts[0]):
 
 
316
  style_parts.append(f"color: {COLORS['text_default']}")
317
 
318
  if style_parts:
319
  styler = styler.map(
320
- lambda x, s="; ".join(style_parts): s, subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx]
 
321
  )
322
  return styler
323
 
 
53
  "sundial-base": ("thuml/sundial-base-128m", "Tsinghua University", True, "DL"),
54
  "ttm-r2": ("ibm-granite/granite-timeseries-ttm-r2", "IBM", True, "DL"),
55
  # Task-specific models
56
+ "stat. ensemble": (
57
+ "https://nixtlaverse.nixtla.io/statsforecast/",
58
+ "β€”",
59
+ False,
60
+ "ST",
61
+ ),
62
  "autoarima": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
63
  "autotheta": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
64
  "autoets": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
65
  "seasonalnaive": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
66
+ "seasonal naive": (
67
+ "https://nixtlaverse.nixtla.io/statsforecast/",
68
+ "β€”",
69
+ False,
70
+ "ST",
71
+ ),
72
  "drift": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
73
  "naive": ("https://nixtlaverse.nixtla.io/statsforecast/", "β€”", False, "ST"),
74
  }
 
140
  df["zero_shot"] = df["model_name"].apply(get_zero_shot_status)
141
  # Format leakage column: convert to int for all models, 0 for non-zero-shot
142
  df["training_corpus_overlap"] = df.apply(
143
+ lambda row: int(round(row["training_corpus_overlap"] * 100))
144
+ if row["zero_shot"] == "βœ“"
145
+ else 0,
146
+ axis=1,
147
  )
148
  df["link"] = df["model_name"].apply(get_model_link)
149
  df["org"] = df["model_name"].apply(get_model_organization)
 
163
  return (
164
  df.style.map(highlight_model_type_color, subset=["model_name"])
165
  .map(lambda x: "font-weight: bold", subset=["zero_shot"])
166
+ .apply(
167
+ lambda x: [
168
+ "background-color: #f8f9fa" if i % 2 == 1 else "" for i in range(len(x))
169
+ ],
170
+ axis=0,
171
+ )
172
  )
173
 
174
 
 
182
  alt.Tooltip(f"{col}_upper:Q", title="95% CI Upper", format=".2f"),
183
  ]
184
 
185
+ base_encode = {
186
+ "y": alt.Y("model_name:N", title="Forecasting Model", sort=None),
187
+ "tooltip": tooltip,
188
+ }
189
 
190
  bars = (
191
  alt.Chart(df)
192
  .mark_bar(color=COLORS["bar_fill"], cornerRadius=4)
193
+ .encode(
194
+ x=alt.X(f"{col}:Q", title=f"{label} (%)", scale=alt.Scale(zero=False)),
195
+ **base_encode,
196
+ )
197
  )
198
 
199
  error_bars = (
 
231
  for c in [col, f"{col}_lower", f"{col}_upper"]:
232
  df[c] *= 100
233
 
234
+ model_order = (
235
+ df.groupby("model_1")[col].mean().sort_values(ascending=False).index.tolist()
236
+ )
237
 
238
  tooltip = [
239
  alt.Tooltip("model_1:N", title="Model 1"),
 
244
  ]
245
 
246
  base = alt.Chart(df).encode(
247
+ x=alt.X(
248
+ "model_2:N",
249
+ sort=model_order,
250
+ title="Model 2",
251
+ axis=alt.Axis(orient="top", labelAngle=-90),
252
+ ),
253
  y=alt.Y("model_1:N", sort=model_order, title="Model 1"),
254
  )
255
 
256
  heatmap = base.mark_rect().encode(
257
  color=alt.Color(
258
  f"{col}:Q",
259
+ legend=None,
260
+ scale=alt.Scale(
261
+ scheme=HEATMAP_COLOR_SCHEME,
262
+ domain=domain,
263
+ domainMid=domain_mid,
264
+ clamp=True,
265
+ ),
266
  ),
267
  tooltip=tooltip,
268
  )
269
 
270
  text_main = base.mark_text(dy=-8, fontSize=8, baseline="top", yOffset=5).encode(
271
  text=alt.Text(f"{col}:Q", format=".1f"),
272
+ color=alt.condition(
273
+ text_condition,
274
+ alt.value(COLORS["text_white"]),
275
+ alt.value(COLORS["text_black"]),
276
+ ),
277
  tooltip=tooltip,
278
  )
279
 
280
  return (
281
  (heatmap + text_main)
282
+ .properties(
283
+ height=550,
284
+ title={
285
+ "text": f"Pairwise {cbar_label} ({metric_name}) with 95% CIs",
286
+ "fontSize": 16,
287
+ },
288
+ )
289
  .configure_axis(labelFontSize=11, titleFontSize=13, titleFontWeight="bold")
290
  .resolve_scale(color="independent")
291
  )
292
 
293
 
294
+ def construct_pivot_table_from_df(
295
+ errors: pd.DataFrame, metric_name: str
296
+ ) -> pd.io.formats.style.Styler:
297
  """Construct styled pivot table from precomputed DataFrame."""
298
 
299
  def highlight_by_position(styler):
 
313
 
314
  if style_parts:
315
  styler = styler.map(
316
+ lambda x, s="; ".join(style_parts): s,
317
+ subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx],
318
  )
319
  return styler
320
 
 
322
 
323
 
324
  def construct_pivot_table(
325
+ summaries: pd.DataFrame,
326
+ metric_name: str,
327
+ baseline_model: str,
328
+ leakage_imputation_model: str,
329
  ) -> pd.io.formats.style.Styler:
330
+ errors = fev.pivot_table(
331
+ summaries=summaries, metric_column=metric_name, task_columns=["task_name"]
332
+ )
333
  train_overlap = (
334
+ fev.pivot_table(
335
+ summaries=summaries,
336
+ metric_column="trained_on_this_dataset",
337
+ task_columns=["task_name"],
338
+ )
339
  .fillna(False)
340
  .astype(bool)
341
  )
 
370
  style_parts.append(f"color: {COLORS['leakage_impute']}")
371
  elif is_imputed_baseline.loc[row_idx, col_idx]:
372
  style_parts.append(f"color: {COLORS['failure_impute']}")
373
+ elif not style_parts or (
374
+ len(style_parts) == 1 and "font-weight" in style_parts[0]
375
+ ):
376
  style_parts.append(f"color: {COLORS['text_default']}")
377
 
378
  if style_parts:
379
  styler = styler.map(
380
+ lambda x, s="; ".join(style_parts): s,
381
+ subset=pd.IndexSlice[row_idx:row_idx, col_idx:col_idx],
382
  )
383
  return styler
384