jasonshaoshun commited on
Commit
3cfa82d
·
1 Parent(s): 1980738
Files changed (2) hide show
  1. app.py +6 -1
  2. custom-select-columns.py +184 -0
app.py CHANGED
@@ -410,8 +410,10 @@ def init_leaderboard_mib_subgraph(dataframe, track):
410
  initial_selected=["Method", "Average"]
411
  )
412
 
 
 
413
  # Create Leaderboard
414
- return Leaderboard(
415
  value=renamed_df,
416
  datatype=[c.type for c in fields(AutoEvalColumn_mib_subgraph)],
417
  select_columns=smart_columns,
@@ -419,6 +421,9 @@ def init_leaderboard_mib_subgraph(dataframe, track):
419
  hide_columns=[],
420
  interactive=False
421
  )
 
 
 
422
 
423
 
424
 
 
410
  initial_selected=["Method", "Average"]
411
  )
412
 
413
+ print("\nDebugging DataFrame columns:", renamed_df.columns.tolist())
414
+
415
  # Create Leaderboard
416
+ leaderboard = Leaderboard(
417
  value=renamed_df,
418
  datatype=[c.type for c in fields(AutoEvalColumn_mib_subgraph)],
419
  select_columns=smart_columns,
 
421
  hide_columns=[],
422
  interactive=False
423
  )
424
+ print(f"Successfully created leaderboard.")
425
+ return leaderboard
426
+
427
 
428
 
429
 
custom-select-columns.py CHANGED
@@ -579,6 +579,190 @@ if __name__ == "__main__":
579
 
580
 
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
 
584
 
 
579
 
580
 
581
 
