# -*- coding: utf-8 -*- # @Date : 2025/2/5 16:26 # @Author : q275343119 # @File : data_page.py # @Description: import io from st_aggrid import AgGrid, JsCode, ColumnsAutoSizeMode import streamlit as st import streamlit.components.v1 as components from utils.st_copy_to_clipboard import st_copy_to_clipboard from streamlit_theme import st_theme from app.backend.app_init_func import LEADERBOARD_MAP from app.backend.constant import LEADERBOARD_ICON_MAP, BASE_URL from app.backend.json_util import compress_msgpack, decompress_msgpack COLUMNS = ['model_name', 'vendor', 'embd_dtype', 'embd_dim', 'num_params', 'max_tokens', 'similarity', 'query_instruct', 'corpus_instruct', 'reference' ] LARGER_HEADER_STYLE = {'fontSize': '18px'} HEADER_STYLE = {'fontSize': '14px'} CELL_STYLE = {'fontSize': '14px'} def is_section(group_name): for k, v in LEADERBOARD_MAP.items(): leaderboard_name = v[0][0] if group_name == leaderboard_name: return True return False def get_closed_dataset(): data_engine = st.session_state["data_engine"] closed_list = [] results = data_engine.results for result in results: if result.get("is_closed"): closed_list.append(result.get("dataset_name")) return closed_list def convert_df_to_csv(df): output = io.StringIO() df.to_csv(output, index=False) return output.getvalue() def get_column_state(): """ get column state from url """ query_params = st.query_params.get("grid_state", None) sider_bar_hidden = st.query_params.get("sider_bar_hidden", "False") if query_params: grid_state = decompress_msgpack(query_params) st.session_state.grid_state = grid_state if sider_bar_hidden.upper() == 'FALSE': st.session_state.sider_bar_hidden = False return None def _get_dataset_columns(group_name, column_list, avg_column): """Generate dataset columns with proper grouping for individual dataset pages.""" dataset_columns = [col for col in column_list if col not in (avg_column, "Closed average", "Open average")] # For individual dataset pages (not sections), group datasets by open/closed if not is_section(group_name) and dataset_columns: # Separate open and closed datasets open_datasets = [d for d in dataset_columns if not d.startswith('_')] closed_datasets = [d for d in dataset_columns if d.startswith('_')] grouped_columns = [] # Add Open Datasets group if open_datasets: grouped_columns.append({ 'headerName': 'Open Datasets', 'headerStyle': LARGER_HEADER_STYLE, 'headerClass': 'group-header', 'marryChildren': True, 'openByDefault': True, 'children': [ { 'headerName': column, 'field': column, 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, "headerTooltip": column, 'headerComponent': JsCode(f""" class DatasetHeaderRenderer {{ init(params) {{ this.eGui = document.createElement('div'); const columnName = params.displayName; const fieldName = params.column.colId; const link = document.createElement('a'); link.href = 'https://huggingface.co/datasets/embedding-benchmark/' + fieldName; link.target = '_blank'; link.style.color = 'white'; link.style.textDecoration = 'underline'; link.style.cursor = 'pointer'; link.textContent = columnName; link.addEventListener('click', function(e) {{ e.stopPropagation(); }}); this.eGui.appendChild(link); }} getGui() {{ return this.eGui; }} }} """) } for column in open_datasets ] }) # Add Closed Datasets group if closed_datasets: grouped_columns.append({ 'headerName': 'Closed Datasets', 'headerStyle': LARGER_HEADER_STYLE, 'headerClass': 'group-header', 'marryChildren': True, 'openByDefault': True, 'children': [ { 'headerName': column, 'field': column, 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, "headerTooltip": column, 'headerComponent': JsCode(f""" class DatasetHeaderRenderer {{ init(params) {{ this.eGui = document.createElement('div'); const columnName = params.displayName; const fieldName = params.column.colId; const link = document.createElement('a'); link.href = 'https://huggingface.co/datasets/embedding-benchmark/' + fieldName; link.target = '_blank'; link.style.color = 'white'; link.style.textDecoration = 'underline'; link.style.cursor = 'pointer'; link.textContent = columnName; link.addEventListener('click', function(e) {{ e.stopPropagation(); }}); this.eGui.appendChild(link); }} getGui() {{ return this.eGui; }} }} """) } for column in closed_datasets ] }) return grouped_columns else: # For section pages, return columns without grouping (original behavior) return [{'headerName': column if "Average" not in column else column.replace("Average", "").strip().capitalize(), 'field': column, 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, "headerTooltip": column if "Average" not in column else column.replace("Average", "").strip().capitalize(), 'headerComponent': JsCode(f""" class DatasetHeaderRenderer {{ init(params) {{ this.eGui = document.createElement('div'); const columnName = params.displayName; const fieldName = params.column.colId; if (fieldName.includes('Average')) {{ this.eGui.textContent = columnName; }} else {{ const link = document.createElement('a'); link.href = 'https://huggingface.co/datasets/embedding-benchmark/' + fieldName; link.target = '_blank'; link.style.color = 'white'; link.style.textDecoration = 'underline'; link.style.cursor = 'pointer'; link.textContent = columnName; link.addEventListener('click', function(e) {{ e.stopPropagation(); }}); this.eGui.appendChild(link); }} }} getGui() {{ return this.eGui; }} }} """) } for column in dataset_columns] def render_page(group_name): grid_state = st.session_state.get("grid_state", {}) st.session_state.sider_bar_hidden = True get_column_state() if st.session_state.sider_bar_hidden: st.markdown(""" """, unsafe_allow_html=True) # Add theme color and grid styles st.title("Retrieval Embedding Benchmark (RTEB)") st.markdown(""" """, unsafe_allow_html=True) # logo # st.markdown('', unsafe_allow_html=True) title = f'

