Anas Awadalla commited on
Commit
628f62f
·
1 Parent(s): c148460

fix caching of elements

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +5 -284
src/streamlit_app.py CHANGED
@@ -411,19 +411,12 @@ def main():
411
 
412
  if selected_model != 'All':
413
  filtered_df = filtered_df[filtered_df['model'] == selected_model]
414
-
415
- # Create placeholders for components that update when dataset or metric changes
416
- header_placeholder = st.empty()
417
- metrics_placeholder = st.empty()
418
- chart_placeholder = st.empty()
419
- view_metrics_expander_placeholder = st.empty()
420
- progression_expander_placeholder = st.empty()
421
-
422
  # Main content
423
- header_placeholder.header(f"Results for {selected_dataset}")
424
 
425
  # Overall metrics
426
- col1, col2, col3 = metrics_placeholder.columns(3)
427
  with col1:
428
  st.metric("Models Evaluated", len(filtered_df))
429
  with col2:
@@ -482,281 +475,9 @@ def main():
482
  # Create single chart for selected metric
483
  chart = create_bar_chart(ui_metrics_df, selected_metric, metric_options[selected_metric])
484
  if chart:
485
- chart_placeholder.altair_chart(chart, use_container_width=True)
486
  else:
487
  st.warning(f"No data available for {metric_options[selected_metric]}")
