Spaces:
Running
Running
import gradio as gr | |
import json | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import os | |
import traceback | |
from datetime import datetime | |
from packaging import version | |
# Color scheme for charts | |
COLORS = px.colors.qualitative.Plotly | |
# Line colors for radar charts | |
line_colors = [ | |
"#EE4266", | |
"#00a6ed", | |
"#ECA72C", | |
"#B42318", | |
"#3CBBB1", | |
] | |
# Fill colors for radar charts | |
fill_colors = [ | |
"rgba(238,66,102,0.05)", | |
"rgba(0,166,237,0.05)", | |
"rgba(236,167,44,0.05)", | |
"rgba(180,35,24,0.05)", | |
"rgba(60,187,177,0.05)", | |
] | |
# Language definitions | |
LANGUAGES = {"English": { | |
"clear_charts": "Clear Charts", | |
"lang_selector_label": "Language / Язык", | |
"description": "This leaderboard allows comparing RAG systems based on generative and retrieval metrics across different question types (simple, comparison, multi-hop, conditional, etc.). <li>Questions are automatically generated from news sources.</li><li>The question dataset is updated regularly, and metrics for open models are recalculated.</li><li>User submissions use the latest calculated metrics for them.</li><li>To recalculate a previously submitted configuration with the latest data version, use the submit_id received during the initial submission via the client (see instructions below).</li>", | |
"version_info_template": "## Version {} → {} questions, generated from news sources → {}", | |
"gen_metrics_title": "### Generation Metrics", | |
"ret_metrics_title": "### Retrieval Metrics", | |
"overall_tab_title": "Overall Table", | |
"no_data_message": "No data available. Please submit some results.", | |
"by_type_tab_title": "By Question Type", | |
"category_display_names": { | |
"simple": "Simple Questions", | |
"set": "Set-based", | |
"mh": "Multi-hop", | |
"cond": "Conditional", | |
"comp": "Comparison" | |
}, | |
"no_data_category_template": "No data available for {} category.", | |
"category_performance_template": "#### Performance on {}", | |
"citation_title": "### Citation", | |
"citation_description": """ | |
``` | |
@misc{chernogorskii2025dragondynamicragbenchmark, | |
title={DRAGON: Dynamic RAG Benchmark On News}, | |
author={Fedor Chernogorskii and Sergei Averkiev and Liliya Kudraleeva and Zaven Martirosian and Maria Tikhonova and Valentin Malykh and Alena Fenogenova}, | |
year={2025}, | |
eprint={2507.05713}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.CL}, | |
url={https://arxiv.org/abs/2507.05713}, | |
} | |
``` | |
""", | |
"version_selector_title": "### Version Selection", | |
"only_actual_label": "Only actual versions", | |
"only_actual_info": "Start counting from the current dataset version", | |
"n_versions_label": "Take n last versions", | |
"n_versions_info": "Number of versions to calculate metrics for", | |
"filter_button": "Apply Filter", | |
"info_text": "Click on models in the table to add them to the charts", | |
"footer_text": "<footer>DRAGON. Dynamic RAG Benchmark Leaderboard</footer>", | |
"radar_gen_title": "Performance on Generation Tasks", | |
"radar_ret_title": "Performance on Retrieval Tasks" | |
}, | |
"Русский": { | |
"clear_charts": "Очистить графики", | |
# "lang_selector_label": "Language", | |
"description": "На этом лидерборде можно сравнить RAG системы в разрезе генеративных и поисковых метрик моделей по вопросам разного типа (простые вопросы, сравнения, multi-hop, условные и др.). <li>Вопросы автоматичеки генерируются на основе новостных источников.</li><li>Обновление датасета с вопросами происходит регулярно, при этом пересчитываются все метрики для открытых моделей.</li><li>Для пользовательских сабмитов учитываются последние посчитанные для них метрики.</li><li>Чтобы посчитать ранее отправленную конфигурацию на последней версии данных, используйте submit_id, полученный при первой отправке через клиент (см. инструкцию ниже).</li>", | |
"version_info_template": "## Версия {} → {} вопросов, сгенерированных по новостным источникам → {}", | |
"gen_metrics_title": "### Генеративные метрики", | |
"ret_metrics_title": "### Метрики поиска", | |
"overall_tab_title": "Общая таблица", | |
"no_data_message": "Нет данных. Пожалуйста, отправьте результаты.", | |
"by_type_tab_title": "По типам вопросов", | |
"category_display_names": { | |
"simple": "Simple", | |
"set": "Set", | |
"mh": "Multi-hop", | |
"cond": "Conditional", | |
"comp": "Comparison" | |
}, | |
"no_data_category_template": "Нет данных для категории {}.", | |
"category_performance_template": "#### Производительность на {}", | |
"citation_title": "### Цитирование", | |
"citation_description": """ | |
``` | |
@article{dynamic-rag-benchmark, | |
title={Dynamic RAG Benchmark}, | |
author={RAG Benchmark Team}, | |
journal={arXiv preprint}, | |
year={2025}, | |
url={https://github.com/rag-benchmark} | |
} | |
``` | |
Шаблон для цитирования нашего бенча. | |
""", | |
"version_selector_title": "### Выбор версий", | |
"only_actual_label": "Только актуальные версии", | |
"only_actual_info": "Считать, начиная с актуальной версии датасета", | |
"n_versions_label": "Взять n последних версий", | |
"n_versions_info": "Количество версий для подсчета метрик", | |
"filter_button": "Применить фильтр", | |
"info_text": "Кликайте на модели в таблице, чтобы добавить их в графики", | |
"footer_text": "<footer>DRAGON. Dynamic RAG Benchmark Leaderboard</footer>", | |
"radar_gen_title": "Производительность на Генеративных Заданиях", | |
"radar_ret_title": "Производительность на Поисковых Заданиях" | |
} | |
} | |
DEFAULT_LANG = "English" | |
# Define the question categories | |
QUESTION_CATEGORIES = ["simple", "set", "mh", "cond", "comp"] | |
METRIC_TYPES = ["retrieval", "generation"] | |
def load_results(): | |
"""Load results from the results.json file.""" | |
try: | |
# Get the directory of the current script | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
# Build the path to results.json | |
results_path = os.path.join(script_dir, 'results.json') | |
print(f"Loading results from: {results_path}") | |
with open(results_path, 'r', encoding='utf-8') as f: | |
results = json.load(f) | |
print(f"Successfully loaded results with {len(results.get('items', {}))} version(s)") | |
return results | |
except FileNotFoundError: | |
# Return empty structure if file doesn't exist | |
print(f"Results file not found, creating empty structure") | |
return {"items": {}, "last_version": "1.0", "n_questions": "0"} | |
except Exception as e: | |
print(f"Error loading results: {e}") | |
print(traceback.format_exc()) | |
return {"items": {}, "last_version": "1.0", "n_questions": "0"} | |
def filter_and_process_results(results, n_versions, only_actual_versions): | |
"""Filter results by version and process them for display.""" | |
if not results or "items" not in results: | |
return pd.DataFrame(), [], [], [] | |
all_items = results["items"] | |
# Get all versions and sort them | |
all_versions_sorted = sorted([version.parse(v_str) for v_str in all_items.keys()], reverse=True) | |
# Filter versions to consider based on n_versions slider | |
versions_to_consider = all_versions_sorted[:n_versions] | |
versions_to_consider_str = {str(v) for v in versions_to_consider} | |
rows = [] | |
for version_str, version_items in all_items.items(): | |
if version_str not in versions_to_consider_str: | |
continue | |
for guid, item in version_items.items(): | |
config = item.get("config", {}) | |
model_name = item.get("model_name", "N/A") | |
metrics = item.get("metrics", {}) | |
judge_metrics = metrics.get("judge", {}) | |
row = { | |
'Model': f"{model_name} ({guid[:6]})", | |
'Embeddings': config.get('embedding_model', 'N/A'), | |
'Top k': config.get('retrieval_config', {}).get('top_k', 'N/A'), | |
# 'Judge': round(judge_metrics.get("judge_total_score", 0.0) / 2, 4), | |
'Version': version_str, | |
'Last Updated': item.get("timestamp", ""), | |
'guid': guid | |
} | |
if row['Last Updated']: | |
try: | |
dt = datetime.fromisoformat(row['Last Updated'].replace('Z', '+00:00')) | |
row['Last Updated'] = dt.strftime("%Y-%m-%d") | |
except (ValueError, TypeError): | |
pass | |
category_sums = {mtype: 0.0 for mtype in METRIC_TYPES} | |
category_counts = {mtype: 0 for mtype in METRIC_TYPES} | |
for category in QUESTION_CATEGORIES: | |
if category in metrics: | |
for metric_type in METRIC_TYPES: | |
if metric_type in metrics[category]: | |
metric_values = metrics[category][metric_type] | |
if metric_values and len(metric_values) > 0: | |
avg_value = sum(metric_values.values()) / len(metric_values) | |
col_name = f"{category}_{metric_type}" | |
row[col_name] = round(avg_value, 4) | |
category_sums[metric_type] += avg_value | |
category_counts[metric_type] += 1 | |
for metric_type in METRIC_TYPES: | |
if category_counts[metric_type] > 0: | |
avg = category_sums[metric_type] / category_counts[metric_type] | |
row[f"{metric_type}_avg"] = round(avg, 4) | |
rows.append(row) | |
df = pd.DataFrame(rows) | |
# Get lists of metrics for each category | |
category_metrics = [] | |
if not df.empty: | |
for category in QUESTION_CATEGORIES: | |
metrics_list = [] | |
for metric_type in METRIC_TYPES: | |
col_name = f"{category}_{metric_type}" | |
if col_name in df.columns: | |
metrics_list.append(col_name) | |
if metrics_list: | |
category_metrics.append((category, metrics_list)) | |
# Define retrieval and generation columns for radar charts | |
retrieval_metrics = [] | |
generation_metrics = [] | |
if not df.empty: | |
retrieval_metrics = [f"{category}_retrieval" for category, _ in category_metrics if f"{category}_retrieval" in df.columns] | |
generation_metrics = [f"{category}_generation" for category, _ in category_metrics if f"{category}_generation" in df.columns] | |
return df, retrieval_metrics, generation_metrics, category_metrics | |
def create_radar_chart(df, selected_models, metrics, title, name_col="Model"): | |
"""Create a radar chart for the selected models and metrics.""" | |
if not metrics or len(selected_models) == 0: | |
# Return empty figure if no metrics or models selected | |
fig = go.Figure() | |
fig.update_layout( | |
title=title, | |
title_font_size=16, | |
height=400, | |
width=500, | |
margin=dict(l=30, r=30, t=50, b=30) | |
) | |
return fig | |
# Filter dataframe for selected models | |
filtered_df = df[df['Model'].isin(selected_models)] | |
if filtered_df.empty: | |
# Return empty figure if no data | |
fig = go.Figure() | |
fig.update_layout( | |
title=title, | |
title_font_size=16, | |
height=400, | |
width=500, | |
margin=dict(l=30, r=30, t=50, b=30) | |
) | |
return fig | |
# Limit to top 5 models for better visualization (similar to inspiration file) | |
if len(filtered_df) > 5: | |
filtered_df = filtered_df.head(5) | |
# Prepare data for radar chart | |
categories = [m.split('_', 1)[0] for m in metrics] # Get category name (simple, set, etc.) | |
fig = go.Figure() | |
# Process in reverse order to match inspiration file | |
for i, (_, row) in enumerate(filtered_df.iterrows()): | |
values = [row[m] for m in metrics] | |
# Close the loop for radar chart | |
values.append(values[0]) | |
categories_loop = categories + [categories[0]] | |
fig.add_trace(go.Scatterpolar( | |
name=row[name_col], | |
r=values, | |
theta=categories_loop, | |
showlegend=True, | |
mode="lines", | |
line=dict(width=2, color=line_colors[i % len(line_colors)]), | |
fill="toself", | |
fillcolor=fill_colors[i % len(fill_colors)] | |
)) | |
fig.update_layout( | |
font=dict(size=13, color="black"), | |
template="plotly_white", | |
polar=dict( | |
radialaxis=dict( | |
visible=True, | |
gridcolor="black", | |
linecolor="rgba(0,0,0,0)", | |
gridwidth=1, | |
showticklabels=False, | |
ticks="", | |
range=[0, 1] # Ensure consistent range for scores | |
), | |
angularaxis=dict( | |
gridcolor="black", | |
gridwidth=1.5, | |
linecolor="rgba(0,0,0,0)" | |
), | |
), | |
legend=dict( | |
orientation="h", | |
yanchor="bottom", | |
y=-0.35, | |
xanchor="center", | |
x=0.4, | |
itemwidth=30, | |
font=dict(size=13), | |
entrywidth=0.6, | |
entrywidthmode="fraction", | |
), | |
margin=dict(l=0, r=16, t=30, b=30), | |
autosize=True, | |
) | |
return fig | |
def create_summary_df(df, retrieval_metrics, generation_metrics): | |
"""Create a summary dataframe with averaged metrics for display.""" | |
if df.empty: | |
return pd.DataFrame() | |
summary_df = df.copy() | |
# Add retrieval average | |
if retrieval_metrics: | |
retrieval_avg = summary_df[retrieval_metrics].mean(axis=1).round(4) | |
summary_df['Retrieval (avg)'] = retrieval_avg | |
# Add generation average | |
if generation_metrics: | |
generation_avg = summary_df[generation_metrics].mean(axis=1).round(4) | |
summary_df['Generation (avg)'] = generation_avg | |
# Add total score if all three columns exist | |
if 'Retrieval (avg)' in summary_df.columns and 'Generation (avg)' in summary_df.columns: | |
# if 'Retrieval (avg)' in summary_df.columns and 'Generation (avg)' in summary_df.columns and 'Judge' in summary_df.columns: | |
# summary_df['Total Score'] = summary_df[['Retrieval (avg)', 'Generation (avg)', 'Judge']].mean(axis=1).round(4) | |
summary_df['Total Score'] = summary_df[['Retrieval (avg)', 'Generation (avg)']].mean(axis=1).round(4) | |
summary_df = summary_df.sort_values('Total Score', ascending=False) | |
# Select columns for display | |
summary_cols = ['Model', 'Embeddings', 'Top k'] | |
# if 'Judge' in summary_df.columns: | |
# summary_cols.append('Judge') | |
if 'Retrieval (avg)' in summary_df.columns: | |
summary_cols.append('Retrieval (avg)') | |
if 'Generation (avg)' in summary_df.columns: | |
summary_cols.append('Generation (avg)') | |
if 'Total Score' in summary_df.columns: | |
summary_cols.append('Total Score') | |
if 'Version' in summary_df.columns: | |
summary_cols.append('Version') | |
if 'Last Updated' in summary_df.columns: | |
summary_cols.append('Last Updated') | |
return summary_df[summary_cols] | |
def create_category_df(df, category, retrieval_col, generation_col): | |
"""Create a dataframe for a specific category with detailed metrics.""" | |
if df.empty or retrieval_col not in df.columns or generation_col not in df.columns: | |
return pd.DataFrame() | |
category_df = df.copy() | |
# Calculate total score for this category | |
category_df[f'Score'] = (category_df[retrieval_col] + category_df[generation_col]).round(4) | |
# Sort by total score | |
category_df = category_df.sort_values(f'Score', ascending=False) | |
# Select columns for display | |
category_cols = ['Model', 'Embeddings', retrieval_col, generation_col, f'Score'] | |
# Rename columns for display | |
category_df = category_df[category_cols].rename(columns={ | |
retrieval_col: 'Retrieval', | |
generation_col: 'Generation' | |
}) | |
return category_df | |
# Load initial data | |
results = load_results() | |
last_version = results.get("last_version", "1.0") | |
n_questions = results.get("n_questions", "100") | |
date_title = results.get("date_title", "---") | |
# Initial data processing | |
df, retrieval_metrics, generation_metrics, category_metrics = filter_and_process_results( | |
results, n_versions=1, only_actual_versions=True | |
) | |
# Pre-generate charts for initial display | |
default_models = df['Model'].head(5).tolist() if not df.empty else [] | |
initial_gen_chart_title = LANGUAGES[DEFAULT_LANG]["radar_gen_title"] | |
initial_ret_chart_title = LANGUAGES[DEFAULT_LANG]["radar_ret_title"] | |
initial_gen_chart = create_radar_chart(df, default_models, generation_metrics, initial_gen_chart_title) | |
initial_ret_chart = create_radar_chart(df, default_models, retrieval_metrics, initial_ret_chart_title, name_col='Embeddings') | |
# Create summary dataframe | |
summary_df = create_summary_df(df, retrieval_metrics, generation_metrics) | |
with gr.Blocks(css=""" | |
.title-container { | |
text-align: center; | |
margin-bottom: 10px; | |
} | |
.description-text { | |
text-align: left; | |
padding: 10px; | |
margin-bottom: 0px; | |
} | |
.version-info { | |
text-align: center; | |
padding: 10px; | |
background-color: #f0f0f0; | |
border-radius: 8px; | |
margin-bottom: 15px; | |
} | |
.version-selector { | |
padding: 15px; | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
margin-bottom: 20px; | |
background-color: #f9f9f9; | |
height: 100%; | |
} | |
.citation-block { | |
padding: 15px; | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
margin-bottom: 20px; | |
background-color: #f9f9f9; | |
font-family: monospace; | |
font-size: 14px; | |
overflow-x: auto; | |
height: 100%; | |
} | |
.flex-row-container { | |
display: flex; | |
justify-content: space-between; | |
gap: 20px; | |
width: 100%; | |
} | |
.charts-container { | |
display: flex; | |
gap: 20px; | |
margin-bottom: 20px; | |
} | |
.chart-box { | |
flex: 1; | |
border: 1px solid #eee; | |
border-radius: 8px; | |
padding: 10px; | |
background-color: white; | |
min-height: 550px; /* Increased height to accommodate legend at bottom */ | |
} | |
.metrics-table { | |
border: 1px solid #eee; | |
border-radius: 8px; | |
padding: 15px; | |
background-color: white; | |
} | |
.info-text { | |
font-size: 0.9em; | |
font-style: italic; | |
color: #666; | |
margin-top: 5px; | |
} | |
footer { | |
text-align: center; | |
margin-top: 30px; | |
font-size: 0.9em; | |
color: #666; | |
} | |
/* Style for selected rows */ | |
table tbody tr.selected { | |
background-color: rgba(25, 118, 210, 0.1) !important; | |
border-left: 3px solid #1976d2; | |
} | |
/* Add this class via JavaScript */ | |
.gr-table tbody tr.selected td:first-child { | |
font-weight: bold; | |
color: #1976d2; | |
} | |
.category-tab { | |
padding: 10px; | |
} | |
.chart-title { | |
font-size: 1.2em; | |
font-weight: bold; | |
margin-bottom: 10px; | |
text-align: center; | |
} | |
.clear-charts-button { | |
display: flex; | |
justify-content: center; | |
margin-top: 10px; | |
margin-bottom: 20px; | |
} | |
.lang-selector { | |
width: fit-content; /* Adjust width to content */ | |
margin-left: auto; /* Push to the right */ | |
margin-right: 0; /* Keep it flush right */ | |
margin-bottom: 15px; /* Keep bottom margin */ | |
padding: 10px; | |
background-color: #f9f9f9; | |
border-radius: 8px; | |
border: none; | |
padding: 0 !important; | |
} | |
.lang-selector .form { | |
border: none !important; | |
} | |
""") as demo: | |
current_lang_dict = gr.State(LANGUAGES[DEFAULT_LANG]) | |
current_language = gr.State(DEFAULT_LANG) | |
with gr.Row(elem_classes=["title-container"]): | |
#title with emoji connected with dragon | |
main_title_md = gr.Markdown("# 🐉 DRAGON. Dynamic RAG Benchmark On News") | |
# Language Selector | |
with gr.Row(elem_classes=["lang-selector"]): | |
lang_selector = gr.Radio( | |
list(LANGUAGES.keys()), | |
label="", | |
value=DEFAULT_LANG, | |
interactive=True | |
) | |
# Description | |
with gr.Row(elem_classes=["description-text"]): | |
description_md = gr.Markdown(value=LANGUAGES[DEFAULT_LANG]["description"]) | |
# Version info | |
with gr.Row(elem_classes=["version-info"]): | |
version_info_md = gr.Markdown( | |
value=LANGUAGES[DEFAULT_LANG]["version_info_template"].format(last_version, n_questions, date_title) | |
) | |
# Radar Charts | |
with gr.Row(elem_classes=["charts-container"]): | |
with gr.Column(elem_classes=["chart-box"]): | |
gen_chart_title_md = gr.Markdown( | |
value=LANGUAGES[DEFAULT_LANG]["gen_metrics_title"], elem_classes=["chart-title"] | |
) | |
generation_chart = gr.Plot(value=initial_gen_chart) | |
with gr.Column(elem_classes=["chart-box"]): | |
ret_chart_title_md = gr.Markdown( | |
value=LANGUAGES[DEFAULT_LANG]["ret_metrics_title"], elem_classes=["chart-title"] | |
) | |
retrieval_chart = gr.Plot(value=initial_ret_chart) | |
# Clear Charts Button | |
with gr.Row(elem_classes=["clear-charts-button"]): | |
clear_charts_btn = gr.Button( | |
value=LANGUAGES[DEFAULT_LANG]["clear_charts"], | |
variant="secondary" | |
) | |
# Metrics table with tabs | |
with gr.Tabs(elem_classes=["metrics-table"]) as metrics_tabs: | |
with gr.TabItem(label=LANGUAGES[DEFAULT_LANG]["overall_tab_title"]) as summary_tab: | |
selected_models = gr.State(default_models) | |
empty_data_md = gr.Markdown( | |
value=LANGUAGES[DEFAULT_LANG]["no_data_message"], | |
visible=df.empty # Initially visible only if df is empty | |
) | |
# Initialize metrics_table even if empty, but maybe hide it | |
metrics_table = gr.DataFrame( | |
value=summary_df if not df.empty else pd.DataFrame(), | |
headers=summary_df.columns.tolist() if not df.empty else [], | |
datatype=["str"] * (len(summary_df.columns) if not df.empty else 0), | |
row_count=(min(10, len(summary_df)) if not summary_df.empty else 0), | |
col_count=(len(summary_df.columns) if not summary_df.empty else 0), | |
interactive=False, | |
wrap=True, | |
visible=not df.empty # Initially visible only if df is not empty | |
) | |
with gr.TabItem(label=LANGUAGES[DEFAULT_LANG]["by_type_tab_title"]) as category_main_tab: | |
category_tabs = gr.Tabs() | |
category_tables = {} | |
category_tab_items = {} # Store TabItem components | |
category_no_data_mds = {} # Store "no data" Markdowns | |
category_title_mds = {} # Store category title Markdowns | |
# Get initial display names | |
initial_category_display_names = LANGUAGES[DEFAULT_LANG]["category_display_names"] | |
with category_tabs: | |
for category, _ in category_metrics: | |
display_name = initial_category_display_names.get(category, category.capitalize()) | |
if f"{category}_retrieval" in df.columns and f"{category}_generation" in df.columns: | |
with gr.TabItem(label=display_name, elem_classes=["category-tab"]) as tab_item: | |
category_tab_items[category] = tab_item # Store the TabItem | |
# Create dataframe for this category | |
category_df = create_category_df(df, category, f"{category}_retrieval", f"{category}_generation") | |
category_no_data_mds[category] = gr.Markdown( | |
value=LANGUAGES[DEFAULT_LANG]["no_data_category_template"].format(display_name), | |
visible=category_df.empty | |
) | |
category_title_mds[category] = gr.Markdown( | |
value=LANGUAGES[DEFAULT_LANG]["category_performance_template"].format(display_name), | |
visible=not category_df.empty | |
) | |
category_tables[category] = gr.DataFrame( | |
value=category_df if not category_df.empty else pd.DataFrame(), | |
headers=category_df.columns.tolist() if not category_df.empty else [], | |
datatype=["str"] * (len(category_df.columns) if not category_df.empty else 0), | |
row_count=(min(10, len(category_df)) if not category_df.empty else 0), | |
col_count=(len(category_df.columns) if not category_df.empty else 0), | |
interactive=False, | |
wrap=True, | |
visible=not category_df.empty | |
) | |
# Version selector and Citation block in a flex container | |
with gr.Row(): | |
# Citation block (left side) | |
with gr.Column(scale=1, elem_classes=["citation-block"]): | |
citation_title_md = gr.Markdown(value=LANGUAGES[DEFAULT_LANG]["citation_title"]) | |
citation_desc_md = gr.Markdown(value=LANGUAGES[DEFAULT_LANG]["citation_description"]) | |
# Version selector (right side) | |
with gr.Column(scale=1, elem_classes=["version-selector"]): | |
version_selector_title_md = gr.Markdown(value=LANGUAGES[DEFAULT_LANG]["version_selector_title"]) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
only_actual_versions = gr.Checkbox( | |
label=LANGUAGES[DEFAULT_LANG]["only_actual_label"], | |
value=True, | |
info=LANGUAGES[DEFAULT_LANG]["only_actual_info"] | |
) | |
with gr.Column(scale=5): | |
n_versions_slider = gr.Slider( | |
minimum=1, | |
maximum=5, | |
value=1, | |
step=1, | |
label=LANGUAGES[DEFAULT_LANG]["n_versions_label"], | |
info=LANGUAGES[DEFAULT_LANG]["n_versions_info"] | |
) | |
with gr.Row(): | |
filter_btn = gr.Button(value=LANGUAGES[DEFAULT_LANG]["filter_button"], variant="primary") | |
info_text_md = gr.Markdown( | |
value=LANGUAGES[DEFAULT_LANG]["info_text"], | |
elem_classes=["info-text"] | |
) | |
# Footer | |
with gr.Row(): | |
footer_md = gr.Markdown(value=LANGUAGES[DEFAULT_LANG]["footer_text"]) | |
# Handle row selection for radar charts | |
def update_charts(evt: gr.SelectData, selected_models, current_lang): | |
try: | |
# Get current data with the latest filters applied in update_data | |
current_df = df # Use the globally updated df | |
current_ret_metrics = retrieval_metrics | |
current_gen_metrics = generation_metrics | |
# Debug info | |
print(f"Selection event: {evt}, type: {type(evt)}") | |
selected_model = None | |
# Extract the selected model based on the row index | |
try: | |
component = evt.target | |
row_idx = evt.index[0] if isinstance(evt.index, list) else evt.index | |
print(f"Row index: {row_idx}, Component: {component}") | |
# Determine what type of data we're dealing with and extract model name | |
if component is metrics_table: | |
# Summary table was clicked | |
current_summary_df = create_summary_df(current_df, current_ret_metrics, current_gen_metrics) | |
if isinstance(current_summary_df, pd.DataFrame) and not current_summary_df.empty and 0 <= row_idx < len(current_summary_df): | |
selected_model = current_summary_df.iloc[row_idx]['Model'] | |
print(f"Selected from summary table: {selected_model}") | |
else: | |
# Check if it's a category table | |
for category, table in category_tables.items(): | |
if component is table: | |
category_df = create_category_df( | |
current_df, | |
category, | |
f"{category}_retrieval", | |
f"{category}_generation" | |
) | |
if isinstance(category_df, pd.DataFrame) and not category_df.empty and 0 <= row_idx < len(category_df): | |
selected_model = category_df.iloc[row_idx]['Model'] | |
print(f"Selected from {category} table: {selected_model}") | |
break | |
# Fallback if model not found yet (should not happen often with explicit checks) | |
if selected_model is None and hasattr(evt, 'value') and evt.value: | |
selected_model = evt.value[0] # Assuming model name is the first column value in the selected cell data | |
print(f"Selected model using fallback evt.value: {selected_model}") | |
except IndexError: | |
print(f"IndexError: row_idx {row_idx} out of bounds for the component's data.") | |
# Potentially return current state without changes | |
gen_chart = create_radar_chart(current_df, selected_models, current_gen_metrics, LANGUAGES[current_lang]["radar_gen_title"]) | |
ret_chart = create_radar_chart(current_df, selected_models, current_ret_metrics, LANGUAGES[current_lang]["radar_ret_title"], name_col='Embeddings') | |
return selected_models, gen_chart, ret_chart | |
except Exception as e: | |
print(f"Error extracting model name: {e}") | |
traceback.print_exc() | |
# If we found a model name, toggle its selection | |
if selected_model: | |
print(f"Selected model: {selected_model}") | |
available_models = current_df['Model'].tolist() if not current_df.empty else [] | |
if selected_model in available_models: | |
new_selected_models = selected_models[:] # Create a copy | |
if selected_model in new_selected_models: | |
new_selected_models.remove(selected_model) | |
else: | |
new_selected_models.append(selected_model) | |
# Ensure only models from the current dataframe are included | |
new_selected_models = [model for model in new_selected_models if model in available_models] | |
# If no models are selected after filtering, select the top available model | |
if not new_selected_models and available_models: | |
new_selected_models = [available_models[0]] | |
selected_models = new_selected_models # Update the state | |
else: | |
print(f"Model {selected_model} not found in current dataframe") | |
# Create radar charts using the current dataframe and metrics | |
gen_chart = create_radar_chart(current_df, selected_models, current_gen_metrics, LANGUAGES[current_lang]["radar_gen_title"]) | |
ret_chart = create_radar_chart(current_df, selected_models, current_ret_metrics, LANGUAGES[current_lang]["radar_ret_title"], name_col='Embeddings') | |
return selected_models, gen_chart, ret_chart | |
except Exception as e: | |
print(f"Error in update_charts: {e}") | |
print(traceback.format_exc()) | |
# Return potentially existing chart values if error occurs | |
current_gen_chart = create_radar_chart(df, selected_models, generation_metrics, LANGUAGES[current_lang]["radar_gen_title"]) | |
current_ret_chart = create_radar_chart(df, selected_models, retrieval_metrics, LANGUAGES[current_lang]["radar_ret_title"], name_col='Embeddings') | |
return selected_models, current_gen_chart, current_ret_chart | |
# Use custom event handler for row selection | |
# Make sure to pass current_language state | |
metrics_table.select( | |
fn=update_charts, | |
inputs=[selected_models, current_language], | |
outputs=[selected_models, generation_chart, retrieval_chart] | |
) | |
# Add selection handlers for category tables too | |
for category_table in category_tables.values(): | |
category_table.select( | |
fn=update_charts, | |
inputs=[selected_models, current_language], | |
outputs=[selected_models, generation_chart, retrieval_chart] | |
) | |
# Handle version filter changes | |
def update_data(n_versions, only_actual, current_selected_models, current_lang): | |
try: | |
# Update global data (df, metrics) | |
global df, retrieval_metrics, generation_metrics | |
new_df, new_ret_metrics, new_gen_metrics, new_category_metrics = filter_and_process_results( | |
results, n_versions=n_versions, only_actual_versions=only_actual | |
) | |
# Update global references | |
df = new_df | |
retrieval_metrics = new_ret_metrics | |
generation_metrics = new_gen_metrics | |
available_models = df['Model'].tolist() if not df.empty else [] | |
# Filter selected models | |
filtered_selected_models = [model for model in current_selected_models if model in available_models] | |
if not filtered_selected_models and available_models: | |
filtered_selected_models = available_models[:min(5, len(available_models))] | |
# Create charts with localized titles | |
gen_chart_val = create_radar_chart(df, filtered_selected_models, generation_metrics, LANGUAGES[current_lang]["radar_gen_title"]) | |
ret_chart_val = create_radar_chart(df, filtered_selected_models, retrieval_metrics, LANGUAGES[current_lang]["radar_ret_title"], name_col='Embeddings') | |
# Create summary dataframe | |
summary_df_val = create_summary_df(df, retrieval_metrics, generation_metrics) | |
# Prepare outputs for tables and charts | |
outputs = { | |
metrics_table: gr.update(value=summary_df_val if not summary_df_val.empty else pd.DataFrame(), visible=not summary_df_val.empty), | |
empty_data_md: gr.update(visible=summary_df_val.empty), | |
generation_chart: gen_chart_val, | |
retrieval_chart: ret_chart_val, | |
selected_models: filtered_selected_models | |
} | |
# Update category tables | |
current_category_display_names = LANGUAGES[current_lang]["category_display_names"] | |
for category in category_tables.keys(): | |
if f"{category}_retrieval" in df.columns and f"{category}_generation" in df.columns: | |
category_df_val = create_category_df(df, category, f"{category}_retrieval", f"{category}_generation") | |
display_name = current_category_display_names.get(category, category.capitalize()) | |
outputs[category_tables[category]] = gr.update(value=category_df_val if not category_df_val.empty else pd.DataFrame(), visible=not category_df_val.empty) | |
outputs[category_no_data_mds[category]] = gr.update(visible=category_df_val.empty) | |
outputs[category_title_mds[category]] = gr.update(visible=not category_df_val.empty) | |
else: | |
# Hide table and titles if data for category doesn't exist with current filters | |
outputs[category_tables[category]] = gr.update(value=pd.DataFrame(), visible=False) | |
outputs[category_no_data_mds[category]] = gr.update(visible=True) # Show 'no data' instead? Or just hide all? Let's hide title too. | |
outputs[category_title_mds[category]] = gr.update(visible=False) | |
# Return updates in the correct order based on outputs list | |
output_list = [outputs[metrics_table], outputs[empty_data_md], outputs[generation_chart], outputs[retrieval_chart], outputs[selected_models]] | |
for category in category_tables.keys(): | |
output_list.extend([ | |
outputs[category_tables[category]], | |
outputs[category_no_data_mds[category]], | |
outputs[category_title_mds[category]] | |
]) | |
return output_list | |
except Exception as e: | |
print(f"Error in update_data: {e}") | |
print(traceback.format_exc()) | |
# Return original values in case of error; construct a list of Nones matching output structure | |
num_category_outputs = len(category_tables.keys()) * 3 | |
return [gr.update()]*5 + [gr.update()]*num_category_outputs # Return no changes | |
# Define filter button outputs | |
filter_outputs = [metrics_table, empty_data_md, generation_chart, retrieval_chart, selected_models] | |
for category in category_tables.keys(): | |
filter_outputs.extend([category_tables[category], category_no_data_mds[category], category_title_mds[category]]) | |
filter_btn.click( | |
fn=update_data, | |
inputs=[n_versions_slider, only_actual_versions, selected_models, current_language], # Pass language | |
outputs=filter_outputs | |
) | |
# Function to clear charts | |
def clear_charts_localized(current_lang): # Pass language | |
empty_models = [] | |
# Create empty charts with localized titles | |
empty_gen_chart = create_radar_chart(df, empty_models, generation_metrics, LANGUAGES[current_lang]["radar_gen_title"]) | |
empty_ret_chart = create_radar_chart(df, empty_models, retrieval_metrics, LANGUAGES[current_lang]["radar_ret_title"], name_col='Embeddings') | |
return empty_models, empty_gen_chart, empty_ret_chart | |
# Connect clear charts button | |
clear_charts_btn.click( | |
fn=clear_charts_localized, | |
inputs=[current_language], # Pass language | |
outputs=[selected_models, generation_chart, retrieval_chart] | |
) | |
# Function to update language-specific elements | |
def update_language(selected_lang): | |
lang_dict = LANGUAGES[selected_lang] | |
category_display_names = lang_dict.get("category_display_names", {}) | |
updates = { | |
current_language: selected_lang, # Update the state holding the language key | |
current_lang_dict: lang_dict, # Update the state holding the translations | |
# lang_selector: gr.update(label=lang_dict["lang_selector_label"]), | |
description_md: gr.update(value=lang_dict["description"]), | |
version_info_md: gr.update(value=lang_dict["version_info_template"].format(last_version, n_questions, date_title)), | |
gen_chart_title_md: gr.update(value=lang_dict["gen_metrics_title"]), | |
ret_chart_title_md: gr.update(value=lang_dict["ret_metrics_title"]), | |
clear_charts_btn: gr.update(value=lang_dict["clear_charts"]), | |
summary_tab: gr.update(label=lang_dict["overall_tab_title"]), | |
empty_data_md: gr.update(value=lang_dict["no_data_message"]), | |
category_main_tab: gr.update(label=lang_dict["by_type_tab_title"]), | |
citation_title_md: gr.update(value=lang_dict["citation_title"]), | |
citation_desc_md: gr.update(value=lang_dict["citation_description"]), | |
version_selector_title_md: gr.update(value=lang_dict["version_selector_title"]), | |
only_actual_versions: gr.update(label=lang_dict["only_actual_label"], info=lang_dict["only_actual_info"]), | |
n_versions_slider: gr.update(label=lang_dict["n_versions_label"], info=lang_dict["n_versions_info"]), | |
filter_btn: gr.update(value=lang_dict["filter_button"]), | |
info_text_md: gr.update(value=lang_dict["info_text"]), | |
footer_md: gr.update(value=lang_dict["footer_text"]), | |
# Update category tab labels and conditional text templates | |
**{tab_item: gr.update(label=category_display_names.get(category, category.capitalize())) | |
for category, tab_item in category_tab_items.items()}, | |
**{no_data_md: gr.update(value=lang_dict["no_data_category_template"].format(category_display_names.get(category, category.capitalize()))) | |
for category, no_data_md in category_no_data_mds.items()}, | |
**{title_md: gr.update(value=lang_dict["category_performance_template"].format(category_display_names.get(category, category.capitalize()))) | |
for category, title_md in category_title_mds.items()}, | |
# Update chart titles dynamically by re-plotting (needed if chart titles change) | |
generation_chart: create_radar_chart(df, selected_models.value, generation_metrics, lang_dict["radar_gen_title"]), | |
retrieval_chart: create_radar_chart(df, selected_models.value, retrieval_metrics, lang_dict["radar_ret_title"], name_col='Embeddings') | |
} | |
# Return updates in the correct order based on outputs list below | |
output_list = [ | |
updates[current_language], updates[current_lang_dict], | |
updates[description_md], updates[version_info_md], updates[gen_chart_title_md], updates[ret_chart_title_md], | |
updates[clear_charts_btn], updates[summary_tab], updates[empty_data_md], updates[category_main_tab], | |
updates[citation_title_md], updates[citation_desc_md], updates[version_selector_title_md], | |
updates[only_actual_versions], updates[n_versions_slider], updates[filter_btn], updates[info_text_md], | |
updates[footer_md], updates[generation_chart], updates[retrieval_chart] | |
] | |
# Add category tab items, no_data markdown, and title markdown updates | |
for category in category_tables.keys(): # Use category_tables as the source of truth for existing categories | |
if category in category_tab_items: output_list.append(updates[category_tab_items[category]]) | |
if category in category_no_data_mds: output_list.append(updates[category_no_data_mds[category]]) | |
if category in category_title_mds: output_list.append(updates[category_title_mds[category]]) | |
return output_list | |
# Define the outputs for the language selector change event | |
lang_outputs = [ | |
current_language, current_lang_dict, description_md, version_info_md, | |
gen_chart_title_md, ret_chart_title_md, clear_charts_btn, summary_tab, empty_data_md, | |
category_main_tab, citation_title_md, citation_desc_md, version_selector_title_md, | |
only_actual_versions, n_versions_slider, filter_btn, info_text_md, footer_md, | |
generation_chart, retrieval_chart # Charts need to be updated too if their titles change | |
] | |
# Add category tab items, no_data markdown, and title markdown to outputs | |
for category in category_tables.keys(): | |
if category in category_tab_items: lang_outputs.append(category_tab_items[category]) | |
if category in category_no_data_mds: lang_outputs.append(category_no_data_mds[category]) | |
if category in category_title_mds: lang_outputs.append(category_title_mds[category]) | |
# Connect language selector change event | |
lang_selector.change( | |
fn=update_language, | |
inputs=[lang_selector], | |
outputs=lang_outputs | |
) | |
if __name__ == "__main__": | |
demo.launch() | |