Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# @Date : 2025/2/5 16:26 | |
# @Author : q275343119 | |
# @File : data_page.py | |
# @Description: | |
import io | |
from st_aggrid import AgGrid, JsCode, ColumnsAutoSizeMode | |
import streamlit as st | |
from app.backend.data_engine import DataEngine | |
from app.backend.multi_header_util import get_header_options | |
from utils.st_copy_to_clipboard import st_copy_to_clipboard | |
from streamlit_theme import st_theme | |
from app.backend.app_init_func import LEADERBOARD_MAP | |
from app.backend.constant import LEADERBOARD_ICON_MAP, BASE_URL | |
from app.backend.json_util import compress_msgpack, decompress_msgpack | |
COLUMNS = ['model_name', 'vendor', | |
'embd_dtype', 'embd_dim', 'num_params', 'max_tokens', 'similarity', | |
'query_instruct', 'corpus_instruct', 'reference' | |
] | |
HEADER_STYLE = {'fontSize': '18px'} | |
CELL_STYLE = {'fontSize': '18px'} | |
def is_section(group_name): | |
for k, v in LEADERBOARD_MAP.items(): | |
leaderboard_name = v[0][0] | |
if group_name == leaderboard_name: | |
return True | |
return False | |
def get_closed_dataset(): | |
data_engine = st.session_state.get("data_engine", DataEngine()) | |
closed_list = [] | |
results = data_engine.results | |
for result in results: | |
if result.get("is_closed"): | |
closed_list.append(result.get("dataset_name")) | |
return closed_list | |
def convert_df_to_csv(df): | |
output = io.StringIO() | |
df.to_csv(output, index=False) | |
return output.getvalue() | |
def get_column_state(): | |
""" | |
get column state from url | |
""" | |
query_params = st.query_params.get("grid_state", None) | |
sider_bar_hidden = st.query_params.get("sider_bar_hidden", "False") | |
table_only = st.query_params.get("table_only", "False") | |
if query_params: | |
grid_state = decompress_msgpack(query_params) | |
st.session_state.grid_state = grid_state | |
if sider_bar_hidden.upper() == 'FALSE': | |
st.session_state.sider_bar_hidden = False | |
if table_only.upper() == 'FALSE': | |
st.session_state.table_only = False | |
return None | |
def sidebar_css(): | |
""" | |
:return: | |
""" | |
if st.session_state.get("sider_bar_hidden"): | |
st.markdown(""" | |
<style> | |
[data-testid="stSidebar"] { | |
display: none !important; | |
} | |
[data-testid="stSidebarNav"] { | |
display: none !important; | |
} | |
[data-testid="stBaseButton-headerNoPadding"] { | |
display: none !important; | |
} | |
h1#retrieval-embedding-benchmark-rteb { | |
text-align: center; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def table_only_css(): | |
if st.session_state.get("table_only"): | |
st.markdown(""" | |
<style> | |
[data-testid="stMainBlockContainer"] { | |
padding-top: 0px; | |
padding-left: 0px; | |
padding-bottom: 0px; | |
padding-right: 0px; | |
} | |
[data-testid="stHeader"] { | |
height: 0px; | |
} | |
[data-testid="stApp"] { | |
height: 456px; | |
} | |
.st-emotion-cache-1dp5vir { | |
height: 0px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def table_area(group_name, grid_state, data_engine=None, df=None): | |
""" | |
table_area | |
:param group_name: | |
:param grid_state: | |
:param data_engine: | |
:param df: | |
:return: | |
""" | |
table_only_css() | |
if data_engine is None: | |
data_engine = st.session_state.get("data_engine", DataEngine()) | |
if df is None: | |
df = data_engine.jsons_to_df().copy() | |
# get columns | |
column_list = [] | |
avg_column = None | |
if is_section(group_name): | |
avg_columns = [] | |
for column in df.columns: | |
if column.startswith("Average"): | |
avg_columns.insert(0, column) | |
continue | |
if "Average" in column: | |
avg_columns.append(column) | |
continue | |
avg_column = avg_columns[0] | |
column_list.extend(avg_columns) | |
else: | |
for column in df.columns: | |
if column.startswith(group_name.capitalize() + " "): | |
avg_column = column | |
column_list.append(avg_column) | |
dataset_list = [] | |
for dataset_dict in data_engine.datasets: | |
if dataset_dict["name"] == group_name: | |
dataset_list = dataset_dict["datasets"] | |
if not is_section(group_name): | |
column_list.extend(dataset_list) | |
closed_list = get_closed_dataset() | |
close_avg_list = list(set(dataset_list) & set(closed_list)) | |
df["Closed average"] = df[close_avg_list].mean(axis=1).round(2) | |
column_list.append("Closed average") | |
open_avg_list = list(set(dataset_list) - set(closed_list)) | |
df["Open average"] = df[open_avg_list].mean(axis=1).round(2) | |
column_list.append("Open average") | |
df = df[COLUMNS + column_list].sort_values(by=avg_column, ascending=False) | |
# rename avg column name | |
if not is_section(group_name): | |
new_column = avg_column.replace(group_name.capitalize(), "").strip() | |
df.rename(columns={avg_column: new_column}, inplace=True) | |
column_list.remove(avg_column) | |
avg_column = new_column | |
# setting column config - 优化缓存机制,减少不必要的session_state更新 | |
grid_options = st.session_state.get(f"{group_name}_grid_options") | |
if grid_options is None: | |
grid_options = get_header_options(column_list, avg_column, is_section(group_name)) | |
st.session_state[f"{group_name}_grid_options"] = grid_options | |
grid_options["initialState"] = grid_state | |
custom_css = { | |
# Model Name Cell | |
".a-cell": { | |
"display": "inline-block", | |
"white-space": "nowrap", | |
"overflow": "hidden", | |
"text-overflow": "ellipsis", | |
"width": "100%", | |
"min-width": "0" | |
}, | |
# Header | |
".multi-line-header": { | |
"text-overflow": "clip", | |
"overflow": "visible", | |
"white-space": "normal", | |
"height": "auto", | |
"font-family": 'Arial', | |
"font-size": "14px", | |
"font-weight": "bold", | |
"padding": "10px", | |
"text-align": "left", | |
}, | |
# Custom header and cell styles to replace headerStyle and cellStyle | |
".custom-header-style": { | |
"text-overflow": "clip", | |
"overflow": "visible", | |
"white-space": "normal", | |
"height": "auto", | |
"font-family": 'Arial', | |
"font-size": "14px", | |
"font-weight": "bold", | |
"padding": "10px", | |
"text-align": "left", | |
"width":"150px" | |
}, | |
".custom-cell-style": { | |
"font-size": "14px", | |
"color": "inherit", | |
}, | |
# Filter Options and Input | |
".ag-theme-streamlit .ag-popup": { | |
"font-family": 'Arial', | |
"font-size": "14px", | |
} | |
, ".ag-picker-field-display": { | |
"font-family": 'Arial', | |
"font-size": "14px", | |
}, | |
".ag-input-field-input .ag-text-field-input": { | |
"font-family": 'Arial', | |
"font-size": "14px", | |
} | |
} | |
grid = AgGrid( | |
df, | |
enable_enterprise_modules=False, | |
gridOptions=grid_options, | |
allow_unsafe_jscode=True, | |
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW, | |
theme="streamlit", | |
custom_css=custom_css, | |
update_on=["selectionChanged", "filterChanged"], # 减少WebSocket触发频率,只在重要变化时触发 | |
) | |
return grid | |
def main_page(group_name, grid_state): | |
""" | |
main_page | |
:param group_name: | |
:param grid_state: | |
:return: | |
""" | |
# Add theme color and grid styles | |
st.title("Retrieval Embedding Benchmark (RTEB)") | |
st.markdown(""" | |
<style> | |
:root { | |
--theme-color: rgb(129, 150, 64); | |
--theme-color-light: rgba(129, 150, 64, 0.2); | |
} | |
/* AG Grid specific overrides */ | |
.ag-theme-alpine { | |
--ag-selected-row-background-color: var(--theme-color-light) !important; | |
--ag-row-hover-color: var(--theme-color-light) !important; | |
--ag-selected-tab-color: var(--theme-color) !important; | |
--ag-range-selection-border-color: var(--theme-color) !important; | |
--ag-range-selection-background-color: var(--theme-color-light) !important; | |
} | |
.ag-row-hover { | |
background-color: var(--theme-color-light) !important; | |
} | |
.ag-row-selected { | |
background-color: var(--theme-color-light) !important; | |
} | |
.ag-row-focus { | |
background-color: var(--theme-color-light) !important; | |
} | |
.ag-cell-focus { | |
border-color: var(--theme-color) !important; | |
} | |
/* Keep existing styles */ | |
.center-text { | |
text-align: center; | |
color: var(--theme-color); | |
} | |
.center-image { | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
h2 { | |
color: var(--theme-color) !important; | |
} | |
.ag-header-cell { | |
background-color: var(--theme-color) !important; | |
color: white !important; | |
} | |
a { | |
color: var(--theme-color) !important; | |
} | |
a:hover { | |
color: rgba(129, 150, 64, 0.8) !important; | |
} | |
/* Download Button */ | |
button[data-testid="stBaseButton-secondary"] { | |
float: right; | |
} | |
/* Toast On The Top*/ | |
div[data-testid="stToastContainer"] { | |
position: fixed !important; | |
z-index: 2147483647 !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# logo | |
# st.markdown('<img src="https://www.voyageai.com/logo.svg" class="center-image" width="200">', unsafe_allow_html=True) | |
title = f'<h2 class="center-text">{LEADERBOARD_ICON_MAP.get(group_name.capitalize(), "")} {group_name.capitalize()}</h2>' | |
if is_section(group_name): | |
title = f'<h2 class="center-text">{LEADERBOARD_ICON_MAP.get(group_name.capitalize() + " Leaderboard", "")} {group_name.capitalize() + " Leaderboard"}</h2>' | |
# title | |
st.markdown(title, unsafe_allow_html=True) | |
data_engine = st.session_state.get("data_engine", DataEngine()) | |
df = data_engine.jsons_to_df().copy() | |
csv = convert_df_to_csv(df) | |
file_name = f"{group_name.capitalize()} Leaderboard" if is_section(group_name) else group_name.capitalize() | |
grid = table_area(group_name, grid_state, data_engine, df) | |
def share_url(): | |
state = grid.grid_state | |
if state: | |
share_link = f'{BASE_URL.replace("_", "-")}{group_name}/?grid_state={compress_msgpack(state)}' if not is_section( | |
group_name) else f'{BASE_URL.replace("_", "-")}?grid_state={compress_msgpack(state)}' | |
else: | |
share_link = f'{BASE_URL.replace("_", "-")}{group_name}' | |
st.write(share_link) | |
theme = st_theme() | |
if theme: | |
theme = theme.get("base") | |
else: | |
theme = "light" | |
st_copy_to_clipboard(share_link, before_copy_label='📋Push to copy', after_copy_label='✅Text copied!', | |
theme=theme) | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.download_button( | |
label="Download CSV", | |
data=csv, | |
file_name=f"{file_name}.csv", | |
mime="text/csv", | |
icon=":material/download:", | |
) | |
with col2: | |
share_btn = st.button("Share this page", icon=":material/share:") | |
if share_btn: | |
share_url() | |
def render_page(group_name): | |
grid_state = st.session_state.get("grid_state", {}) | |
st.session_state.sider_bar_hidden = True | |
st.session_state.table_only = True | |
get_column_state() | |
sidebar_css() | |
if st.session_state.get("table_only"): | |
table_area(group_name, grid_state) | |
else: | |
main_page(group_name, grid_state) | |