488
-
489
- # Show all metrics in an expandable section - available for all datasets
490
- with view_metrics_expander_placeholder.expander("View All Metrics"):
491
- if selected_dataset == 'screenspot-v2':
492
- # First row: Overall, Desktop, Web averages
493
- col1, col2, col3 = st.columns(3)
494
-
495
- with col1:
496
- chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average (Desktop + Web) / 2')
497
- if chart:
498
- st.altair_chart(chart, use_container_width=True)
499
-
500
- with col2:
501
- chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
502
- if chart:
503
- st.altair_chart(chart, use_container_width=True)
504
-
505
- with col3:
506
- chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
507
- if chart:
508
- st.altair_chart(chart, use_container_width=True)
509
-
510
- # Second row: Individual UI type metrics
511
- col1, col2, col3, col4 = st.columns(4)
512
-
513
- with col1:
514
- chart = create_bar_chart(ui_metrics_df, 'desktop_text', 'Desktop (Text)')
515
- if chart:
516
- st.altair_chart(chart, use_container_width=True)
517
-
518
- with col2:
519
- chart = create_bar_chart(ui_metrics_df, 'desktop_icon', 'Desktop (Icon)')
520
- if chart:
521
- st.altair_chart(chart, use_container_width=True)
522
-
523
- with col3:
524
- chart = create_bar_chart(ui_metrics_df, 'web_text', 'Web (Text)')
525
- if chart:
526
- st.altair_chart(chart, use_container_width=True)
527
-
528
- with col4:
529
- chart = create_bar_chart(ui_metrics_df, 'web_icon', 'Web (Icon)')
530
- if chart:
531
- st.altair_chart(chart, use_container_width=True)
532
-
533
- # Third row: Text vs Icon averages
534
- col1, col2 = st.columns(2)
535
-
536
- with col1:
537
- chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (Desktop + Web)')
538
- if chart:
539
- st.altair_chart(chart, use_container_width=True)
540
-
541
- with col2:
542
- chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (Desktop + Web)')
543
- if chart:
544
- st.altair_chart(chart, use_container_width=True)
545
- else:
546
- # For screenspot-pro and showdown-clicks
547
- st.info("No additional UI type metrics available for this dataset. Only overall accuracy is reported.")
548
-
549
- # Checkpoint progression visualization
550
- with progression_expander_placeholder.expander("Checkpoint Progression Analysis"):
551
- # Select a model with checkpoints
552
- models_with_checkpoints = ui_metrics_df[ui_metrics_df['all_checkpoints'].apply(lambda x: len(x) > 1)]
553
-
554
- if not models_with_checkpoints.empty:
555
- selected_checkpoint_model = st.selectbox(
556
- "Select a model to view checkpoint progression:",
557
- models_with_checkpoints['model'].str.replace('*', '').unique()
558
- )
559
-
560
- # Get checkpoint data for selected model
561
- model_row = models_with_checkpoints[models_with_checkpoints['model'].str.replace('*', '') == selected_checkpoint_model].iloc[0]
562
- checkpoint_data = model_row['all_checkpoints']
563
-
564
- # Create DataFrame from checkpoint data
565
- checkpoint_df = pd.DataFrame(checkpoint_data)
566
-
567
- # Prepare data for visualization
568
- checkpoint_metrics = []
569
- for _, cp in checkpoint_df.iterrows():
570
- ui_results = cp.get('ui_type_results', {})
571
- dataset_type_results = cp.get('dataset_type_results', {})
572
- results_by_file = cp.get('results_by_file', {})
573
-
574
- # Check if we have desktop/web breakdown in results_by_file
575
- desktop_file = None
576
- web_file = None
577
-
578
- for filename, file_results in results_by_file.items():
579
- if 'desktop' in filename.lower():
580
- desktop_file = file_results
581
- elif 'web' in filename.lower():
582
- web_file = file_results
583
-
584
- if desktop_file and web_file:
585
- # We have desktop/web breakdown
586
- desktop_text = desktop_file.get('by_ui_type', {}).get('text', {}).get('correct', 0) / max(desktop_file.get('by_ui_type', {}).get('text', {}).get('total', 1), 1) * 100
587
- desktop_icon = desktop_file.get('by_ui_type', {}).get('icon', {}).get('correct', 0) / max(desktop_file.get('by_ui_type', {}).get('icon', {}).get('total', 1), 1) * 100
588
- web_text = web_file.get('by_ui_type', {}).get('text', {}).get('correct', 0) / max(web_file.get('by_ui_type', {}).get('text', {}).get('total', 1), 1) * 100
589
- web_icon = web_file.get('by_ui_type', {}).get('icon', {}).get('correct', 0) / max(web_file.get('by_ui_type', {}).get('icon', {}).get('total', 1), 1) * 100
590
- else:
591
- # Fallback to simple UI type results
592
- desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
593
- desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
594
- web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
595
- web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
596
-
597
- # If still all zeros, try dataset_type_results
598
- if desktop_text == 0 and desktop_icon == 0 and web_text == 0 and web_icon == 0:
599
- for dataset_key in dataset_type_results:
600
- if 'screenspot' in dataset_key.lower():
601
- dataset_data = dataset_type_results[dataset_key]
602
- if 'by_ui_type' in dataset_data:
603
- ui_data = dataset_data['by_ui_type']
604
- # For simple text/icon without desktop/web
605
- text_val = ui_data.get('text', {}).get('correct', 0) / max(ui_data.get('text', {}).get('total', 1), 1) * 100
606
- icon_val = ui_data.get('icon', {}).get('correct', 0) / max(ui_data.get('icon', {}).get('total', 1), 1) * 100
607
- # Assign same values to desktop and web as we don't have the breakdown
608
- desktop_text = web_text = text_val
609
- desktop_icon = web_icon = icon_val
610
- break
611
-
612
- desktop_avg = (desktop_text + desktop_icon) / 2
613
- web_avg = (web_text + web_icon) / 2
614
- text_avg = (desktop_text + web_text) / 2
615
- icon_avg = (desktop_icon + web_icon) / 2
616
- overall = (desktop_avg + web_avg) / 2 if selected_dataset == 'screenspot-v2' else cp['overall_accuracy']
617
-
618
- checkpoint_metrics.append({
619
- 'steps': cp['checkpoint_steps'] or 0,
620
- 'overall': overall,
621
- 'desktop_avg': desktop_avg,
622
- 'web_avg': web_avg,
623
- 'desktop_text': desktop_text,
624
- 'desktop_icon': desktop_icon,
625
- 'web_text': web_text,
626
- 'web_icon': web_icon,
627
- 'text_avg': text_avg,
628
- 'icon_avg': icon_avg,
629
- 'loss': cp['training_loss'],
630
- 'neg_log_loss': -np.log(cp['training_loss']) if cp['training_loss'] and cp['training_loss'] > 0 else None
631
- })
632
-
633
- metrics_df = pd.DataFrame(checkpoint_metrics).sort_values('steps')
634
-
635
- # Plot metrics over training steps
636
- col1, col2 = st.columns(2)
637
-
638
- with col1:
639
- st.write("**Accuracy over Training Steps**")
640
-
641
- # Determine which metrics to show based on selected metric
642
- if selected_metric == 'overall':
643
- # Show overall, desktop, and web averages
644
- metrics_to_show = ['overall', 'desktop_avg', 'web_avg']
645
- metric_labels = ['Overall', 'Desktop Avg', 'Web Avg']
646
- colors = ['#4ECDC4', '#45B7D1', '#96CEB4']
647
- elif 'desktop' in selected_metric:
648
- # Show all desktop metrics
649
- metrics_to_show = ['desktop_avg', 'desktop_text', 'desktop_icon']
650
- metric_labels = ['Desktop Avg', 'Desktop Text', 'Desktop Icon']
651
- colors = ['#45B7D1', '#FFA726', '#FF6B6B']
652
- elif 'web' in selected_metric:
653
- # Show all web metrics
654
- metrics_to_show = ['web_avg', 'web_text', 'web_icon']
655
- metric_labels = ['Web Avg', 'Web Text', 'Web Icon']
656
- colors = ['#96CEB4', '#9C27B0', '#E91E63']
657
- elif 'text' in selected_metric:
658
- # Show text metrics across environments
659
- metrics_to_show = ['text_avg', 'desktop_text', 'web_text']
660
- metric_labels = ['Text Avg', 'Desktop Text', 'Web Text']
661
- colors = ['#FF9800', '#FFA726', '#FFB74D']
662
- elif 'icon' in selected_metric:
663
- # Show icon metrics across environments
664
- metrics_to_show = ['icon_avg', 'desktop_icon', 'web_icon']
665
- metric_labels = ['Icon Avg', 'Desktop Icon', 'Web Icon']
666
- colors = ['#3F51B5', '#5C6BC0', '#7986CB']
667
- else:
668
- # Default: just show the selected metric
669
- metrics_to_show = [selected_metric]
670
- metric_labels = [metric_options.get(selected_metric, selected_metric)]
671
- colors = ['#4ECDC4']
672
-
673
- # Create multi-line chart data
674
- chart_data = []
675
- for i, (metric, label) in enumerate(zip(metrics_to_show, metric_labels)):
676
- for _, row in metrics_df.iterrows():
677
- if metric in row:
678
- chart_data.append({
679
- 'steps': row['steps'],
680
- 'value': row[metric],
681
- 'metric': label,
682
- 'color_idx': i
683
- })
684
-
685
- if chart_data:
686
- chart_df = pd.DataFrame(chart_data)
687
-
688
- # Create multi-line chart with distinct colors
689
- chart = alt.Chart(chart_df).mark_line(point=True, strokeWidth=2).encode(
690
- x=alt.X('steps:Q', title='Training Steps'),
691
- y=alt.Y('value:Q', scale=alt.Scale(domain=[0, 100]), title='Accuracy (%)'),
692
- color=alt.Color('metric:N',
693
- scale=alt.Scale(domain=metric_labels, range=colors),
694
- legend=alt.Legend(title="Metric")),
695
- tooltip=['steps:Q', 'metric:N', alt.Tooltip('value:Q', format='.1f', title='Accuracy')]
696
- ).properties(
697
- width=500,
698
- height=400,
699
- title='Accuracy Progression During Training'
700
- )
701
- st.altair_chart(chart, use_container_width=True)
702
- else:
703
- st.warning("No data available for the selected metrics")
704
-
705
- with col2:
706
- st.write(f"**{metric_options[selected_metric]} vs. Training Loss**")
707
-
708
- if metrics_df['neg_log_loss'].notna().any():
709
- scatter_data = metrics_df[metrics_df['neg_log_loss'].notna()]
710
-
711
- chart = alt.Chart(scatter_data).mark_circle(size=100).encode(
712
- x=alt.X('neg_log_loss:Q', title='-log(Training Loss)'),
713
- y=alt.Y(f'{selected_metric}:Q', scale=alt.Scale(domain=[0, 100]), title=f'{metric_options[selected_metric]} (%)'),
714
- color=alt.Color('steps:Q', scale=alt.Scale(scheme='viridis'), title='Training Steps'),
715
- tooltip=['steps', 'loss', selected_metric]
716
- ).properties(
717
- width=500, # Increased from 400
718
- height=400, # Increased from 300
719
- title=f'{metric_options[selected_metric]} vs. -log(Training Loss)'
720
- )
721
- st.altair_chart(chart, use_container_width=True)
722
- else:
723
- st.info("No training loss data available for this model")
724
-
725
- # Show checkpoint details table with selected metric
726
- st.write("**Checkpoint Details**")
727
-
728
- # Determine columns to display based on selected metric category
729
- if selected_metric == 'overall':
730
- display_cols = ['steps', 'overall', 'desktop_avg', 'web_avg', 'loss']
731
- col_labels = ['Steps', 'Overall %', 'Desktop Avg %', 'Web Avg %', 'Training Loss']
732
- elif 'desktop' in selected_metric:
733
- display_cols = ['steps', 'desktop_avg', 'desktop_text', 'desktop_icon', 'loss']
734
- col_labels = ['Steps', 'Desktop Avg %', 'Desktop Text %', 'Desktop Icon %', 'Training Loss']
735
- elif 'web' in selected_metric:
736
- display_cols = ['steps', 'web_avg', 'web_text', 'web_icon', 'loss']
737
- col_labels = ['Steps', 'Web Avg %', 'Web Text %', 'Web Icon %', 'Training Loss']
738
- elif 'text' in selected_metric:
739
- display_cols = ['steps', 'text_avg', 'desktop_text', 'web_text', 'loss']
740
- col_labels = ['Steps', 'Text Avg %', 'Desktop Text %', 'Web Text %', 'Training Loss']
741
- elif 'icon' in selected_metric:
742
- display_cols = ['steps', 'icon_avg', 'desktop_icon', 'web_icon', 'loss']
743
- col_labels = ['Steps', 'Icon Avg %', 'Desktop Icon %', 'Web Icon %', 'Training Loss']
744
- else:
745
- display_cols = ['steps', selected_metric, 'loss']
746
- col_labels = ['Steps', f'{metric_options[selected_metric]} %', 'Training Loss']
747
-
748
- display_metrics = metrics_df[display_cols].copy()
749
- display_metrics.columns = col_labels
750
-
751
- # Format percentage columns
752
- for col in col_labels:
753
- if '%' in col and col != 'Training Loss':
754
- display_metrics[col] = display_metrics[col].round(2)
755
-
756
- display_metrics['Training Loss'] = display_metrics['Training Loss'].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
757
- st.dataframe(display_metrics, use_container_width=True)
758
- else:
759
- st.info("No models with multiple checkpoints available for progression analysis")
760
 
