Anas Awadalla commited on
Commit
4f9fa17
·
1 Parent(s): fc25316

add subset avg for pro baselines

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +133 -8
src/streamlit_app.py CHANGED
@@ -366,11 +366,30 @@ def create_bar_chart(data: pd.DataFrame, metric: str, title: str):
366
  for baseline_name, baseline_metrics in BASELINES[dataset].items():
367
  metric_key = metric.replace('_avg', '').replace('avg', 'overall')
368
  if metric_key in baseline_metrics:
369
- chart_data.append({
370
- 'Model': baseline_name,
371
- 'Score': baseline_metrics[metric_key],
372
- 'Type': 'Baseline'
373
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
  if not chart_data:
376
  return None
@@ -565,6 +584,75 @@ def main():
565
  # If no models selected, show empty dataframe
566
  filtered_df = pd.DataFrame()
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  # Main content
569
  st.header(f"Results for {selected_dataset}")
570
 
@@ -589,6 +677,30 @@ def main():
589
  # Parse UI type metrics
590
  ui_metrics_df = parse_ui_type_metrics(filtered_df, selected_dataset)
591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  # Add metric selector for screenspot datasets
593
  selected_metric = 'overall' # Default metric
594
  if not ui_metrics_df.empty:
@@ -634,9 +746,9 @@ def main():
634
  # Display results table
635
  st.subheader("📊 Results Table")
636
 
637
- # Filter ui_metrics_df to only include selected models
638
  if not ui_metrics_df.empty:
639
- table_df = ui_metrics_df[ui_metrics_df['model'].isin(selected_models)].copy()
640
 
641
  # Add baselines to the table if available
642
  if selected_dataset in BASELINES:
@@ -672,7 +784,20 @@ def main():
672
  # For other datasets (showdown-clicks, etc.)
673
  baseline_row['overall'] = baseline_metrics.get('overall', 0)
674
 
675
- baseline_rows.append(baseline_row)
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  # Append baselines to table
678
  if baseline_rows:
 
366
  for baseline_name, baseline_metrics in BASELINES[dataset].items():
367
  metric_key = metric.replace('_avg', '').replace('avg', 'overall')
368
  if metric_key in baseline_metrics:
369
+ baseline_value = baseline_metrics[metric_key]
370
+
371
+ # Check performance bounds if filter is enabled
372
+ should_include = True
373
+ if st.session_state.get('perf_filter_enabled', False):
374
+ filter_metric = st.session_state.get('perf_filter_metric', 'overall')
375
+ min_perf = st.session_state.get('perf_filter_min', 0.0)
376
+ max_perf = st.session_state.get('perf_filter_max', 100.0)
377
+
378
+ # Only filter if we're filtering by the same metric being displayed
379
+ if filter_metric == metric and (baseline_value < min_perf or baseline_value > max_perf):
380
+ should_include = False
381
+ # Or if filtering by a different metric, check that metric's value
382
+ elif filter_metric != metric and filter_metric in baseline_metrics:
383
+ filter_value = baseline_metrics[filter_metric]
384
+ if filter_value < min_perf or filter_value > max_perf:
385
+ should_include = False
386
+
387
+ if should_include:
388
+ chart_data.append({
389
+ 'Model': baseline_name,
390
+ 'Score': baseline_value,
391
+ 'Type': 'Baseline'
392
+ })
393
 
394
  if not chart_data:
395
  return None
 
584
  # If no models selected, show empty dataframe
585
  filtered_df = pd.DataFrame()
586
 
587
+ # Performance bounds filter
588
+ st.sidebar.divider()
589
+ st.sidebar.subheader("Performance Filters")
590
+
591
+ # Enable/disable performance filtering
592
+ enable_perf_filter = st.sidebar.checkbox("Enable performance bounds", value=False)
593
+
594
+ if enable_perf_filter:
595
+ # Get the metric to filter on
596
+ filter_metric_help = "Filter models based on their performance in the selected metric"
597
+
598
+ # Determine available metrics for filtering
599
+ if selected_dataset == 'screenspot-v2':
600
+ filter_metrics = ['overall', 'desktop_text', 'desktop_icon', 'web_text', 'web_icon']
601
+ filter_metric_names = {
602
+ 'overall': 'Overall Average',
603
+ 'desktop_text': 'Desktop (Text)',
604
+ 'desktop_icon': 'Desktop (Icon)',
605
+ 'web_text': 'Web (Text)',
606
+ 'web_icon': 'Web (Icon)'
607
+ }
608
+ elif selected_dataset == 'screenspot-pro':
609
+ filter_metrics = ['overall', 'text', 'icon']
610
+ filter_metric_names = {
611
+ 'overall': 'Overall Average',
612
+ 'text': 'Text',
613
+ 'icon': 'Icon'
614
+ }
615
+ else:
616
+ filter_metrics = ['overall']
617
+ filter_metric_names = {'overall': 'Overall Average'}
618
+
619
+ # Metric selector for filtering
620
+ filter_metric = st.sidebar.selectbox(
621
+ "Filter by metric:",
622
+ options=filter_metrics,
623
+ format_func=lambda x: filter_metric_names[x],
624
+ help=filter_metric_help
625
+ )
626
+
627
+ # Performance bounds inputs
628
+ col1, col2 = st.sidebar.columns(2)
629
+ with col1:
630
+ min_perf = st.number_input(
631
+ "Min %",
632
+ min_value=0.0,
633
+ max_value=100.0,
634
+ value=0.0,
635
+ step=5.0,
636
+ help="Minimum performance threshold"
637
+ )
638
+ with col2:
639
+ max_perf = st.number_input(
640
+ "Max %",
641
+ min_value=0.0,
642
+ max_value=100.0,
643
+ value=100.0,
644
+ step=5.0,
645
+ help="Maximum performance threshold"
646
+ )
647
+
648
+ # Store filter settings in session state
649
+ st.session_state['perf_filter_enabled'] = True
650
+ st.session_state['perf_filter_metric'] = filter_metric
651
+ st.session_state['perf_filter_min'] = min_perf
652
+ st.session_state['perf_filter_max'] = max_perf
653
+ else:
654
+ st.session_state['perf_filter_enabled'] = False
655
+
656
  # Main content
657
  st.header(f"Results for {selected_dataset}")
658
 
 
677
  # Parse UI type metrics
678
  ui_metrics_df = parse_ui_type_metrics(filtered_df, selected_dataset)
679
 
680
+ # Apply performance bounds filter if enabled
681
+ if st.session_state.get('perf_filter_enabled', False) and not ui_metrics_df.empty:
682
+ filter_metric = st.session_state.get('perf_filter_metric', 'overall')
683
+ min_perf = st.session_state.get('perf_filter_min', 0.0)
684
+ max_perf = st.session_state.get('perf_filter_max', 100.0)
685
+
686
+ # Check if the filter metric exists in the dataframe
687
+ if filter_metric in ui_metrics_df.columns:
688
+ # Filter models based on performance bounds
689
+ ui_metrics_df = ui_metrics_df[
690
+ (ui_metrics_df[filter_metric] >= min_perf) &
691
+ (ui_metrics_df[filter_metric] <= max_perf)
692
+ ]
693
+
694
+ # Update selected models to only include those within bounds
695
+ models_in_bounds = ui_metrics_df['model'].tolist()
696
+ filtered_models = [m for m in selected_models if m in models_in_bounds]
697
+
698
+ # Show info about filtered models
699
+ total_models = len(selected_models)
700
+ shown_models = len(filtered_models)
701
+ if shown_models < total_models:
702
+ st.info(f"Showing {shown_models} of {total_models} selected models within performance bounds ({min_perf:.1f}% - {max_perf:.1f}% {filter_metric})")
703
+
704
  # Add metric selector for screenspot datasets
705
  selected_metric = 'overall' # Default metric
706
  if not ui_metrics_df.empty:
 
746
  # Display results table
747
  st.subheader("📊 Results Table")
748
 
749
+ # Use the already filtered ui_metrics_df which respects performance bounds
750
  if not ui_metrics_df.empty:
751
+ table_df = ui_metrics_df.copy()
752
 
753
  # Add baselines to the table if available
754
  if selected_dataset in BASELINES:
 
784
  # For other datasets (showdown-clicks, etc.)
785
  baseline_row['overall'] = baseline_metrics.get('overall', 0)
786
 
787
+ # Apply performance filter to baselines if enabled
788
+ should_include_baseline = True
789
+ if st.session_state.get('perf_filter_enabled', False):
790
+ filter_metric = st.session_state.get('perf_filter_metric', 'overall')
791
+ min_perf = st.session_state.get('perf_filter_min', 0.0)
792
+ max_perf = st.session_state.get('perf_filter_max', 100.0)
793
+
794
+ if filter_metric in baseline_row:
795
+ metric_value = baseline_row[filter_metric]
796
+ if metric_value < min_perf or metric_value > max_perf:
797
+ should_include_baseline = False
798
+
799
+ if should_include_baseline:
800
+ baseline_rows.append(baseline_row)
801
 
802
  # Append baselines to table
803
  if baseline_rows: