leaderboard / custom-select-columns.py
jasonshaoshun
debug
ef71549
raw
history blame
10.8 kB
import gradio as gr
import pandas as pd
from typing import List, Dict, Union, Optional
class SmartSelectColumns(gr.SelectColumns):
"""
Enhanced SelectColumns component that supports substring matching and column mapping.
Inherits from gr.SelectColumns but adds additional filtering capabilities.
"""
def __init__(
self,
*args,
column_filters: Optional[Dict[str, List[str]]] = None,
column_mapping: Optional[Dict[str, str]] = None,
**kwargs
):
"""
Initialize the SmartSelectColumns component.
Args:
column_filters: Dict mapping filter names to lists of substrings to match
column_mapping: Dict mapping display names to actual column names
*args, **kwargs: Arguments passed to parent SelectColumns
"""
super().__init__(*args, **kwargs)
self.column_filters = column_filters or {}
self.column_mapping = column_mapping or {}
def preprocess(self, x: List[str]) -> List[str]:
"""Transform selected display names back to actual column names."""
if self.column_mapping:
reverse_mapping = {v: k for k, v in self.column_mapping.items()}
return [reverse_mapping.get(col, col) for col in x]
return x
def get_filtered_columns(self, df: pd.DataFrame) -> Dict[str, List[str]]:
"""
Get columns filtered by substring matches.
Args:
df: Input DataFrame
Returns:
Dict mapping filter names to lists of matching columns
"""
filtered_cols = {}
for filter_name, substrings in self.column_filters.items():
matching_cols = []
for col in df.columns:
if any(substr.lower() in col.lower() for substr in substrings):
matching_cols.append(col)
filtered_cols[filter_name] = matching_cols
return filtered_cols
def update(
self,
value: Union[pd.DataFrame, Dict[str, List[str]]],
interactive: Optional[bool] = None
) -> Dict:
"""
Update the component with new values.
Args:
value: Either a DataFrame or dict of predefined column groups
interactive: Whether the component should be interactive
Returns:
Dict containing the update configuration
"""
if isinstance(value, pd.DataFrame):
# Get filtered column groups
filtered_cols = self.get_filtered_columns(value)
# Create display names for columns if mapping exists
choices = list(value.columns)
if self.column_mapping:
choices = [self.column_mapping.get(col, col) for col in choices]
return {
"choices": choices,
"filtered_cols": filtered_cols,
"interactive": interactive if interactive is not None else self.interactive
}
return super().update(value, interactive)
# Example usage
if __name__ == "__main__":
df = pd.DataFrame({
"ioi_score_1": [1, 2, 3],
"ioi_score_2": [4, 5, 6],
"other_metric": [7, 8, 9],
"performance_1": [10, 11, 12]
})
# Define filters and mappings
column_filters = {
"IOI Metrics": ["ioi"],
"Performance Metrics": ["performance"]
}
column_mapping = {
"ioi_score_1": "IOI Score (Type 1)",
"ioi_score_2": "IOI Score (Type 2)",
"other_metric": "Other Metric",
"performance_1": "Performance Metric 1"
}
# Create interface
with gr.Blocks() as demo:
select_cols = SmartSelectColumns(
column_filters=column_filters,
column_mapping=column_mapping,
multiselect=True
)
# Update component with DataFrame
select_cols.update(df)
demo.launch()
import gradio as gr
import pandas as pd
from typing import List, Dict, Union, Optional, Any
from dataclasses import fields
class SmartSelectColumns(gr.SelectColumns):
"""
Enhanced SelectColumns component for Gradio Leaderboard with smart filtering and mapping capabilities.
"""
def __init__(
self,
column_filters: Optional[Dict[str, List[str]]] = None,
column_mapping: Optional[Dict[str, str]] = None,
initial_selected: Optional[List[str]] = None,
*args,
**kwargs
):
"""
Initialize SmartSelectColumns with enhanced functionality.
Args:
column_filters: Dict mapping filter names to lists of substrings to match
column_mapping: Dict mapping actual column names to display names
initial_selected: List of column names to be initially selected
*args, **kwargs: Additional arguments passed to parent SelectColumns
"""
super().__init__(*args, **kwargs)
self.column_filters = column_filters or {}
self.column_mapping = column_mapping or {}
self.reverse_mapping = {v: k for k, v in self.column_mapping.items()} if column_mapping else {}
self.initial_selected = initial_selected or []
def preprocess(self, x: List[str]) -> List[str]:
"""
Transform selected display names back to actual column names.
Args:
x: List of selected display names
Returns:
List of actual column names
"""
return [self.reverse_mapping.get(col, col) for col in x]
def postprocess(self, y: List[str]) -> List[str]:
"""
Transform actual column names to display names.
Args:
y: List of actual column names
Returns:
List of display names
"""
return [self.column_mapping.get(col, col) for col in y]
def get_filtered_columns(self, df: pd.DataFrame) -> Dict[str, List[str]]:
"""
Get columns filtered by substring matches.
Args:
df: Input DataFrame
Returns:
Dict mapping filter names to lists of matching display names
"""
filtered_cols = {}
for filter_name, substrings in self.column_filters.items():
matching_cols = []
for col in df.columns:
if any(substr.lower() in col.lower() for substr in substrings):
display_name = self.column_mapping.get(col, col)
matching_cols.append(display_name)
filtered_cols[filter_name] = matching_cols
return filtered_cols
def update(
self,
value: Union[pd.DataFrame, Dict[str, List[str]], Any],
interactive: Optional[bool] = None
) -> Dict:
"""
Update component with new values, supporting DataFrame fields.
Args:
value: DataFrame, dict of columns, or fields object
interactive: Whether component should be interactive
Returns:
Dict containing update configuration
"""
if isinstance(value, pd.DataFrame):
filtered_cols = self.get_filtered_columns(value)
choices = [self.column_mapping.get(col, col) for col in value.columns]
# Set initial selection if provided
value = self.initial_selected if self.initial_selected else choices
return {
"choices": choices,
"value": value,
"filtered_cols": filtered_cols,
"interactive": interactive if interactive is not None else self.interactive
}
# Handle fields object (e.g., from dataclass)
if hasattr(value, '__dataclass_fields__'):
field_names = [field.name for field in fields(value)]
choices = [self.column_mapping.get(name, name) for name in field_names]
return {
"choices": choices,
"value": self.initial_selected if self.initial_selected else choices,
"interactive": interactive if interactive is not None else self.interactive
}
return super().update(value, interactive)
def initialize_leaderboard(df: pd.DataFrame, column_class: Any,
filters: Dict[str, List[str]],
mappings: Dict[str, str],
initial_columns: Optional[List[str]] = None) -> gr.Leaderboard:
"""
Initialize a Gradio Leaderboard with SmartSelectColumns.
Args:
df: Input DataFrame
column_class: Class containing column definitions (e.g., AutoEvalColumn_mib_subgraph)
filters: Column filters for substring matching
mappings: Column name mappings (actual -> display)
initial_columns: List of columns to show initially
Returns:
Configured Leaderboard instance
"""
# Create renamed DataFrame with display names
renamed_df = df.rename(columns=mappings)
# Initialize SmartSelectColumns
smart_columns = SmartSelectColumns(
column_filters=filters,
column_mapping=mappings,
initial_selected=initial_columns,
multiselect=True
)
return gr.Leaderboard(
value=renamed_df,
datatype=[c.type for c in fields(column_class)],
select_columns=smart_columns,
search_columns=["Method"],
hide_columns=[],
interactive=False
)
# Example usage
if __name__ == "__main__":
# Sample data
df = pd.DataFrame({
"ioi_score_1": [1, 2, 3],
"ioi_score_2": [4, 5, 6],
"other_metric": [7, 8, 9],
"performance_1": [10, 11, 12],
"Method": ["A", "B", "C"]
})
# Define filters and mappings
filters = {
"IOI Metrics": ["ioi"],
"Performance Metrics": ["performance"]
}
mappings = {
"ioi_score_1": "IOI Score (Type 1)",
"ioi_score_2": "IOI Score (Type 2)",
"other_metric": "Other Metric",
"performance_1": "Performance Metric 1"
}
# Create demo interface
with gr.Blocks() as demo:
# Initialize leaderboard with smart columns
leaderboard = initialize_leaderboard(
df=df,
column_class=None, # Replace with your actual column class
filters=filters,
mappings=mappings,
initial_columns=["Method", "IOI Score (Type 1)"]
)
demo.launch()