import gradio as gr import json import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots import os import traceback from datetime import datetime from packaging import version # Color scheme for charts COLORS = px.colors.qualitative.Plotly # Line colors for radar charts line_colors = [ "#EE4266", "#00a6ed", "#ECA72C", "#B42318", "#3CBBB1", ] # Fill colors for radar charts fill_colors = [ "rgba(238,66,102,0.05)", "rgba(0,166,237,0.05)", "rgba(236,167,44,0.05)", "rgba(180,35,24,0.05)", "rgba(60,187,177,0.05)", ] # Define the question categories QUESTION_CATEGORIES = ["simple", "set", "mh", "cond", "comp"] METRIC_TYPES = ["retrieval", "generation"] def load_results(): """Load results from the results.json file.""" try: # Get the directory of the current script script_dir = os.path.dirname(os.path.abspath(__file__)) # Build the path to results.json results_path = os.path.join(script_dir, 'results.json') print(f"Loading results from: {results_path}") with open(results_path, 'r', encoding='utf-8') as f: results = json.load(f) print(f"Successfully loaded results with {len(results.get('items', {}))} version(s)") return results except FileNotFoundError: # Return empty structure if file doesn't exist print(f"Results file not found, creating empty structure") return {"items": {}, "last_version": "1.0", "n_questions": "0"} except Exception as e: print(f"Error loading results: {e}") print(traceback.format_exc()) return {"items": {}, "last_version": "1.0", "n_questions": "0"} def filter_and_process_results(results, n_versions, only_actual_versions): """Filter results by version and process them for display.""" if not results or "items" not in results: return pd.DataFrame(), [], [], [] all_items = results["items"] last_version_str = results.get("last_version", "1.0") last_version = version.parse(last_version_str) print(f"Last version: {last_version_str}") # Group items by model_name model_groups = {} for version_str, version_items in all_items.items(): version_obj = version.parse(version_str) for item_id, item in version_items.items(): model_name = item.get("model_name", "Unknown") if model_name not in model_groups: model_groups[model_name] = [] # Add version info to the item (both as string and as parsed version object for comparison) item["version_str"] = version_str item["version_obj"] = version_obj model_groups[model_name].append(item) rows = [] for model_name, items in model_groups.items(): # Sort items by version (newest first) items.sort(key=lambda x: x["version_obj"], reverse=True) # Filter versions based on selection filtered_items = [] if only_actual_versions: # Get the n most recent actual dataset versions all_versions = sorted([version.parse(v_str) for v_str in all_items.keys()], reverse=True) # Take at most n_versions versions_to_consider = all_versions[:n_versions] if all_versions else [] # Filter items that match those versions filtered_items = [item for item in items if any(item["version_obj"] == v for v in versions_to_consider)] else: # Consider n_versions most recent items for this model filtered_items = items[:n_versions] if not filtered_items: continue config = filtered_items[0]["config"] # Use config from most recent version # Create row with basic info row = { 'Model': model_name, 'Embeddings': config.get('embedding_model', 'N/A'), 'Retriever': config.get('retriever_type', 'N/A'), 'Top-K': config.get('retrieval_config', {}).get('top_k', 'N/A'), 'Versions': ", ".join([item["version_str"] for item in filtered_items]), 'Last Updated': filtered_items[0].get("timestamp", "") } # Format timestamp if available if row['Last Updated']: try: dt = datetime.fromisoformat(row['Last Updated'].replace('Z', '+00:00')) row['Last Updated'] = dt.strftime("%Y-%m-%d") except: pass # Process metrics based on categories category_metrics = { category: { metric_type: { "avg": 0.0, "count": 0 } for metric_type in METRIC_TYPES } for category in QUESTION_CATEGORIES } # Collect metrics by category for item in filtered_items: metrics = item.get("metrics", {}) for category in QUESTION_CATEGORIES: if category in metrics: for metric_type in METRIC_TYPES: if metric_type in metrics[category]: metric_values = metrics[category][metric_type] avg_value = sum(metric_values.values()) / len(metric_values) # Add to the running sum for this category and metric type category_metrics[category][metric_type]["avg"] += avg_value category_metrics[category][metric_type]["count"] += 1 # Calculate averages and add to row for category in QUESTION_CATEGORIES: for metric_type in METRIC_TYPES: metric_data = category_metrics[category][metric_type] if metric_data["count"] > 0: avg_value = metric_data["avg"] / metric_data["count"] # Add to row with appropriate column name col_name = f"{category}_{metric_type}" row[col_name] = round(avg_value, 4) # Calculate overall averages for each metric type for metric_type in METRIC_TYPES: total_sum = 0 total_count = 0 for category in QUESTION_CATEGORIES: metric_data = category_metrics[category][metric_type] if metric_data["count"] > 0: total_sum += metric_data["avg"] total_count += metric_data["count"] if total_count > 0: row[f"{metric_type}_avg"] = round(total_sum / total_count, 4) rows.append(row) # Create DataFrame df = pd.DataFrame(rows) # Get lists of metrics for each category category_metrics = [] for category in QUESTION_CATEGORIES: metrics = [] for metric_type in METRIC_TYPES: col_name = f"{category}_{metric_type}" if col_name in df.columns: metrics.append(col_name) if metrics: category_metrics.append((category, metrics)) # Define retrieval and generation columns for radar charts retrieval_metrics = [f"{category}_retrieval" for category in QUESTION_CATEGORIES if f"{category}_retrieval" in df.columns] generation_metrics = [f"{category}_generation" for category in QUESTION_CATEGORIES if f"{category}_generation" in df.columns] return df, retrieval_metrics, generation_metrics, category_metrics def create_radar_chart(df, selected_models, metrics, title): """Create a radar chart for the selected models and metrics.""" if not metrics or len(selected_models) == 0: # Return empty figure if no metrics or models selected fig = go.Figure() fig.update_layout( title=title, title_font_size=16, height=400, width=500, margin=dict(l=30, r=30, t=50, b=30) ) return fig # Filter dataframe for selected models filtered_df = df[df['Model'].isin(selected_models)] if filtered_df.empty: # Return empty figure if no data fig = go.Figure() fig.update_layout( title=title, title_font_size=16, height=400, width=500, margin=dict(l=30, r=30, t=50, b=30) ) return fig # Limit to top 5 models for better visualization (similar to inspiration file) if len(filtered_df) > 5: filtered_df = filtered_df.head(5) # Prepare data for radar chart categories = [m.split('_', 1)[0] for m in metrics] # Get category name (simple, set, etc.) fig = go.Figure() # Process in reverse order to match inspiration file for i, (_, row) in enumerate(filtered_df.iterrows()): values = [row[m] for m in metrics] # Close the loop for radar chart values.append(values[0]) categories_loop = categories + [categories[0]] fig.add_trace(go.Scatterpolar( name=row['Model'], r=values, theta=categories_loop, showlegend=True, mode="lines", line=dict(width=2, color=line_colors[i % len(line_colors)]), fill="toself", fillcolor=fill_colors[i % len(fill_colors)] )) fig.update_layout( font=dict(size=13, color="black"), template="plotly_white", polar=dict( radialaxis=dict( visible=True, gridcolor="black", linecolor="rgba(0,0,0,0)", gridwidth=1, showticklabels=False, ticks="", range=[0, 1] # Ensure consistent range for scores ), angularaxis=dict( gridcolor="black", gridwidth=1.5, linecolor="rgba(0,0,0,0)" ), ), legend=dict( orientation="h", yanchor="bottom", y=-0.35, xanchor="center", x=0.4, itemwidth=30, font=dict(size=13), entrywidth=0.6, entrywidthmode="fraction", ), margin=dict(l=0, r=16, t=30, b=30), autosize=True, ) return fig def create_summary_df(df, retrieval_metrics, generation_metrics): """Create a summary dataframe with averaged metrics for display.""" if df.empty: return pd.DataFrame() summary_df = df.copy() # Add retrieval average if retrieval_metrics: retrieval_avg = summary_df[retrieval_metrics].mean(axis=1).round(4) summary_df['Retrieval (avg)'] = retrieval_avg # Add generation average if generation_metrics: generation_avg = summary_df[generation_metrics].mean(axis=1).round(4) summary_df['Generation (avg)'] = generation_avg # Add total score if both averages exist if 'Retrieval (avg)' in summary_df.columns and 'Generation (avg)' in summary_df.columns: summary_df['Total Score'] = summary_df['Retrieval (avg)'] + summary_df['Generation (avg)'] summary_df = summary_df.sort_values('Total Score', ascending=False) # Select columns for display summary_cols = ['Model', 'Embeddings', 'Retriever', 'Top-K'] if 'Retrieval (avg)' in summary_df.columns: summary_cols.append('Retrieval (avg)') if 'Generation (avg)' in summary_df.columns: summary_cols.append('Generation (avg)') if 'Total Score' in summary_df.columns: summary_cols.append('Total Score') if 'Versions' in summary_df.columns: summary_cols.append('Versions') if 'Last Updated' in summary_df.columns: summary_cols.append('Last Updated') return summary_df[summary_cols] def create_category_df(df, category, retrieval_col, generation_col): """Create a dataframe for a specific category with detailed metrics.""" if df.empty or retrieval_col not in df.columns or generation_col not in df.columns: return pd.DataFrame() category_df = df.copy() # Calculate total score for this category category_df[f'{category} Score'] = category_df[retrieval_col] + category_df[generation_col] # Sort by total score category_df = category_df.sort_values(f'{category} Score', ascending=False) # Select columns for display category_cols = ['Model', 'Embeddings', 'Retriever', retrieval_col, generation_col, f'{category} Score'] # Rename columns for display category_df = category_df[category_cols].rename(columns={ retrieval_col: 'Retrieval', generation_col: 'Generation' }) return category_df # Load initial data results = load_results() last_version = results.get("last_version", "1.0") n_questions = results.get("n_questions", "100") date_title = results.get("date_title", "---") # Initial data processing df, retrieval_metrics, generation_metrics, category_metrics = filter_and_process_results( results, n_versions=1, only_actual_versions=True ) # Pre-generate charts for initial display default_models = df['Model'].head(5).tolist() if not df.empty else [] initial_gen_chart = create_radar_chart(df, default_models, generation_metrics, "Performance on Generation Tasks") initial_ret_chart = create_radar_chart(df, default_models, retrieval_metrics, "Performance on Retrieval Tasks") # Create summary dataframe summary_df = create_summary_df(df, retrieval_metrics, generation_metrics) with gr.Blocks(css=""" .title-container { text-align: center; margin-bottom: 10px; } .description-text { text-align: left; padding: 10px; margin-bottom: 0px; } .version-info { text-align: center; padding: 10px; background-color: #f0f0f0; border-radius: 8px; margin-bottom: 15px; } .version-selector { padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin-bottom: 20px; background-color: #f9f9f9; height: 100%; } .citation-block { padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin-bottom: 20px; background-color: #f9f9f9; font-family: monospace; font-size: 14px; overflow-x: auto; height: 100%; } .flex-row-container { display: flex; justify-content: space-between; gap: 20px; width: 100%; } .charts-container { display: flex; gap: 20px; margin-bottom: 20px; } .chart-box { flex: 1; border: 1px solid #eee; border-radius: 8px; padding: 10px; background-color: white; min-height: 550px; /* Increased height to accommodate legend at bottom */ } .metrics-table { border: 1px solid #eee; border-radius: 8px; padding: 15px; background-color: white; } .info-text { font-size: 0.9em; font-style: italic; color: #666; margin-top: 5px; } footer { text-align: center; margin-top: 30px; font-size: 0.9em; color: #666; } /* Style for selected rows */ table tbody tr.selected { background-color: rgba(25, 118, 210, 0.1) !important; border-left: 3px solid #1976d2; } /* Add this class via JavaScript */ .gr-table tbody tr.selected td:first-child { font-weight: bold; color: #1976d2; } .category-tab { padding: 10px; } .chart-title { font-size: 1.2em; font-weight: bold; margin-bottom: 10px; text-align: center; } .clear-charts-button { display: flex; justify-content: center; margin-top: 10px; margin-bottom: 20px; } """) as demo: # Title with gr.Row(elem_classes=["title-container"]): gr.Markdown("# 🐙 Dynamic RAG Benchmark") # Version info with gr.Row(elem_classes=["description-text"]): gr.Markdown(f"На этом лидерборде можно сравнить RAG системы в разрезе генеративных и поисковых метрик моделей по вопросам разного типа (простые вопросы, сравнения, multi-hop, условные и др.).