761
  else:
762
  # For non-ScreenSpot datasets, show a simple bar chart
@@ -772,7 +493,7 @@ def main():
772
  height=400
773
  )
774
 
775
- chart_placeholder.altair_chart(chart, use_container_width=True)
776
 
777
  if __name__ == "__main__":
778
  main()
 
411
 
412
  if selected_model != 'All':
413
  filtered_df = filtered_df[filtered_df['model'] == selected_model]
414
+
 
 
 
 
 
 
 
415
  # Main content
416
+ st.header(f"Results for {selected_dataset}")
417
 
418
  # Overall metrics
419
+ col1, col2, col3 = st.columns(3)
420
  with col1:
421
  st.metric("Models Evaluated", len(filtered_df))
422
  with col2:
 
475
  # Create single chart for selected metric
476
  chart = create_bar_chart(ui_metrics_df, selected_metric, metric_options[selected_metric])
477
  if chart:
478
+ st.altair_chart(chart, use_container_width=True)
479
  else:
480
  st.warning(f"No data available for {metric_options[selected_metric]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
  else:
483
  # For non-ScreenSpot datasets, show a simple bar chart
 
493
  height=400
494
  )
495
 
496
+ st.altair_chart(chart, use_container_width=True)
497
 
498
  if __name__ == "__main__":
499
  main()