ai-forever commited on
Commit
36c47f8
·
verified ·
1 Parent(s): aff180f

Update leaderboard display

Browse files
Files changed (1) hide show
  1. app.py +803 -0
app.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ from plotly.subplots import make_subplots
8
+ import os
9
+ import traceback
10
+ from datetime import datetime
11
+ from packaging import version
12
+
13
+ # Color scheme for charts
14
+ COLORS = px.colors.qualitative.Plotly
15
+
16
+ # Line colors for radar charts
17
+ line_colors = [
18
+ "#EE4266",
19
+ "#00a6ed",
20
+ "#ECA72C",
21
+ "#B42318",
22
+ "#3CBBB1",
23
+ ]
24
+
25
+ # Fill colors for radar charts
26
+ fill_colors = [
27
+ "rgba(238,66,102,0.05)",
28
+ "rgba(0,166,237,0.05)",
29
+ "rgba(236,167,44,0.05)",
30
+ "rgba(180,35,24,0.05)",
31
+ "rgba(60,187,177,0.05)",
32
+ ]
33
+
34
+ # Define the question categories
35
+ QUESTION_CATEGORIES = ["simple", "set", "mh", "cond", "comp"]
36
+ METRIC_TYPES = ["retrieval", "generation"]
37
+
38
+ def load_results():
39
+ """Load results from the results.json file."""
40
+ try:
41
+ # Get the directory of the current script
42
+ script_dir = os.path.dirname(os.path.abspath(__file__))
43
+ # Build the path to results.json
44
+ results_path = os.path.join(script_dir, 'results.json')
45
+
46
+ print(f"Loading results from: {results_path}")
47
+
48
+ with open(results_path, 'r', encoding='utf-8') as f:
49
+ results = json.load(f)
50
+ print(f"Successfully loaded results with {len(results.get('items', {}))} version(s)")
51
+ return results
52
+ except FileNotFoundError:
53
+ # Return empty structure if file doesn't exist
54
+ print(f"Results file not found, creating empty structure")
55
+ return {"items": {}, "last_version": "1.0", "n_questions": "0"}
56
+ except Exception as e:
57
+ print(f"Error loading results: {e}")
58
+ print(traceback.format_exc())
59
+ return {"items": {}, "last_version": "1.0", "n_questions": "0"}
60
+
61
+ def filter_and_process_results(results, n_versions, only_actual_versions):
62
+ """Filter results by version and process them for display."""
63
+ if not results or "items" not in results:
64
+ return pd.DataFrame(), [], [], []
65
+
66
+ all_items = results["items"]
67
+ last_version_str = results.get("last_version", "1.0")
68
+ last_version = version.parse(last_version_str)
69
+
70
+ print(f"Last version: {last_version_str}")
71
+
72
+ # Group items by model_name
73
+ model_groups = {}
74
+
75
+ for version_str, version_items in all_items.items():
76
+ version_obj = version.parse(version_str)
77
+ for item_id, item in version_items.items():
78
+ model_name = item.get("model_name", "Unknown")
79
+
80
+ if model_name not in model_groups:
81
+ model_groups[model_name] = []
82
+
83
+ # Add version info to the item (both as string and as parsed version object for comparison)
84
+ item["version_str"] = version_str
85
+ item["version_obj"] = version_obj
86
+ model_groups[model_name].append(item)
87
+
88
+ rows = []
89
+ for model_name, items in model_groups.items():
90
+ # Sort items by version (newest first)
91
+ items.sort(key=lambda x: x["version_obj"], reverse=True)
92
+
93
+ # Filter versions based on selection
94
+ filtered_items = []
95
+
96
+ if only_actual_versions:
97
+ # Get the n most recent actual dataset versions
98
+ all_versions = sorted([version.parse(v_str) for v_str in all_items.keys()], reverse=True)
99
+ # Take at most n_versions
100
+ versions_to_consider = all_versions[:n_versions] if all_versions else []
101
+
102
+ # Filter items that match those versions
103
+ filtered_items = [item for item in items if any(item["version_obj"] == v for v in versions_to_consider)]
104
+ else:
105
+ # Consider n_versions most recent items for this model
106
+ filtered_items = items[:n_versions]
107
+
108
+ if not filtered_items:
109
+ continue
110
+
111
+ config = filtered_items[0]["config"] # Use config from most recent version
112
+
113
+ # Create row with basic info
114
+ row = {
115
+ 'Model': model_name,
116
+ 'Embeddings': config.get('embedding_model', 'N/A'),
117
+ 'Retriever': config.get('retriever_type', 'N/A'),
118
+ 'Top-K': config.get('retrieval_config', {}).get('top_k', 'N/A'),
119
+ 'Versions': ", ".join([item["version_str"] for item in filtered_items]),
120
+ 'Last Updated': filtered_items[0].get("timestamp", "")
121
+ }
122
+
123
+ # Format timestamp if available
124
+ if row['Last Updated']:
125
+ try:
126
+ dt = datetime.fromisoformat(row['Last Updated'].replace('Z', '+00:00'))
127
+ row['Last Updated'] = dt.strftime("%Y-%m-%d")
128
+ except:
129
+ pass
130
+
131
+ # Process metrics based on categories
132
+ category_metrics = {
133
+ category: {
134
+ metric_type: {
135
+ "avg": 0.0,
136
+ "count": 0
137
+ } for metric_type in METRIC_TYPES
138
+ } for category in QUESTION_CATEGORIES
139
+ }
140
+
141
+ # Collect metrics by category
142
+ for item in filtered_items:
143
+ metrics = item.get("metrics", {})
144
+ for category in QUESTION_CATEGORIES:
145
+ if category in metrics:
146
+ for metric_type in METRIC_TYPES:
147
+ if metric_type in metrics[category]:
148
+ metric_values = metrics[category][metric_type]
149
+ avg_value = sum(metric_values.values()) / len(metric_values)
150
+
151
+ # Add to the running sum for this category and metric type
152
+ category_metrics[category][metric_type]["avg"] += avg_value
153
+ category_metrics[category][metric_type]["count"] += 1
154
+
155
+ # Calculate averages and add to row
156
+ for category in QUESTION_CATEGORIES:
157
+ for metric_type in METRIC_TYPES:
158
+ metric_data = category_metrics[category][metric_type]
159
+ if metric_data["count"] > 0:
160
+ avg_value = metric_data["avg"] / metric_data["count"]
161
+ # Add to row with appropriate column name
162
+ col_name = f"{category}_{metric_type}"
163
+ row[col_name] = round(avg_value, 4)
164
+
165
+ # Calculate overall averages for each metric type
166
+ for metric_type in METRIC_TYPES:
167
+ total_sum = 0
168
+ total_count = 0
169
+
170
+ for category in QUESTION_CATEGORIES:
171
+ metric_data = category_metrics[category][metric_type]
172
+ if metric_data["count"] > 0:
173
+ total_sum += metric_data["avg"]
174
+ total_count += metric_data["count"]
175
+
176
+ if total_count > 0:
177
+ row[f"{metric_type}_avg"] = round(total_sum / total_count, 4)
178
+
179
+ rows.append(row)
180
+
181
+ # Create DataFrame
182
+ df = pd.DataFrame(rows)
183
+
184
+ # Get lists of metrics for each category
185
+ category_metrics = []
186
+ for category in QUESTION_CATEGORIES:
187
+ metrics = []
188
+ for metric_type in METRIC_TYPES:
189
+ col_name = f"{category}_{metric_type}"
190
+ if col_name in df.columns:
191
+ metrics.append(col_name)
192
+ if metrics:
193
+ category_metrics.append((category, metrics))
194
+
195
+ # Define retrieval and generation columns for radar charts
196
+ retrieval_metrics = [f"{category}_retrieval" for category in QUESTION_CATEGORIES if f"{category}_retrieval" in df.columns]
197
+ generation_metrics = [f"{category}_generation" for category in QUESTION_CATEGORIES if f"{category}_generation" in df.columns]
198
+
199
+ return df, retrieval_metrics, generation_metrics, category_metrics
200
+
201
+ def create_radar_chart(df, selected_models, metrics, title):
202
+ """Create a radar chart for the selected models and metrics."""
203
+ if not metrics or len(selected_models) == 0:
204
+ # Return empty figure if no metrics or models selected
205
+ fig = go.Figure()
206
+ fig.update_layout(
207
+ title=title,
208
+ title_font_size=16,
209
+ height=400,
210
+ width=500,
211
+ margin=dict(l=30, r=30, t=50, b=30)
212
+ )
213
+ return fig
214
+
215
+ # Filter dataframe for selected models
216
+ filtered_df = df[df['Model'].isin(selected_models)]
217
+
218
+ if filtered_df.empty:
219
+ # Return empty figure if no data
220
+ fig = go.Figure()
221
+ fig.update_layout(
222
+ title=title,
223
+ title_font_size=16,
224
+ height=400,
225
+ width=500,
226
+ margin=dict(l=30, r=30, t=50, b=30)
227
+ )
228
+ return fig
229
+
230
+ # Limit to top 5 models for better visualization (similar to inspiration file)
231
+ if len(filtered_df) > 5:
232
+ filtered_df = filtered_df.head(5)
233
+
234
+ # Prepare data for radar chart
235
+ categories = [m.split('_', 1)[0] for m in metrics] # Get category name (simple, set, etc.)
236
+
237
+ fig = go.Figure()
238
+
239
+ # Process in reverse order to match inspiration file
240
+ for i, (_, row) in enumerate(filtered_df.iterrows()):
241
+ values = [row[m] for m in metrics]
242
+ # Close the loop for radar chart
243
+ values.append(values[0])
244
+ categories_loop = categories + [categories[0]]
245
+
246
+ fig.add_trace(go.Scatterpolar(
247
+ name=row['Model'],
248
+ r=values,
249
+ theta=categories_loop,
250
+ showlegend=True,
251
+ mode="lines",
252
+ line=dict(width=2, color=line_colors[i % len(line_colors)]),
253
+ fill="toself",
254
+ fillcolor=fill_colors[i % len(fill_colors)]
255
+ ))
256
+
257
+ fig.update_layout(
258
+ font=dict(size=13, color="black"),
259
+ template="plotly_white",
260
+ polar=dict(
261
+ radialaxis=dict(
262
+ visible=True,
263
+ gridcolor="black",
264
+ linecolor="rgba(0,0,0,0)",
265
+ gridwidth=1,
266
+ showticklabels=False,
267
+ ticks="",
268
+ range=[0, 1] # Ensure consistent range for scores
269
+ ),
270
+ angularaxis=dict(
271
+ gridcolor="black",
272
+ gridwidth=1.5,
273
+ linecolor="rgba(0,0,0,0)"
274
+ ),
275
+ ),
276
+ legend=dict(
277
+ orientation="h",
278
+ yanchor="bottom",
279
+ y=-0.35,
280
+ xanchor="center",
281
+ x=0.4,
282
+ itemwidth=30,
283
+ font=dict(size=13),
284
+ entrywidth=0.6,
285
+ entrywidthmode="fraction",
286
+ ),
287
+ margin=dict(l=0, r=16, t=30, b=30),
288
+ autosize=True,
289
+ )
290
+
291
+ return fig
292
+
293
+ def create_summary_df(df, retrieval_metrics, generation_metrics):
294
+ """Create a summary dataframe with averaged metrics for display."""
295
+ if df.empty:
296
+ return pd.DataFrame()
297
+
298
+ summary_df = df.copy()
299
+
300
+ # Add retrieval average
301
+ if retrieval_metrics:
302
+ retrieval_avg = summary_df[retrieval_metrics].mean(axis=1).round(4)
303
+ summary_df['Retrieval (avg)'] = retrieval_avg
304
+
305
+ # Add generation average
306
+ if generation_metrics:
307
+ generation_avg = summary_df[generation_metrics].mean(axis=1).round(4)
308
+ summary_df['Generation (avg)'] = generation_avg
309
+
310
+ # Add total score if both averages exist
311
+ if 'Retrieval (avg)' in summary_df.columns and 'Generation (avg)' in summary_df.columns:
312
+ summary_df['Total Score'] = summary_df['Retrieval (avg)'] + summary_df['Generation (avg)']
313
+ summary_df = summary_df.sort_values('Total Score', ascending=False)
314
+
315
+ # Select columns for display
316
+ summary_cols = ['Model', 'Embeddings', 'Retriever', 'Top-K']
317
+ if 'Retrieval (avg)' in summary_df.columns:
318
+ summary_cols.append('Retrieval (avg)')
319
+ if 'Generation (avg)' in summary_df.columns:
320
+ summary_cols.append('Generation (avg)')
321
+ if 'Total Score' in summary_df.columns:
322
+ summary_cols.append('Total Score')
323
+ if 'Versions' in summary_df.columns:
324
+ summary_cols.append('Versions')
325
+ if 'Last Updated' in summary_df.columns:
326
+ summary_cols.append('Last Updated')
327
+
328
+ return summary_df[summary_cols]
329
+
330
+ def create_category_df(df, category, retrieval_col, generation_col):
331
+ """Create a dataframe for a specific category with detailed metrics."""
332
+ if df.empty or retrieval_col not in df.columns or generation_col not in df.columns:
333
+ return pd.DataFrame()
334
+
335
+ category_df = df.copy()
336
+
337
+ # Calculate total score for this category
338
+ category_df[f'{category} Score'] = category_df[retrieval_col] + category_df[generation_col]
339
+
340
+ # Sort by total score
341
+ category_df = category_df.sort_values(f'{category} Score', ascending=False)
342
+
343
+ # Select columns for display
344
+ category_cols = ['Model', 'Embeddings', 'Retriever', retrieval_col, generation_col, f'{category} Score']
345
+
346
+ # Rename columns for display
347
+ category_df = category_df[category_cols].rename(columns={
348
+ retrieval_col: 'Retrieval',
349
+ generation_col: 'Generation'
350
+ })
351
+
352
+ return category_df
353
+
354
+ # Load initial data
355
+ results = load_results()
356
+ last_version = results.get("last_version", "1.0")
357
+ n_questions = results.get("n_questions", "100")
358
+ date_title = results.get("date_title", "---")
359
+
360
+ # Initial data processing
361
+ df, retrieval_metrics, generation_metrics, category_metrics = filter_and_process_results(
362
+ results, n_versions=1, only_actual_versions=True
363
+ )
364
+
365
+ # Pre-generate charts for initial display
366
+ default_models = df['Model'].head(5).tolist() if not df.empty else []
367
+ initial_gen_chart = create_radar_chart(df, default_models, generation_metrics, "Performance on Generation Tasks")
368
+ initial_ret_chart = create_radar_chart(df, default_models, retrieval_metrics, "Performance on Retrieval Tasks")
369
+
370
+ # Create summary dataframe
371
+ summary_df = create_summary_df(df, retrieval_metrics, generation_metrics)
372
+
373
+ with gr.Blocks(css="""
374
+ .title-container {
375
+ text-align: center;
376
+ margin-bottom: 10px;
377
+ }
378
+ .description-text {
379
+ text-align: left;
380
+ padding: 10px;
381
+ margin-bottom: 0px;
382
+ }
383
+ .version-info {
384
+ text-align: center;
385
+ padding: 10px;
386
+ background-color: #f0f0f0;
387
+ border-radius: 8px;
388
+ margin-bottom: 15px;
389
+ }
390
+ .version-selector {
391
+ padding: 15px;
392
+ border: 1px solid #ddd;
393
+ border-radius: 8px;
394
+ margin-bottom: 20px;
395
+ background-color: #f9f9f9;
396
+ height: 100%;
397
+ }
398
+ .citation-block {
399
+ padding: 15px;
400
+ border: 1px solid #ddd;
401
+ border-radius: 8px;
402
+ margin-bottom: 20px;
403
+ background-color: #f9f9f9;
404
+ font-family: monospace;
405
+ font-size: 14px;
406
+ overflow-x: auto;
407
+ height: 100%;
408
+ }
409
+ .flex-row-container {
410
+ display: flex;
411
+ justify-content: space-between;
412
+ gap: 20px;
413
+ width: 100%;
414
+ }
415
+ .charts-container {
416
+ display: flex;
417
+ gap: 20px;
418
+ margin-bottom: 20px;
419
+ }
420
+ .chart-box {
421
+ flex: 1;
422
+ border: 1px solid #eee;
423
+ border-radius: 8px;
424
+ padding: 10px;
425
+ background-color: white;
426
+ min-height: 550px; /* Increased height to accommodate legend at bottom */
427
+ }
428
+ .metrics-table {
429
+ border: 1px solid #eee;
430
+ border-radius: 8px;
431
+ padding: 15px;
432
+ background-color: white;
433
+ }
434
+ .info-text {
435
+ font-size: 0.9em;
436
+ font-style: italic;
437
+ color: #666;
438
+ margin-top: 5px;
439
+ }
440
+ footer {
441
+ text-align: center;
442
+ margin-top: 30px;
443
+ font-size: 0.9em;
444
+ color: #666;
445
+ }
446
+ /* Style for selected rows */
447
+ table tbody tr.selected {
448
+ background-color: rgba(25, 118, 210, 0.1) !important;
449
+ border-left: 3px solid #1976d2;
450
+ }
451
+ /* Add this class via JavaScript */
452
+ .gr-table tbody tr.selected td:first-child {
453
+ font-weight: bold;
454
+ color: #1976d2;
455
+ }
456
+ .category-tab {
457
+ padding: 10px;
458
+ }
459
+ .chart-title {
460
+ font-size: 1.2em;
461
+ font-weight: bold;
462
+ margin-bottom: 10px;
463
+ text-align: center;
464
+ }
465
+ .clear-charts-button {
466
+ display: flex;
467
+ justify-content: center;
468
+ margin-top: 10px;
469
+ margin-bottom: 20px;
470
+ }
471
+ """) as demo:
472
+ # Title
473
+ with gr.Row(elem_classes=["title-container"]):
474
+ gr.Markdown("# 🐙 Dynamic RAG Benchmark")
475
+
476
+ # Version info
477
+ with gr.Row(elem_classes=["description-text"]):
478
+ gr.Markdown(f"На этом лидерборде можно сравнить RAG системы в разрезе генеративных и поисковых метрик моделей по вопросам разного типа (простые вопросы, сравнения, multi-hop, условные и др.). <li>Вопросы автоматичеки генерируются на основе новостных источников.</li><li>Обновление датасета с вопросами происходит регулярно, при этом пересчитываются все метрики для открытых моделей.</li><li>Для пользовательских сабмитов учитываются последние посчитанные для них метрики.</li><li>Чтобы посчитать ранее отправленную конфигурацию на последней версии данных, используйте submit_id, полученный при первой отправке через клиент (см. инструкцию ниже).</li>")
479
+
480
+ # Version info
481
+ with gr.Row(elem_classes=["version-info"]):
482
+ gr.Markdown(f"## Версия {last_version} → {n_questions} вопросов, сгенерированных по новостным источникам → {date_title}")
483
+
484
+ # Radar Charts
485
+ with gr.Row(elem_classes=["charts-container"]):
486
+ with gr.Column(elem_classes=["chart-box"]):
487
+ gr.Markdown("### Генеративные метрики", elem_classes=["chart-title"])
488
+ generation_chart = gr.Plot(value=initial_gen_chart)
489
+
490
+ with gr.Column(elem_classes=["chart-box"]):
491
+ gr.Markdown("### Метрики поиска", elem_classes=["chart-title"])
492
+ retrieval_chart = gr.Plot(value=initial_ret_chart)
493
+
494
+ # Clear Charts Button
495
+ with gr.Row(elem_classes=["clear-charts-button"]):
496
+ clear_charts_btn = gr.Button("Очистить графики", variant="secondary")
497
+
498
+ # Metrics table with tabs
499
+ with gr.Tabs(elem_classes=["metrics-table"]) as metrics_tabs:
500
+ with gr.TabItem("Общая таблица"):
501
+ selected_models = gr.State(default_models)
502
+
503
+ # If dataframe is empty, show a message
504
+ if df.empty:
505
+ gr.Markdown("No data available. Please submit some results.")
506
+ metrics_table = gr.DataFrame()
507
+ else:
508
+ metrics_table = gr.DataFrame(
509
+ value=summary_df,
510
+ headers=summary_df.columns.tolist(),
511
+ datatype=["str"] * len(summary_df.columns),
512
+ row_count=(min(10, len(summary_df)) if not summary_df.empty else 0),
513
+ col_count=(len(summary_df.columns) if not summary_df.empty else 0),
514
+ interactive=False,
515
+ wrap=True
516
+ )
517
+
518
+ with gr.TabItem("По типам вопросов"):
519
+ # Create tabs for each category
520
+ category_tabs = gr.Tabs()
521
+ category_tables = {}
522
+
523
+ # Dictionary to map category codes to display names
524
+ category_display_names = {
525
+ "simple": "Simple Questions",
526
+ "set": "Set-based",
527
+ "mh": "Multi-hop",
528
+ "cond": "Conditional",
529
+ "comp": "Comparison"
530
+ }
531
+
532
+ with category_tabs:
533
+ for category, _ in category_metrics:
534
+ if f"{category}_retrieval" in df.columns and f"{category}_generation" in df.columns:
535
+ with gr.TabItem(category_display_names.get(category, category.capitalize()), elem_classes=["category-tab"]):
536
+ # Create dataframe for this category
537
+ category_df = create_category_df(df, category, f"{category}_retrieval", f"{category}_generation")
538
+
539
+ if category_df.empty:
540
+ gr.Markdown(f"No data available for {category_display_names.get(category, category)} category.")
541
+ category_tables[category] = gr.DataFrame()
542
+ else:
543
+ gr.Markdown(f"#### Performance on {category_display_names.get(category, category)}")
544
+ category_tables[category] = gr.DataFrame(
545
+ value=category_df,
546
+ headers=category_df.columns.tolist(),
547
+ datatype=["str"] * len(category_df.columns),
548
+ row_count=(min(10, len(category_df)) if not category_df.empty else 0),
549
+ col_count=(len(category_df.columns) if not category_df.empty else 0),
550
+ interactive=False,
551
+ wrap=True
552
+ )
553
+
554
+ # Version selector and Citation block in a flex container
555
+ with gr.Row():
556
+ # Citation block (left side)
557
+ with gr.Column(scale=1, elem_classes=["citation-block"]):
558
+ gr.Markdown("### Цитирование")
559
+ gr.Markdown("""
560
+ ```
561
+ @article{dynamic-rag-benchmark,
562
+ title={Dynamic RAG Benchmark},
563
+ author={RAG Benchmark Team},
564
+ journal={arXiv preprint},
565
+ year={2024},
566
+ url={https://github.com/rag-benchmark}
567
+ }
568
+ ```
569
+
570
+ Шаблон для цитирования нашего бенча.
571
+ """)
572
+
573
+ # Version selector (right side)
574
+ with gr.Column(scale=1, elem_classes=["version-selector"]):
575
+ gr.Markdown("### Выбор версий")
576
+ with gr.Column():
577
+ with gr.Row():
578
+ with gr.Column(scale=3):
579
+ only_actual_versions = gr.Checkbox(
580
+ label="Только актуальные версии",
581
+ value=True,
582
+ info="Считать, начиная с актуальной версии датасета"
583
+ )
584
+ with gr.Column(scale=5):
585
+ n_versions_slider = gr.Slider(
586
+ minimum=1,
587
+ maximum=5,
588
+ value=1,
589
+ step=1,
590
+ label="Взять n последних версий",
591
+ info="Количество версий для подсчета метрик"
592
+ )
593
+ with gr.Row():
594
+ filter_btn = gr.Button("Применить фильтр", variant="primary")
595
+
596
+ gr.Markdown(
597
+ "Кликайте на модели в таблице, чтобы добавить их в графики",
598
+ elem_classes=["info-text"]
599
+ )
600
+
601
+ # Footer
602
+ with gr.Row():
603
+ gr.Markdown("""
604
+ <footer>Dynamic RAG Benchmark Leaderboard</footer>
605
+ """)
606
+
607
+ # Handle row selection for radar charts
608
+ def update_charts(evt: gr.SelectData, selected_models):
609
+ try:
610
+ # Get current data with the latest filters
611
+ current_df, current_ret_metrics, current_gen_metrics, _ = filter_and_process_results(
612
+ results, n_versions=n_versions_slider.value, only_actual_versions=only_actual_versions.value
613
+ )
614
+
615
+ # Debug info
616
+ print(f"Selection event: {evt}, type: {type(evt)}")
617
+
618
+ selected_model = None
619
+
620
+ # Extract the selected model based on the row index
621
+ try:
622
+ # Get the table component that was clicked
623
+ component = evt.target
624
+
625
+ # Get the row index
626
+ row_idx = evt.index[0] if isinstance(evt.index, list) else evt.index
627
+ print(f"Row index: {row_idx}")
628
+
629
+ # Determine what type of data we're dealing with and extract model name
630
+ # First check if it's a summary table
631
+ if component is metrics_table:
632
+ # Summary table was clicked
633
+ if isinstance(summary_df, pd.DataFrame) and 0 <= row_idx < len(summary_df):
634
+ selected_model = summary_df.iloc[row_idx]['Model']
635
+ print(f"Selected from summary table: {selected_model}")
636
+ else:
637
+ # Check if it's a category table
638
+ for category, table in category_tables.items():
639
+ if component is table:
640
+ # Get the category dataframe
641
+ category_df = create_category_df(
642
+ current_df,
643
+ category,
644
+ f"{category}_retrieval",
645
+ f"{category}_generation"
646
+ )
647
+
648
+ if isinstance(category_df, pd.DataFrame) and 0 <= row_idx < len(category_df):
649
+ selected_model = category_df.iloc[row_idx]['Model']
650
+ print(f"Selected from {category} table: {selected_model}")
651
+ break
652
+
653
+ # If we still couldn't identify the model, try to get it from the raw data
654
+ if selected_model is None and hasattr(component, "value"):
655
+ table_value = component.value
656
+ if isinstance(table_value, pd.DataFrame) and 0 <= row_idx < len(table_value):
657
+ selected_model = table_value.iloc[row_idx]['Model']
658
+ elif isinstance(table_value, list) and 0 <= row_idx < len(table_value):
659
+ selected_model = table_value[row_idx][0] # Assuming Model is the first column
660
+ elif isinstance(table_value, dict) and 'data' in table_value and 0 <= row_idx < len(table_value['data']):
661
+ selected_model = table_value['data'][row_idx][0]
662
+ except Exception as e:
663
+ print(f"Error extracting model name: {e}")
664
+ traceback.print_exc()
665
+
666
+ # If we found a model name, toggle its selection
667
+ if selected_model:
668
+ print(f"Selected model: {selected_model}")
669
+
670
+ # Make sure the model exists in the current dataframe
671
+ available_models = current_df['Model'].tolist() if not current_df.empty else []
672
+
673
+ if selected_model in available_models:
674
+ # Add to list if not already there, otherwise remove (toggle selection)
675
+ if selected_model in selected_models:
676
+ selected_models.remove(selected_model)
677
+ else:
678
+ selected_models.append(selected_model)
679
+ else:
680
+ print(f"Model {selected_model} not found in current dataframe")
681
+
682
+ # Ensure only models from the current dataframe are included
683
+ available_models = current_df['Model'].tolist() if not current_df.empty else []
684
+ selected_models = [model for model in selected_models if model in available_models]
685
+
686
+ # If no models are selected after filtering, use the first available model
687
+ if not selected_models and available_models:
688
+ selected_models = [available_models[0]]
689
+
690
+ # Create radar charts using the current dataframe and metrics
691
+ gen_chart = create_radar_chart(current_df, selected_models, current_gen_metrics, "Performance on Generation Tasks")
692
+ ret_chart = create_radar_chart(current_df, selected_models, current_ret_metrics, "Performance on Retrieval Tasks")
693
+
694
+ return selected_models, gen_chart, ret_chart
695
+ except Exception as e:
696
+ print(f"Error in update_charts: {e}")
697
+ print(traceback.format_exc())
698
+ return selected_models, generation_chart.value, retrieval_chart.value
699
+
700
+ # Use custom event handler for row selection
701
+ metrics_table.select(
702
+ fn=update_charts,
703
+ inputs=[selected_models],
704
+ outputs=[selected_models, generation_chart, retrieval_chart]
705
+ )
706
+
707
+ # Add selection handlers for category tables too
708
+ for category_table in category_tables.values():
709
+ category_table.select(
710
+ fn=update_charts,
711
+ inputs=[selected_models],
712
+ outputs=[selected_models, generation_chart, retrieval_chart]
713
+ )
714
+
715
+ # Handle version filter changes
716
+ def update_data(n_versions, only_actual, current_selected_models):
717
+ try:
718
+ # Get updated data
719
+ new_df, new_ret_metrics, new_gen_metrics, new_category_metrics = filter_and_process_results(
720
+ results, n_versions=n_versions, only_actual_versions=only_actual
721
+ )
722
+
723
+ # Get available models
724
+ available_models = new_df['Model'].tolist() if not new_df.empty else []
725
+
726
+ # Filter selected models to only include those that exist in the new dataset
727
+ filtered_selected_models = [model for model in current_selected_models if model in available_models]
728
+
729
+ # If no previously selected models remain, select the top models
730
+ if not filtered_selected_models and available_models:
731
+ filtered_selected_models = available_models[:min(5, len(available_models))]
732
+
733
+ # Create radar charts
734
+ gen_chart = create_radar_chart(new_df, filtered_selected_models, new_gen_metrics, "Performance on Generation Tasks")
735
+ ret_chart = create_radar_chart(new_df, filtered_selected_models, new_ret_metrics, "Performance on Retrieval Tasks")
736
+
737
+ # Create summary dataframe
738
+ summary_df = create_summary_df(new_df, new_ret_metrics, new_gen_metrics)
739
+
740
+ # Create category tables dictionary for output
741
+ category_tables_output = {}
742
+
743
+ # First initialize all tables to empty DataFrame
744
+ for category in category_tables.keys():
745
+ category_tables_output[category] = pd.DataFrame()
746
+
747
+ # Then populate available tables
748
+ for category, _ in new_category_metrics:
749
+ if f"{category}_retrieval" in new_df.columns and f"{category}_generation" in new_df.columns:
750
+ category_df = create_category_df(new_df, category, f"{category}_retrieval", f"{category}_generation")
751
+ if category in category_tables:
752
+ category_tables_output[category] = category_df if not category_df.empty else pd.DataFrame()
753
+
754
+ # Prepare all outputs
755
+ outputs = [summary_df, gen_chart, ret_chart, filtered_selected_models]
756
+
757
+ # Add category tables to outputs in the same order as in category_tables
758
+ for category in category_tables.keys():
759
+ outputs.append(category_tables_output.get(category, pd.DataFrame()))
760
+
761
+ # Update global df for later use
762
+ global df, retrieval_metrics, generation_metrics
763
+ df = new_df
764
+ retrieval_metrics = new_ret_metrics
765
+ generation_metrics = new_gen_metrics
766
+
767
+ return outputs
768
+ except Exception as e:
769
+ print(f"Error in update_data: {e}")
770
+ print(traceback.format_exc())
771
+ # Return original values in case of error
772
+ empty_tables = [pd.DataFrame() for _ in category_tables]
773
+ return summary_df, generation_chart.value, retrieval_chart.value, current_selected_models, *empty_tables
774
+
775
+ # Define filter button outputs
776
+ filter_outputs = [metrics_table, generation_chart, retrieval_chart, selected_models]
777
+ # Add category tables to outputs
778
+ for category_table in category_tables.values():
779
+ filter_outputs.append(category_table)
780
+
781
+ filter_btn.click(
782
+ fn=update_data,
783
+ inputs=[n_versions_slider, only_actual_versions, selected_models],
784
+ outputs=filter_outputs
785
+ )
786
+
787
+ # Function to clear charts
788
+ def clear_charts():
789
+ empty_models = []
790
+ # Create empty charts
791
+ empty_gen_chart = create_radar_chart(df, empty_models, generation_metrics, "Performance on Generation Tasks")
792
+ empty_ret_chart = create_radar_chart(df, empty_models, retrieval_metrics, "Performance on Retrieval Tasks")
793
+ return empty_models, empty_gen_chart, empty_ret_chart
794
+
795
+ # Connect clear charts button
796
+ clear_charts_btn.click(
797
+ fn=clear_charts,
798
+ inputs=[],
799
+ outputs=[selected_models, generation_chart, retrieval_chart]
800
+ )
801
+
802
+ if __name__ == "__main__":
803
+ demo.launch()