# -*- coding: utf-8 -*-
# @Date : 2025/2/5 16:26
# @Author : q275343119
# @File : data_page.py
# @Description:
import io
from st_aggrid import AgGrid, JsCode, ColumnsAutoSizeMode
import streamlit as st
from app.backend.data_engine import DataEngine
from app.backend.multi_header_util import get_header_options
from utils.st_copy_to_clipboard import st_copy_to_clipboard
from streamlit_theme import st_theme
from app.backend.app_init_func import LEADERBOARD_MAP
from app.backend.constant import LEADERBOARD_ICON_MAP, BASE_URL
from app.backend.json_util import compress_msgpack, decompress_msgpack
COLUMNS = ['model_name', 'vendor',
'embd_dtype', 'embd_dim', 'num_params', 'max_tokens', 'similarity',
'query_instruct', 'corpus_instruct', 'reference'
]
HEADER_STYLE = {'fontSize': '18px'}
CELL_STYLE = {'fontSize': '18px'}
def is_section(group_name):
for k, v in LEADERBOARD_MAP.items():
leaderboard_name = v[0][0]
if group_name == leaderboard_name:
return True
return False
def get_closed_dataset():
data_engine = st.session_state.get("data_engine", DataEngine())
closed_list = []
results = data_engine.results
for result in results:
if result.get("is_closed"):
closed_list.append(result.get("dataset_name"))
return closed_list
def convert_df_to_csv(df):
output = io.StringIO()
df.to_csv(output, index=False)
return output.getvalue()
def get_column_state():
"""
get column state from url
"""
query_params = st.query_params.get("grid_state", None)
sider_bar_hidden = st.query_params.get("sider_bar_hidden", "False")
table_only = st.query_params.get("table_only", "False")
if query_params:
grid_state = decompress_msgpack(query_params)
st.session_state.grid_state = grid_state
if sider_bar_hidden.upper() == 'FALSE':
st.session_state.sider_bar_hidden = False
if table_only.upper() == 'FALSE':
st.session_state.table_only = False
return None
def sidebar_css():
"""
:return:
"""
if st.session_state.get("sider_bar_hidden"):
st.markdown("""
""", unsafe_allow_html=True)
def table_only_css():
if st.session_state.get("table_only"):
st.markdown("""
""", unsafe_allow_html=True)
def table_area(group_name, grid_state, data_engine=None, df=None):
"""
table_area
:param group_name:
:param grid_state:
:param data_engine:
:param df:
:return:
"""
table_only_css()
if data_engine is None:
data_engine = st.session_state.get("data_engine", DataEngine())
if df is None:
df = data_engine.jsons_to_df().copy()
# get columns
column_list = []
avg_column = None
if is_section(group_name):
avg_columns = []
for column in df.columns:
if column.startswith("Average"):
avg_columns.insert(0, column)
continue
if "Average" in column:
avg_columns.append(column)
continue
avg_column = avg_columns[0]
column_list.extend(avg_columns)
else:
for column in df.columns:
if column.startswith(group_name.capitalize() + " "):
avg_column = column
column_list.append(avg_column)
dataset_list = []
for dataset_dict in data_engine.datasets:
if dataset_dict["name"] == group_name:
dataset_list = dataset_dict["datasets"]
if not is_section(group_name):
column_list.extend(dataset_list)
closed_list = get_closed_dataset()
close_avg_list = list(set(dataset_list) & set(closed_list))
df["Closed average"] = df[close_avg_list].mean(axis=1).round(2)
column_list.append("Closed average")
open_avg_list = list(set(dataset_list) - set(closed_list))
df["Open average"] = df[open_avg_list].mean(axis=1).round(2)
column_list.append("Open average")
df = df[COLUMNS + column_list].sort_values(by=avg_column, ascending=False)
# rename avg column name
if not is_section(group_name):
new_column = avg_column.replace(group_name.capitalize(), "").strip()
df.rename(columns={avg_column: new_column}, inplace=True)
column_list.remove(avg_column)
avg_column = new_column
# setting column config - 优化缓存机制,减少不必要的session_state更新
grid_options = st.session_state.get(f"{group_name}_grid_options")
if grid_options is None:
grid_options = get_header_options(column_list, avg_column, is_section(group_name))
st.session_state[f"{group_name}_grid_options"] = grid_options
grid_options["initialState"] = grid_state
custom_css = {
# Model Name Cell
".a-cell": {
"display": "inline-block",
"white-space": "nowrap",
"overflow": "hidden",
"text-overflow": "ellipsis",
"width": "100%",
"min-width": "0"
},
# Header
".multi-line-header": {
"text-overflow": "clip",
"overflow": "visible",
"white-space": "normal",
"height": "auto",
"font-family": 'Arial',
"font-size": "14px",
"font-weight": "bold",
"padding": "10px",
"text-align": "left",
},
# Custom header and cell styles to replace headerStyle and cellStyle
".custom-header-style": {
"text-overflow": "clip",
"overflow": "visible",
"white-space": "normal",
"height": "auto",
"font-family": 'Arial',
"font-size": "14px",
"font-weight": "bold",
"padding": "10px",
"text-align": "left",
"width":"150px"
},
".custom-cell-style": {
"font-size": "14px",
"color": "inherit",
},
# Filter Options and Input
".ag-theme-streamlit .ag-popup": {
"font-family": 'Arial',
"font-size": "14px",
}
, ".ag-picker-field-display": {
"font-family": 'Arial',
"font-size": "14px",
},
".ag-input-field-input .ag-text-field-input": {
"font-family": 'Arial',
"font-size": "14px",
}
}
grid = AgGrid(
df,
enable_enterprise_modules=False,
gridOptions=grid_options,
allow_unsafe_jscode=True,
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW,
theme="streamlit",
custom_css=custom_css,
update_on=["selectionChanged", "filterChanged"], # 减少WebSocket触发频率,只在重要变化时触发
)
return grid
def main_page(group_name, grid_state):
"""
main_page
:param group_name:
:param grid_state:
:return:
"""
# Add theme color and grid styles
st.title("Retrieval Embedding Benchmark (RTEB)")
st.markdown("""
""", unsafe_allow_html=True)
# logo
# st.markdown('', unsafe_allow_html=True)
title = f'