akera commited on
Commit
b9c4788
·
verified ·
1 Parent(s): b4ca380

Update src/plotting.py

Browse files
Files changed (1) hide show
  1. src/plotting.py +211 -143
src/plotting.py CHANGED
@@ -38,86 +38,121 @@ def create_leaderboard_plot(
38
  )
39
  return fig
40
 
41
- # Get top N models for this track
42
- metric_col = f"{track}_{metric}"
43
- ci_lower_col = f"{track}_ci_lower"
44
- ci_upper_col = f"{track}_ci_upper"
45
-
46
- if metric_col not in df.columns:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  fig = go.Figure()
48
- fig.add_annotation(
49
- text=f"Metric {metric} not available for {track} track",
50
- xref="paper", yref="paper",
51
- x=0.5, y=0.5, showarrow=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
 
 
 
 
53
  return fig
54
-
55
- # Filter and sort
56
- valid_models = df[(df[metric_col] > 0)].head(top_n)
57
-
58
- if valid_models.empty:
59
  fig = go.Figure()
60
- fig.add_annotation(text="No valid models found", x=0.5, y=0.5, showarrow=False)
61
- return fig
62
-
63
- # Create color mapping by category
64
- colors = [MODEL_CATEGORIES.get(cat, {}).get("color", "#808080") for cat in valid_models["model_category"]]
65
-
66
- # Main bar plot
67
- fig = go.Figure()
68
-
69
- # Add bars with error bars if confidence intervals available
70
- error_x = None
71
- if ci_lower_col in valid_models.columns and ci_upper_col in valid_models.columns:
72
- error_x = dict(
73
- type="data",
74
- array=valid_models[ci_upper_col] - valid_models[metric_col],
75
- arrayminus=valid_models[metric_col] - valid_models[ci_lower_col],
76
- visible=True,
77
- thickness=2,
78
- width=4,
79
  )
80
-
81
- fig.add_trace(go.Bar(
82
- y=valid_models["model_name"],
83
- x=valid_models[metric_col],
84
- orientation="h",
85
- marker=dict(color=colors, line=dict(color="black", width=0.5)),
86
- error_x=error_x,
87
- text=[f"{score:.3f}" for score in valid_models[metric_col]],
88
- textposition="auto",
89
- hovertemplate=(
90
- "<b>%{y}</b><br>" +
91
- f"{metric.title()}: %{{x:.4f}}<br>" +
92
- "Category: %{customdata[0]}<br>" +
93
- "Author: %{customdata[1]}<br>" +
94
- "Samples: %{customdata[2]}<br>" +
95
- "<extra></extra>"
96
- ),
97
- customdata=list(zip(
98
- valid_models["model_category"],
99
- valid_models["author"],
100
- valid_models.get(f"{track}_samples", [0] * len(valid_models))
101
- )),
102
- ))
103
-
104
- # Customize layout
105
- track_info = EVALUATION_TRACKS[track]
106
- fig.update_layout(
107
- title=f"🏆 {track_info['name']} - {metric.title()} Score",
108
- xaxis_title=f"{metric.title()} Score (with 95% CI)",
109
- yaxis_title="Models",
110
- height=max(400, len(valid_models) * 35 + 100),
111
- margin=dict(l=20, r=20, t=60, b=20),
112
- paper_bgcolor="rgba(0,0,0,0)",
113
- plot_bgcolor="rgba(0,0,0,0)",
114
- font=dict(size=12),
115
- )
116
-
117
- # Reverse y-axis to show best model at top
118
- fig.update_yaxes(autorange="reversed")
119
-
120
- return fig
121
 
122
 
123
  def create_language_pair_heatmap(
@@ -201,79 +236,112 @@ def create_performance_comparison_plot(df: pd.DataFrame, track: str) -> go.Figur
201
  fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False)
202
  return fig
203
 
204
- metric_col = f"{track}_quality"
205
- ci_lower_col = f"{track}_ci_lower"
206
- ci_upper_col = f"{track}_ci_upper"
207
-
208
- # Filter to models with data for this track
209
- valid_models = df[
210
- (df[metric_col] > 0) &
211
- (df[ci_lower_col].notna()) &
212
- (df[ci_upper_col].notna())
213
- ].head(10)
214
-
215
- if valid_models.empty:
 
 
 
 
 
 
 
 
 
 
 
216
  fig = go.Figure()
217
- fig.add_annotation(text="No models with confidence intervals", x=0.5, y=0.5, showarrow=False)
218
- return fig
219
-
220
- fig = go.Figure()
221
-
222
- # Add confidence intervals as error bars
223
- for i, (_, model) in enumerate(valid_models.iterrows()):
224
- category = model["model_category"]
225
- color = MODEL_CATEGORIES.get(category, {}).get("color", "#808080")
226
 
227
- # Main point
228
- fig.add_trace(go.Scatter(
229
- x=[model[metric_col]],
230
- y=[i],
231
- mode="markers",
232
- marker=dict(
233
- size=12,
234
- color=color,
235
- line=dict(color="black", width=1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  ),
237
- name=model["model_name"],
238
  showlegend=False,
239
- hovertemplate=(
240
- f"<b>{model['model_name']}</b><br>" +
241
- f"Quality: {model[metric_col]:.4f}<br>" +
242
- f"95% CI: [{model[ci_lower_col]:.4f}, {model[ci_upper_col]:.4f}]<br>" +
243
- f"Category: {category}<br>" +
244
- "<extra></extra>"
245
- ),
246
- ))
247
 
248
- # Confidence interval line
249
- fig.add_trace(go.Scatter(
250
- x=[model[ci_lower_col], model[ci_upper_col]],
251
- y=[i, i],
252
- mode="lines",
253
- line=dict(color=color, width=3),
254
- showlegend=False,
255
- hoverinfo="skip",
256
- ))
257
-
258
- # Customize layout
259
- track_info = EVALUATION_TRACKS[track]
260
- fig.update_layout(
261
- title=f"📊 {track_info['name']} - Performance Comparison",
262
- xaxis_title="Quality Score",
263
- yaxis_title="Models",
264
- height=max(400, len(valid_models) * 40 + 100),
265
- yaxis=dict(
266
- tickmode="array",
267
- tickvals=list(range(len(valid_models))),
268
- ticktext=valid_models["model_name"].tolist(),
269
- autorange="reversed",
270
- ),
271
- showlegend=False,
272
- paper_bgcolor="rgba(0,0,0,0)",
273
- plot_bgcolor="rgba(0,0,0,0)",
274
- )
275
-
276
- return fig
277
 
278
 
279
  def create_language_pair_comparison_plot(pairs_df: pd.DataFrame, track: str) -> go.Figure:
 
38
  )
39
  return fig
40
 
41
+ try:
42
+ # Get top N models for this track
43
+ metric_col = f"{track}_{metric}"
44
+ ci_lower_col = f"{track}_ci_lower"
45
+ ci_upper_col = f"{track}_ci_upper"
46
+
47
+ if metric_col not in df.columns:
48
+ fig = go.Figure()
49
+ fig.add_annotation(
50
+ text=f"Metric {metric} not available for {track} track",
51
+ xref="paper", yref="paper",
52
+ x=0.5, y=0.5, showarrow=False,
53
+ )
54
+ return fig
55
+
56
+ # Ensure numeric columns are properly typed
57
+ numeric_cols = [metric_col, ci_lower_col, ci_upper_col]
58
+ for col in numeric_cols:
59
+ if col in df.columns:
60
+ df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
61
+
62
+ # Filter and sort
63
+ valid_models = df[(df[metric_col] > 0)].head(top_n).copy()
64
+
65
+ if valid_models.empty:
66
+ fig = go.Figure()
67
+ fig.add_annotation(text="No valid models found", x=0.5, y=0.5, showarrow=False)
68
+ return fig
69
+
70
+ # Create color mapping by category
71
+ colors = [MODEL_CATEGORIES.get(cat, {}).get("color", "#808080") for cat in valid_models["model_category"]]
72
+
73
+ # Main bar plot
74
  fig = go.Figure()
75
+
76
+ # Add bars with error bars if confidence intervals available
77
+ error_x = None
78
+ if ci_lower_col in valid_models.columns and ci_upper_col in valid_models.columns:
79
+ try:
80
+ error_x = dict(
81
+ type="data",
82
+ array=valid_models[ci_upper_col] - valid_models[metric_col],
83
+ arrayminus=valid_models[metric_col] - valid_models[ci_lower_col],
84
+ visible=True,
85
+ thickness=2,
86
+ width=4,
87
+ )
88
+ except Exception as e:
89
+ print(f"Error creating error bars: {e}")
90
+ error_x = None
91
+
92
+ # Safely format text values
93
+ try:
94
+ text_values = [f"{float(score):.3f}" for score in valid_models[metric_col]]
95
+ except:
96
+ text_values = ["0.000"] * len(valid_models)
97
+
98
+ # Safely prepare custom data
99
+ try:
100
+ samples_col = f"{track}_samples"
101
+ samples_data = valid_models.get(samples_col, [0] * len(valid_models))
102
+ customdata = list(zip(
103
+ valid_models["model_category"].fillna("unknown"),
104
+ valid_models["author"].fillna("Anonymous"),
105
+ [int(float(x)) if pd.notnull(x) else 0 for x in samples_data]
106
+ ))
107
+ except Exception as e:
108
+ print(f"Error preparing custom data: {e}")
109
+ customdata = [("unknown", "Anonymous", 0)] * len(valid_models)
110
+
111
+ fig.add_trace(go.Bar(
112
+ y=valid_models["model_name"],
113
+ x=valid_models[metric_col],
114
+ orientation="h",
115
+ marker=dict(color=colors, line=dict(color="black", width=0.5)),
116
+ error_x=error_x,
117
+ text=text_values,
118
+ textposition="auto",
119
+ hovertemplate=(
120
+ "<b>%{y}</b><br>" +
121
+ f"{metric.title()}: %{{x:.4f}}<br>" +
122
+ "Category: %{customdata[0]}<br>" +
123
+ "Author: %{customdata[1]}<br>" +
124
+ "Samples: %{customdata[2]}<br>" +
125
+ "<extra></extra>"
126
+ ),
127
+ customdata=customdata,
128
+ ))
129
+
130
+ # Customize layout
131
+ track_info = EVALUATION_TRACKS[track]
132
+ fig.update_layout(
133
+ title=f"🏆 {track_info['name']} - {metric.title()} Score",
134
+ xaxis_title=f"{metric.title()} Score (with 95% CI)",
135
+ yaxis_title="Models",
136
+ height=max(400, len(valid_models) * 35 + 100),
137
+ margin=dict(l=20, r=20, t=60, b=20),
138
+ paper_bgcolor="rgba(0,0,0,0)",
139
+ plot_bgcolor="rgba(0,0,0,0)",
140
+ font=dict(size=12),
141
  )
142
+
143
+ # Reverse y-axis to show best model at top
144
+ fig.update_yaxes(autorange="reversed")
145
+
146
  return fig
147
+
148
+ except Exception as e:
149
+ print(f"Error creating leaderboard plot: {e}")
 
 
150
  fig = go.Figure()
151
+ fig.add_annotation(
152
+ text=f"Error creating plot: {str(e)}",
153
+ x=0.5, y=0.5, showarrow=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  )
155
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
 
158
  def create_language_pair_heatmap(
 
236
  fig.add_annotation(text="No data available", x=0.5, y=0.5, showarrow=False)
237
  return fig
238
 
239
+ try:
240
+ metric_col = f"{track}_quality"
241
+ ci_lower_col = f"{track}_ci_lower"
242
+ ci_upper_col = f"{track}_ci_upper"
243
+
244
+ # Ensure numeric columns are properly typed
245
+ numeric_cols = [metric_col, ci_lower_col, ci_upper_col]
246
+ for col in numeric_cols:
247
+ if col in df.columns:
248
+ df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
249
+
250
+ # Filter to models with data for this track
251
+ valid_models = df[
252
+ (df[metric_col] > 0) &
253
+ (df[ci_lower_col].notna()) &
254
+ (df[ci_upper_col].notna())
255
+ ].head(10).copy()
256
+
257
+ if valid_models.empty:
258
+ fig = go.Figure()
259
+ fig.add_annotation(text="No models with confidence intervals", x=0.5, y=0.5, showarrow=False)
260
+ return fig
261
+
262
  fig = go.Figure()
 
 
 
 
 
 
 
 
 
263
 
264
+ # Add confidence intervals as error bars
265
+ for i, (_, model) in enumerate(valid_models.iterrows()):
266
+ try:
267
+ category = str(model["model_category"])
268
+ color = MODEL_CATEGORIES.get(category, {}).get("color", "#808080")
269
+ model_name = str(model["model_name"])
270
+
271
+ # Safely extract numeric values
272
+ quality_val = float(model[metric_col])
273
+ ci_lower_val = float(model[ci_lower_col])
274
+ ci_upper_val = float(model[ci_upper_col])
275
+
276
+ # Main point
277
+ fig.add_trace(go.Scatter(
278
+ x=[quality_val],
279
+ y=[i],
280
+ mode="markers",
281
+ marker=dict(
282
+ size=12,
283
+ color=color,
284
+ line=dict(color="black", width=1),
285
+ ),
286
+ name=model_name,
287
+ showlegend=False,
288
+ hovertemplate=(
289
+ f"<b>{model_name}</b><br>" +
290
+ f"Quality: {quality_val:.4f}<br>" +
291
+ f"95% CI: [{ci_lower_val:.4f}, {ci_upper_val:.4f}]<br>" +
292
+ f"Category: {category}<br>" +
293
+ "<extra></extra>"
294
+ ),
295
+ ))
296
+
297
+ # Confidence interval line
298
+ fig.add_trace(go.Scatter(
299
+ x=[ci_lower_val, ci_upper_val],
300
+ y=[i, i],
301
+ mode="lines",
302
+ line=dict(color=color, width=3),
303
+ showlegend=False,
304
+ hoverinfo="skip",
305
+ ))
306
+
307
+ except Exception as e:
308
+ print(f"Error adding model {i} to comparison plot: {e}")
309
+ continue
310
+
311
+ # Safely prepare tick labels
312
+ try:
313
+ tick_labels = [str(name) for name in valid_models["model_name"]]
314
+ except:
315
+ tick_labels = [f"Model {i}" for i in range(len(valid_models))]
316
+
317
+ # Customize layout
318
+ track_info = EVALUATION_TRACKS[track]
319
+ fig.update_layout(
320
+ title=f"📊 {track_info['name']} - Performance Comparison",
321
+ xaxis_title="Quality Score",
322
+ yaxis_title="Models",
323
+ height=max(400, len(valid_models) * 40 + 100),
324
+ yaxis=dict(
325
+ tickmode="array",
326
+ tickvals=list(range(len(valid_models))),
327
+ ticktext=tick_labels,
328
+ autorange="reversed",
329
  ),
 
330
  showlegend=False,
331
+ paper_bgcolor="rgba(0,0,0,0)",
332
+ plot_bgcolor="rgba(0,0,0,0)",
333
+ )
 
 
 
 
 
334
 
335
+ return fig
336
+
337
+ except Exception as e:
338
+ print(f"Error creating performance comparison plot: {e}")
339
+ fig = go.Figure()
340
+ fig.add_annotation(
341
+ text=f"Error creating plot: {str(e)}",
342
+ x=0.5, y=0.5, showarrow=False
343
+ )
344
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
 
347
  def create_language_pair_comparison_plot(pairs_df: pd.DataFrame, track: str) -> go.Figure: