Anas Awadalla
commited on
Commit
·
628f62f
1
Parent(s):
c148460
fix caching of elements
Browse files- src/streamlit_app.py +5 -284
src/streamlit_app.py
CHANGED
@@ -411,19 +411,12 @@ def main():
|
|
411 |
|
412 |
if selected_model != 'All':
|
413 |
filtered_df = filtered_df[filtered_df['model'] == selected_model]
|
414 |
-
|
415 |
-
# Create placeholders for components that update when dataset or metric changes
|
416 |
-
header_placeholder = st.empty()
|
417 |
-
metrics_placeholder = st.empty()
|
418 |
-
chart_placeholder = st.empty()
|
419 |
-
view_metrics_expander_placeholder = st.empty()
|
420 |
-
progression_expander_placeholder = st.empty()
|
421 |
-
|
422 |
# Main content
|
423 |
-
|
424 |
|
425 |
# Overall metrics
|
426 |
-
col1, col2, col3 =
|
427 |
with col1:
|
428 |
st.metric("Models Evaluated", len(filtered_df))
|
429 |
with col2:
|
@@ -482,281 +475,9 @@ def main():
|
|
482 |
# Create single chart for selected metric
|
483 |
chart = create_bar_chart(ui_metrics_df, selected_metric, metric_options[selected_metric])
|
484 |
if chart:
|
485 |
-
|
486 |
else:
|
487 |
st.warning(f"No data available for {metric_options[selected_metric]}")
|
488 |
-
|
489 |
-
# Show all metrics in an expandable section - available for all datasets
|
490 |
-
with view_metrics_expander_placeholder.expander("View All Metrics"):
|
491 |
-
if selected_dataset == 'screenspot-v2':
|
492 |
-
# First row: Overall, Desktop, Web averages
|
493 |
-
col1, col2, col3 = st.columns(3)
|
494 |
-
|
495 |
-
with col1:
|
496 |
-
chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average (Desktop + Web) / 2')
|
497 |
-
if chart:
|
498 |
-
st.altair_chart(chart, use_container_width=True)
|
499 |
-
|
500 |
-
with col2:
|
501 |
-
chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
|
502 |
-
if chart:
|
503 |
-
st.altair_chart(chart, use_container_width=True)
|
504 |
-
|
505 |
-
with col3:
|
506 |
-
chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
|
507 |
-
if chart:
|
508 |
-
st.altair_chart(chart, use_container_width=True)
|
509 |
-
|
510 |
-
# Second row: Individual UI type metrics
|
511 |
-
col1, col2, col3, col4 = st.columns(4)
|
512 |
-
|
513 |
-
with col1:
|
514 |
-
chart = create_bar_chart(ui_metrics_df, 'desktop_text', 'Desktop (Text)')
|
515 |
-
if chart:
|
516 |
-
st.altair_chart(chart, use_container_width=True)
|
517 |
-
|
518 |
-
with col2:
|
519 |
-
chart = create_bar_chart(ui_metrics_df, 'desktop_icon', 'Desktop (Icon)')
|
520 |
-
if chart:
|
521 |
-
st.altair_chart(chart, use_container_width=True)
|
522 |
-
|
523 |
-
with col3:
|
524 |
-
chart = create_bar_chart(ui_metrics_df, 'web_text', 'Web (Text)')
|
525 |
-
if chart:
|
526 |
-
st.altair_chart(chart, use_container_width=True)
|
527 |
-
|
528 |
-
with col4:
|
529 |
-
chart = create_bar_chart(ui_metrics_df, 'web_icon', 'Web (Icon)')
|
530 |
-
if chart:
|
531 |
-
st.altair_chart(chart, use_container_width=True)
|
532 |
-
|
533 |
-
# Third row: Text vs Icon averages
|
534 |
-
col1, col2 = st.columns(2)
|
535 |
-
|
536 |
-
with col1:
|
537 |
-
chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (Desktop + Web)')
|
538 |
-
if chart:
|
539 |
-
st.altair_chart(chart, use_container_width=True)
|
540 |
-
|
541 |
-
with col2:
|
542 |
-
chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (Desktop + Web)')
|
543 |
-
if chart:
|
544 |
-
st.altair_chart(chart, use_container_width=True)
|
545 |
-
else:
|
546 |
-
# For screenspot-pro and showdown-clicks
|
547 |
-
st.info("No additional UI type metrics available for this dataset. Only overall accuracy is reported.")
|
548 |
-
|
549 |
-
# Checkpoint progression visualization
|
550 |
-
with progression_expander_placeholder.expander("Checkpoint Progression Analysis"):
|
551 |
-
# Select a model with checkpoints
|
552 |
-
models_with_checkpoints = ui_metrics_df[ui_metrics_df['all_checkpoints'].apply(lambda x: len(x) > 1)]
|
553 |
-
|
554 |
-
if not models_with_checkpoints.empty:
|
555 |
-
selected_checkpoint_model = st.selectbox(
|
556 |
-
"Select a model to view checkpoint progression:",
|
557 |
-
models_with_checkpoints['model'].str.replace('*', '').unique()
|
558 |
-
)
|
559 |
-
|
560 |
-
# Get checkpoint data for selected model
|
561 |
-
model_row = models_with_checkpoints[models_with_checkpoints['model'].str.replace('*', '') == selected_checkpoint_model].iloc[0]
|
562 |
-
checkpoint_data = model_row['all_checkpoints']
|
563 |
-
|
564 |
-
# Create DataFrame from checkpoint data
|
565 |
-
checkpoint_df = pd.DataFrame(checkpoint_data)
|
566 |
-
|
567 |
-
# Prepare data for visualization
|
568 |
-
checkpoint_metrics = []
|
569 |
-
for _, cp in checkpoint_df.iterrows():
|
570 |
-
ui_results = cp.get('ui_type_results', {})
|
571 |
-
dataset_type_results = cp.get('dataset_type_results', {})
|
572 |
-
results_by_file = cp.get('results_by_file', {})
|
573 |
-
|
574 |
-
# Check if we have desktop/web breakdown in results_by_file
|
575 |
-
desktop_file = None
|
576 |
-
web_file = None
|
577 |
-
|
578 |
-
for filename, file_results in results_by_file.items():
|
579 |
-
if 'desktop' in filename.lower():
|
580 |
-
desktop_file = file_results
|
581 |
-
elif 'web' in filename.lower():
|
582 |
-
web_file = file_results
|
583 |
-
|
584 |
-
if desktop_file and web_file:
|
585 |
-
# We have desktop/web breakdown
|
586 |
-
desktop_text = desktop_file.get('by_ui_type', {}).get('text', {}).get('correct', 0) / max(desktop_file.get('by_ui_type', {}).get('text', {}).get('total', 1), 1) * 100
|
587 |
-
desktop_icon = desktop_file.get('by_ui_type', {}).get('icon', {}).get('correct', 0) / max(desktop_file.get('by_ui_type', {}).get('icon', {}).get('total', 1), 1) * 100
|
588 |
-
web_text = web_file.get('by_ui_type', {}).get('text', {}).get('correct', 0) / max(web_file.get('by_ui_type', {}).get('text', {}).get('total', 1), 1) * 100
|
589 |
-
web_icon = web_file.get('by_ui_type', {}).get('icon', {}).get('correct', 0) / max(web_file.get('by_ui_type', {}).get('icon', {}).get('total', 1), 1) * 100
|
590 |
-
else:
|
591 |
-
# Fallback to simple UI type results
|
592 |
-
desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
|
593 |
-
desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
|
594 |
-
web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
|
595 |
-
web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
|
596 |
-
|
597 |
-
# If still all zeros, try dataset_type_results
|
598 |
-
if desktop_text == 0 and desktop_icon == 0 and web_text == 0 and web_icon == 0:
|
599 |
-
for dataset_key in dataset_type_results:
|
600 |
-
if 'screenspot' in dataset_key.lower():
|
601 |
-
dataset_data = dataset_type_results[dataset_key]
|
602 |
-
if 'by_ui_type' in dataset_data:
|
603 |
-
ui_data = dataset_data['by_ui_type']
|
604 |
-
# For simple text/icon without desktop/web
|
605 |
-
text_val = ui_data.get('text', {}).get('correct', 0) / max(ui_data.get('text', {}).get('total', 1), 1) * 100
|
606 |
-
icon_val = ui_data.get('icon', {}).get('correct', 0) / max(ui_data.get('icon', {}).get('total', 1), 1) * 100
|
607 |
-
# Assign same values to desktop and web as we don't have the breakdown
|
608 |
-
desktop_text = web_text = text_val
|
609 |
-
desktop_icon = web_icon = icon_val
|
610 |
-
break
|
611 |
-
|
612 |
-
desktop_avg = (desktop_text + desktop_icon) / 2
|
613 |
-
web_avg = (web_text + web_icon) / 2
|
614 |
-
text_avg = (desktop_text + web_text) / 2
|
615 |
-
icon_avg = (desktop_icon + web_icon) / 2
|
616 |
-
overall = (desktop_avg + web_avg) / 2 if selected_dataset == 'screenspot-v2' else cp['overall_accuracy']
|
617 |
-
|
618 |
-
checkpoint_metrics.append({
|
619 |
-
'steps': cp['checkpoint_steps'] or 0,
|
620 |
-
'overall': overall,
|
621 |
-
'desktop_avg': desktop_avg,
|
622 |
-
'web_avg': web_avg,
|
623 |
-
'desktop_text': desktop_text,
|
624 |
-
'desktop_icon': desktop_icon,
|
625 |
-
'web_text': web_text,
|
626 |
-
'web_icon': web_icon,
|
627 |
-
'text_avg': text_avg,
|
628 |
-
'icon_avg': icon_avg,
|
629 |
-
'loss': cp['training_loss'],
|
630 |
-
'neg_log_loss': -np.log(cp['training_loss']) if cp['training_loss'] and cp['training_loss'] > 0 else None
|
631 |
-
})
|
632 |
-
|
633 |
-
metrics_df = pd.DataFrame(checkpoint_metrics).sort_values('steps')
|
634 |
-
|
635 |
-
# Plot metrics over training steps
|
636 |
-
col1, col2 = st.columns(2)
|
637 |
-
|
638 |
-
with col1:
|
639 |
-
st.write("**Accuracy over Training Steps**")
|
640 |
-
|
641 |
-
# Determine which metrics to show based on selected metric
|
642 |
-
if selected_metric == 'overall':
|
643 |
-
# Show overall, desktop, and web averages
|
644 |
-
metrics_to_show = ['overall', 'desktop_avg', 'web_avg']
|
645 |
-
metric_labels = ['Overall', 'Desktop Avg', 'Web Avg']
|
646 |
-
colors = ['#4ECDC4', '#45B7D1', '#96CEB4']
|
647 |
-
elif 'desktop' in selected_metric:
|
648 |
-
# Show all desktop metrics
|
649 |
-
metrics_to_show = ['desktop_avg', 'desktop_text', 'desktop_icon']
|
650 |
-
metric_labels = ['Desktop Avg', 'Desktop Text', 'Desktop Icon']
|
651 |
-
colors = ['#45B7D1', '#FFA726', '#FF6B6B']
|
652 |
-
elif 'web' in selected_metric:
|
653 |
-
# Show all web metrics
|
654 |
-
metrics_to_show = ['web_avg', 'web_text', 'web_icon']
|
655 |
-
metric_labels = ['Web Avg', 'Web Text', 'Web Icon']
|
656 |
-
colors = ['#96CEB4', '#9C27B0', '#E91E63']
|
657 |
-
elif 'text' in selected_metric:
|
658 |
-
# Show text metrics across environments
|
659 |
-
metrics_to_show = ['text_avg', 'desktop_text', 'web_text']
|
660 |
-
metric_labels = ['Text Avg', 'Desktop Text', 'Web Text']
|
661 |
-
colors = ['#FF9800', '#FFA726', '#FFB74D']
|
662 |
-
elif 'icon' in selected_metric:
|
663 |
-
# Show icon metrics across environments
|
664 |
-
metrics_to_show = ['icon_avg', 'desktop_icon', 'web_icon']
|
665 |
-
metric_labels = ['Icon Avg', 'Desktop Icon', 'Web Icon']
|
666 |
-
colors = ['#3F51B5', '#5C6BC0', '#7986CB']
|
667 |
-
else:
|
668 |
-
# Default: just show the selected metric
|
669 |
-
metrics_to_show = [selected_metric]
|
670 |
-
metric_labels = [metric_options.get(selected_metric, selected_metric)]
|
671 |
-
colors = ['#4ECDC4']
|
672 |
-
|
673 |
-
# Create multi-line chart data
|
674 |
-
chart_data = []
|
675 |
-
for i, (metric, label) in enumerate(zip(metrics_to_show, metric_labels)):
|
676 |
-
for _, row in metrics_df.iterrows():
|
677 |
-
if metric in row:
|
678 |
-
chart_data.append({
|
679 |
-
'steps': row['steps'],
|
680 |
-
'value': row[metric],
|
681 |
-
'metric': label,
|
682 |
-
'color_idx': i
|
683 |
-
})
|
684 |
-
|
685 |
-
if chart_data:
|
686 |
-
chart_df = pd.DataFrame(chart_data)
|
687 |
-
|
688 |
-
# Create multi-line chart with distinct colors
|
689 |
-
chart = alt.Chart(chart_df).mark_line(point=True, strokeWidth=2).encode(
|
690 |
-
x=alt.X('steps:Q', title='Training Steps'),
|
691 |
-
y=alt.Y('value:Q', scale=alt.Scale(domain=[0, 100]), title='Accuracy (%)'),
|
692 |
-
color=alt.Color('metric:N',
|
693 |
-
scale=alt.Scale(domain=metric_labels, range=colors),
|
694 |
-
legend=alt.Legend(title="Metric")),
|
695 |
-
tooltip=['steps:Q', 'metric:N', alt.Tooltip('value:Q', format='.1f', title='Accuracy')]
|
696 |
-
).properties(
|
697 |
-
width=500,
|
698 |
-
height=400,
|
699 |
-
title='Accuracy Progression During Training'
|
700 |
-
)
|
701 |
-
st.altair_chart(chart, use_container_width=True)
|
702 |
-
else:
|
703 |
-
st.warning("No data available for the selected metrics")
|
704 |
-
|
705 |
-
with col2:
|
706 |
-
st.write(f"**{metric_options[selected_metric]} vs. Training Loss**")
|
707 |
-
|
708 |
-
if metrics_df['neg_log_loss'].notna().any():
|
709 |
-
scatter_data = metrics_df[metrics_df['neg_log_loss'].notna()]
|
710 |
-
|
711 |
-
chart = alt.Chart(scatter_data).mark_circle(size=100).encode(
|
712 |
-
x=alt.X('neg_log_loss:Q', title='-log(Training Loss)'),
|
713 |
-
y=alt.Y(f'{selected_metric}:Q', scale=alt.Scale(domain=[0, 100]), title=f'{metric_options[selected_metric]} (%)'),
|
714 |
-
color=alt.Color('steps:Q', scale=alt.Scale(scheme='viridis'), title='Training Steps'),
|
715 |
-
tooltip=['steps', 'loss', selected_metric]
|
716 |
-
).properties(
|
717 |
-
width=500, # Increased from 400
|
718 |
-
height=400, # Increased from 300
|
719 |
-
title=f'{metric_options[selected_metric]} vs. -log(Training Loss)'
|
720 |
-
)
|
721 |
-
st.altair_chart(chart, use_container_width=True)
|
722 |
-
else:
|
723 |
-
st.info("No training loss data available for this model")
|
724 |
-
|
725 |
-
# Show checkpoint details table with selected metric
|
726 |
-
st.write("**Checkpoint Details**")
|
727 |
-
|
728 |
-
# Determine columns to display based on selected metric category
|
729 |
-
if selected_metric == 'overall':
|
730 |
-
display_cols = ['steps', 'overall', 'desktop_avg', 'web_avg', 'loss']
|
731 |
-
col_labels = ['Steps', 'Overall %', 'Desktop Avg %', 'Web Avg %', 'Training Loss']
|
732 |
-
elif 'desktop' in selected_metric:
|
733 |
-
display_cols = ['steps', 'desktop_avg', 'desktop_text', 'desktop_icon', 'loss']
|
734 |
-
col_labels = ['Steps', 'Desktop Avg %', 'Desktop Text %', 'Desktop Icon %', 'Training Loss']
|
735 |
-
elif 'web' in selected_metric:
|
736 |
-
display_cols = ['steps', 'web_avg', 'web_text', 'web_icon', 'loss']
|
737 |
-
col_labels = ['Steps', 'Web Avg %', 'Web Text %', 'Web Icon %', 'Training Loss']
|
738 |
-
elif 'text' in selected_metric:
|
739 |
-
display_cols = ['steps', 'text_avg', 'desktop_text', 'web_text', 'loss']
|
740 |
-
col_labels = ['Steps', 'Text Avg %', 'Desktop Text %', 'Web Text %', 'Training Loss']
|
741 |
-
elif 'icon' in selected_metric:
|
742 |
-
display_cols = ['steps', 'icon_avg', 'desktop_icon', 'web_icon', 'loss']
|
743 |
-
col_labels = ['Steps', 'Icon Avg %', 'Desktop Icon %', 'Web Icon %', 'Training Loss']
|
744 |
-
else:
|
745 |
-
display_cols = ['steps', selected_metric, 'loss']
|
746 |
-
col_labels = ['Steps', f'{metric_options[selected_metric]} %', 'Training Loss']
|
747 |
-
|
748 |
-
display_metrics = metrics_df[display_cols].copy()
|
749 |
-
display_metrics.columns = col_labels
|
750 |
-
|
751 |
-
# Format percentage columns
|
752 |
-
for col in col_labels:
|
753 |
-
if '%' in col and col != 'Training Loss':
|
754 |
-
display_metrics[col] = display_metrics[col].round(2)
|
755 |
-
|
756 |
-
display_metrics['Training Loss'] = display_metrics['Training Loss'].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
|
757 |
-
st.dataframe(display_metrics, use_container_width=True)
|
758 |
-
else:
|
759 |
-
st.info("No models with multiple checkpoints available for progression analysis")
|
760 |
|
761 |
else:
|
762 |
# For non-ScreenSpot datasets, show a simple bar chart
|
@@ -772,7 +493,7 @@ def main():
|
|
772 |
height=400
|
773 |
)
|
774 |
|
775 |
-
|
776 |
|
777 |
if __name__ == "__main__":
|
778 |
main()
|
|
|
411 |
|
412 |
if selected_model != 'All':
|
413 |
filtered_df = filtered_df[filtered_df['model'] == selected_model]
|
414 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
# Main content
|
416 |
+
st.header(f"Results for {selected_dataset}")
|
417 |
|
418 |
# Overall metrics
|
419 |
+
col1, col2, col3 = st.columns(3)
|
420 |
with col1:
|
421 |
st.metric("Models Evaluated", len(filtered_df))
|
422 |
with col2:
|
|
|
475 |
# Create single chart for selected metric
|
476 |
chart = create_bar_chart(ui_metrics_df, selected_metric, metric_options[selected_metric])
|
477 |
if chart:
|
478 |
+
st.altair_chart(chart, use_container_width=True)
|
479 |
else:
|
480 |
st.warning(f"No data available for {metric_options[selected_metric]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
else:
|
483 |
# For non-ScreenSpot datasets, show a simple bar chart
|
|
|
493 |
height=400
|
494 |
)
|
495 |
|
496 |
+
st.altair_chart(chart, use_container_width=True)
|
497 |
|
498 |
if __name__ == "__main__":
|
499 |
main()
|