582
+
583
+
584
+
585
+
586
+
587
+ from gradio_leaderboard import SelectColumns, Leaderboard
588
+ import pandas as pd
589
+ from typing import List, Dict, Union, Optional, Any
590
+ from dataclasses import fields
591
+
592
+ class SmartSelectColumns(SelectColumns):
593
+ """
594
+ Enhanced SelectColumns component for gradio_leaderboard with dynamic column filtering.
595
+ """
596
+ def __init__(
597
+ self,
598
+ benchmark_keywords: Optional[List[str]] = None,
599
+ model_keywords: Optional[List[str]] = None,
600
+ column_mapping: Optional[Dict[str, str]] = None,
601
+ initial_selected: Optional[List[str]] = None,
602
+ **kwargs
603
+ ):
604
+ """
605
+ Initialize SmartSelectColumns with dynamic filtering.
606
+
607
+ Args:
608
+ benchmark_keywords: List of benchmark names to filter by (e.g., ["ioi", "mcqa"])
609
+ model_keywords: List of model names to filter by (e.g., ["llama3", "qwen2_5"])
610
+ column_mapping: Dict mapping actual column names to display names
611
+ initial_selected: List of columns to show initially
612
+ """
613
+ super().__init__(**kwargs)
614
+ self.benchmark_keywords = benchmark_keywords or []
615
+ self.model_keywords = model_keywords or []
616
+ self.column_mapping = column_mapping or {}
617
+ self.reverse_mapping = {v: k for k, v in self.column_mapping.items()} if column_mapping else {}
618
+ self.initial_selected = initial_selected or []
619
+
620
+ def preprocess_value(self, x: List[str]) -> List[str]:
621
+ """Transform selected display names back to actual column names."""
622
+ return [self.reverse_mapping.get(col, col) for col in x]
623
+
624
+ def postprocess_value(self, y: List[str]) -> List[str]:
625
+ """Transform actual column names to display names."""
626
+ return [self.column_mapping.get(col, col) for col in y]
627
+
628
+ def get_filtered_groups(self, df: pd.DataFrame) -> Dict[str, List[str]]:
629
+ """
630
+ Dynamically create column groups based on keywords.
631
+ """
632
+ filtered_groups = {}
633
+
634
+ # Create benchmark groups
635
+ for benchmark in self.benchmark_keywords:
636
+ matching_cols = [
637
+ col for col in df.columns
638
+ if benchmark in col.lower()
639
+ ]
640
+ if matching_cols:
641
+ group_name = f"Benchmark group for {benchmark}"
642
+ filtered_groups[group_name] = [
643
+ self.column_mapping.get(col, col)
644
+ for col in matching_cols
645
+ ]
646
+
647
+ # Create model groups
648
+ for model in self.model_keywords:
649
+ matching_cols = [
650
+ col for col in df.columns
651
+ if model in col.lower()
652
+ ]
653
+ if matching_cols:
654
+ group_name = f"Model group for {model}"
655
+ filtered_groups[group_name] = [
656
+ self.column_mapping.get(col, col)
657
+ for col in matching_cols
658
+ ]
659
+
660
+ return filtered_groups
661
+
662
+ def update(
663
+ self,
664
+ value: Union[pd.DataFrame, Dict[str, List[str]], Any]
665
+ ) -> Dict:
666
+ """Update component with new values."""
667
+ if isinstance(value, pd.DataFrame):
668
+ # Get all column names and convert to display names
669
+ choices = [self.column_mapping.get(col, col) for col in value.columns]
670
+
671
+ # Use initial selection or default columns
672
+ selected = self.initial_selected if self.initial_selected else choices
673
+
674
+ # Get dynamically filtered groups
675
+ filtered_cols = self.get_filtered_groups(value)
676
+
677
+ return {
678
+ "choices": choices,
679
+ "value": selected,
680
+ "filtered_cols": filtered_cols
681
+ }
682
+
683
+ # Handle fields object
684
+ if hasattr(value, '__dataclass_fields__'):
685
+ field_names = [field.name for field in fields(value)]
686
+ choices = [self.column_mapping.get(name, name) for name in field_names]
687
+ return {
688
+ "choices": choices,
689
+ "value": self.initial_selected if self.initial_selected else choices
690
+ }
691
+
692
+ return super().update(value)
693
+
694
+
695
+ # Example usage
696
+ if __name__ == "__main__":
697
+ # Sample DataFrame
698
+ df = pd.DataFrame({
699
+ "eval_name": ["test1", "test2", "test3"],
700
+ "Method": ["method1", "method2", "method3"],
701
+ "ioi_llama3": [0.1, 0.2, 0.3],
702
+ "ioi_qwen2_5": [0.4, 0.5, 0.6],
703
+ "ioi_gpt2": [0.7, 0.8, 0.9],
704
+ "mcqa_llama3": [0.2, 0.3, 0.4],
705
+ "Average": [0.35, 0.45, 0.55]
706
+ })
707
+
708
+ # Define keywords for filtering
709
+ benchmark_keywords = ["ioi", "mcqa", "arithmetic_addition", "arithmetic_subtraction", "arc_easy", "arc_challenge"]
710
+ model_keywords = ["qwen2_5", "gpt2", "gemma2", "llama3"]
711
+
712
+ # Optional: Define display names
713
+ mappings = {
714
+ "ioi_llama3": "IOI (LLaMA-3)",
715
+ "ioi_qwen2_5": "IOI (Qwen-2.5)",
716
+ "ioi_gpt2": "IOI (GPT-2)",
717
+ "ioi_gemma2": "IOI (Gemma-2)",
718
+ "mcqa_llama3": "MCQA (LLaMA-3)",
719
+ "mcqa_qwen2_5": "MCQA (Qwen-2.5)",
720
+ "mcqa_gemma2": "MCQA (Gemma-2)",
721
+ "arithmetic_addition_llama3": "Arithmetic Addition (LLaMA-3)",
722
+ "arithmetic_subtraction_llama3": "Arithmetic Subtraction (LLaMA-3)",
723
+ "arc_easy_llama3": "ARC Easy (LLaMA-3)",
724
+ "arc_easy_gemma2": "ARC Easy (Gemma-2)",
725
+ "arc_challenge_llama3": "ARC Challenge (LLaMA-3)",
726
+ "eval_name": "Evaluation Name",
727
+ "Method": "Method",
728
+ "Average": "Average Score"
729
+ }
730
+
731
+ # Create SmartSelectColumns instance
732
+ smart_columns = SmartSelectColumns(
733
+ benchmark_keywords=benchmark_keywords,
734
+ model_keywords=model_keywords,
735
+ column_mapping=mappings,
736
+ initial_selected=["Method", "Average"]
737
+ )
738
+
739
+ # Create Leaderboard
740
+ leaderboard = Leaderboard(
741
+ value=df,
742
+ datatype=[c.type for c in fields(AutoEvalColumn_mib_subgraph)],
743
+ select_columns=smart_columns,
744
+ search_columns=["Method"],
745
+ hide_columns=[],
746
+ interactive=False
747
+ )
748
+
749
+
750
+
751
+
752
+
753
+
754
+
755
+
756
+
757
+
758
+
759
+
760
+
761
+
762
+
763
+
764
+
765
+
766
 
767
 
768