{LEADERBOARD_ICON_MAP.get(group_name.capitalize(), "")} {group_name.capitalize()}

' if is_section(group_name): title = f'

{LEADERBOARD_ICON_MAP.get(group_name.capitalize() + " Leaderboard", "")} {group_name.capitalize() + " Leaderboard"}

' # title st.markdown(title, unsafe_allow_html=True) data_engine = st.session_state["data_engine"] df = data_engine.jsons_to_df().copy() csv = convert_df_to_csv(df) file_name = f"{group_name.capitalize()} Leaderboard" if is_section(group_name) else group_name.capitalize() # get columns column_list = [] avg_column = None if is_section(group_name): avg_columns = [] for column in df.columns: if column.startswith("Average"): avg_columns.insert(0, column) continue if "Average" in column: avg_columns.append(column) continue avg_column = avg_columns[0] column_list.extend(avg_columns) else: for column in df.columns: if column.startswith(group_name.capitalize() + " "): avg_column = column column_list.append(avg_column) dataset_list = [] for dataset_dict in data_engine.datasets: if dataset_dict["name"] == group_name: dataset_list = dataset_dict["datasets"] if not is_section(group_name): # Sort datasets: open first, then closed open_datasets = [d for d in dataset_list if not d.startswith('_')] closed_datasets = [d for d in dataset_list if d.startswith('_')] dataset_list = open_datasets + closed_datasets column_list.extend(dataset_list) closed_list = get_closed_dataset() close_avg_list = list(set(dataset_list) & set(closed_list)) df["Closed average"] = df[close_avg_list].mean(axis=1).round(2) column_list.append("Closed average") open_avg_list = list(set(dataset_list) - set(closed_list)) df["Open average"] = df[open_avg_list].mean(axis=1).round(2) column_list.append("Open average") df = df[COLUMNS + column_list].sort_values(by=avg_column, ascending=False) # rename avg column name if not is_section(group_name): new_column = avg_column.replace(group_name.capitalize(), "").strip() df.rename(columns={avg_column: new_column}, inplace=True) column_list.remove(avg_column) avg_column = new_column # setting column config grid_options = { 'columnDefs': [ { 'headerName': 'Model Name', 'field': 'model_name', 'pinned': 'left', 'sortable': False, 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, "tooltipValueGetter": JsCode( """function(p) {return p.value}""" ), "width": 250, 'cellRenderer': JsCode("""class CustomHTML { init(params) { const link = params.data.reference; this.eGui = document.createElement('div'); this.eGui.innerHTML = link ? `${params.value} ` : params.value; } getGui() { return this.eGui; } }"""), 'suppressSizeToFit': True }, {'headerName': "Vendor", 'field': 'vendor', 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, # 'suppressSizeToFit': True }, {'headerName': "Overall Score", 'field': avg_column, 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, # 'suppressSizeToFit': True }, # Add Open average column definition {'headerName': 'Open Average', 'field': 'Open average', 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, # 'suppressSizeToFit': True }, {'headerName': 'Closed Average', 'field': 'Closed average', 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, # 'suppressSizeToFit': True }, { 'headerName': 'Embd Dtype', 'field': 'embd_dtype', 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, # 'suppressSizeToFit': True, }, { 'headerName': 'Embd Dim', 'field': 'embd_dim', 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, # 'suppressSizeToFit': True, }, { 'headerName': 'Number of Parameters', 'field': 'num_params', 'cellDataType': 'number', "colId": "num_params", 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, 'valueFormatter': JsCode( """function(params) { const num = params.value; if (num >= 1e9) return (num / 1e9).toFixed(2) + "B"; if (num >= 1e6) return (num / 1e6).toFixed(2) + "M"; if (num >= 1e3) return (num / 1e3).toFixed(2) + "K"; return num; }""" ), "width": 120, # 'suppressSizeToFit': True, }, { 'headerName': 'Context Length', 'field': 'max_tokens', 'headerStyle': HEADER_STYLE, 'cellStyle': CELL_STYLE, # 'suppressSizeToFit': True, }, *(_get_dataset_columns(group_name, column_list, avg_column)) ], 'defaultColDef': { 'filter': True, 'sortable': True, 'resizable': True, 'headerClass': "multi-line-header", 'autoHeaderHeight': True, 'width': 105 }, "autoSizeStrategy": { "type": 'fitCellContents', "colIds": [column for column in column_list if column not in (avg_column, "Closed average", "Open average")] }, "tooltipShowDelay": 500, "initialState": grid_state, } custom_css = { # Model Name Cell ".a-cell": { "display": "inline-block", "white-space": "nowrap", "overflow": "hidden", "text-overflow": "ellipsis", "width": "100%", "min-width": "0" }, # Header ".multi-line-header": { "text-overflow": "clip", "overflow": "visible", "white-space": "normal", "height": "auto", "font-family": 'Arial', "font-size": "14px", "font-weight": "bold", "padding": "10px", "text-align": "left", } , # Filter Options and Input ".ag-theme-streamlit .ag-popup": { "font-family": 'Arial', "font-size": "14px", } , ".ag-picker-field-display": { "font-family": 'Arial', "font-size": "14px", }, ".ag-input-field-input .ag-text-field-input": { "font-family": 'Arial', "font-size": "14px", } } grid = AgGrid( df, enable_enterprise_modules=False, gridOptions=grid_options, allow_unsafe_jscode=True, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW, theme="streamlit", custom_css=custom_css, update_on=["stateUpdated"], ) @st.dialog("URL") def share_url(): state = grid.grid_state if state: share_link = f'{BASE_URL.replace("_", "-")}{group_name}/?grid_state={compress_msgpack(state)}' if not is_section( group_name) else f'{BASE_URL.replace("_", "-")}?grid_state={compress_msgpack(state)}' else: share_link = f'{BASE_URL.replace("_", "-")}{group_name}' st.write(share_link) theme = st_theme() if theme: theme = theme.get("base") else: theme = "light" st_copy_to_clipboard(share_link, before_copy_label='📋Push to copy', after_copy_label='✅Text copied!', theme=theme) col1, col2 = st.columns([1, 1]) with col1: st.download_button( label="Download CSV", data=csv, file_name=f"{file_name}.csv", mime="text/csv", icon=":material/download:", ) with col2: share_btn = st.button("Share this page", icon=":material/share:") if share_btn: share_url()