Anas Awadalla commited on
Commit
4db9f63
·
1 Parent(s): 79cb6e1

some fixes

Browse files
Files changed (2) hide show
  1. README.md +0 -2
  2. src/streamlit_app.py +325 -165
README.md CHANGED
@@ -26,7 +26,6 @@ A Streamlit application for visualizing model performance on grounding benchmark
26
  - For other datasets: Desktop vs Web and Text vs Icon performance
27
  - **Checkpoint Progression Analysis**: Visualize how metrics evolve during training
28
  - **Model Details**: View training loss, checkpoint steps, and evaluation timestamps
29
- - **Sample Results**: Inspect the first 5 evaluation samples for each model
30
 
31
  ## Installation
32
 
@@ -62,7 +61,6 @@ The app will open in your browser at `http://localhost:8501`
62
  4. **Explore Details**:
63
  - Expand "Model Details" to see training metadata
64
  - Expand "Detailed UI Type Breakdown" for a comprehensive table
65
- - Expand "Sample Results" to see the first 5 evaluation samples
66
  - Expand "Checkpoint Progression Analysis" to:
67
  - View accuracy progression over training steps
68
  - See the relationship between training loss and accuracy
 
26
  - For other datasets: Desktop vs Web and Text vs Icon performance
27
  - **Checkpoint Progression Analysis**: Visualize how metrics evolve during training
28
  - **Model Details**: View training loss, checkpoint steps, and evaluation timestamps
 
29
 
30
  ## Installation
31
 
 
61
  4. **Explore Details**:
62
  - Expand "Model Details" to see training metadata
63
  - Expand "Detailed UI Type Breakdown" for a comprehensive table
 
64
  - Expand "Checkpoint Progression Analysis" to:
65
  - View accuracy progression over training steps
66
  - See the relationship between training loss and accuracy
src/streamlit_app.py CHANGED
@@ -167,12 +167,7 @@ def fetch_leaderboard_data():
167
  "checkpoint_steps": metadata.get("checkpoint_steps"),
168
  "training_loss": metadata.get("training_loss"),
169
  "ui_type_results": ui_type_results,
170
- "dataset_type_results": dataset_type_results,
171
- # Store minimal sample results for inspection
172
- "sample_results_summary": {
173
- "total_samples": len(data.get("sample_results", [])),
174
- "first_5_samples": data.get("sample_results", [])[:5]
175
- }
176
  }
177
 
178
  results.append(result_entry)
@@ -242,16 +237,31 @@ def parse_ui_type_metrics(df: pd.DataFrame, dataset_filter: str) -> pd.DataFrame
242
  continue
243
 
244
  model = row['model']
245
- ui_results = row['ui_type_results']
 
246
 
247
  # For ScreenSpot datasets, we have desktop/web and text/icon
248
  if 'screenspot' in dataset_filter.lower():
249
- # Calculate individual metrics
250
  desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
251
  desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
252
  web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
253
  web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  # Calculate averages
256
  desktop_avg = (desktop_text + desktop_icon) / 2 if (desktop_text > 0 or desktop_icon > 0) else 0
257
  web_avg = (web_text + web_icon) / 2 if (web_text > 0 or web_icon > 0) else 0
@@ -260,7 +270,7 @@ def parse_ui_type_metrics(df: pd.DataFrame, dataset_filter: str) -> pd.DataFrame
260
 
261
  # For screenspot-v2, calculate the overall as average of desktop and web
262
  if dataset_filter == 'screenspot-v2':
263
- overall = (desktop_avg + web_avg) / 2 if (desktop_avg > 0 or web_avg > 0) else 0
264
  else:
265
  overall = row['overall_accuracy']
266
 
@@ -278,6 +288,14 @@ def parse_ui_type_metrics(df: pd.DataFrame, dataset_filter: str) -> pd.DataFrame
278
  'is_best_not_last': row.get('is_best_not_last', False),
279
  'all_checkpoints': row.get('all_checkpoints', [])
280
  })
 
 
 
 
 
 
 
 
281
 
282
  return pd.DataFrame(metrics_list)
283
 
@@ -326,8 +344,8 @@ def create_bar_chart(data: pd.DataFrame, metric: str, title: str):
326
  tooltip=['Model', 'Score', 'Type']
327
  ).properties(
328
  title=title,
329
- width=400,
330
- height=300
331
  )
332
 
333
  # Add value labels
@@ -374,6 +392,38 @@ def main():
374
  # Main content
375
  st.header(f"Results for {selected_dataset}")
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  # Overall metrics
378
  col1, col2, col3 = st.columns(3)
379
  with col1:
@@ -390,98 +440,137 @@ def main():
390
  # Parse UI type metrics
391
  ui_metrics_df = parse_ui_type_metrics(filtered_df, selected_dataset)
392
 
 
 
393
  if not ui_metrics_df.empty and 'screenspot' in selected_dataset.lower():
394
  st.subheader("Performance by UI Type")
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  # Add note about asterisks
397
  if any(ui_metrics_df['is_best_not_last']):
398
  st.info("* indicates the best checkpoint is not the last checkpoint")
399
 
400
- # Create charts in a grid
401
- if selected_dataset == 'screenspot-v2':
402
- # First row: Overall, Desktop, Web averages
403
- col1, col2, col3 = st.columns(3)
404
-
405
- with col1:
406
- chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average (Desktop + Web) / 2')
407
- if chart:
408
- st.altair_chart(chart, use_container_width=True)
409
-
410
- with col2:
411
- chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
412
- if chart:
413
- st.altair_chart(chart, use_container_width=True)
414
-
415
- with col3:
416
- chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
417
- if chart:
418
- st.altair_chart(chart, use_container_width=True)
419
-
420
- # Second row: Individual UI type metrics
421
- col1, col2, col3, col4 = st.columns(4)
422
-
423
- with col1:
424
- chart = create_bar_chart(ui_metrics_df, 'desktop_text', 'Desktop (Text)')
425
- if chart:
426
- st.altair_chart(chart, use_container_width=True)
427
-
428
- with col2:
429
- chart = create_bar_chart(ui_metrics_df, 'desktop_icon', 'Desktop (Icon)')
430
- if chart:
431
- st.altair_chart(chart, use_container_width=True)
432
-
433
- with col3:
434
- chart = create_bar_chart(ui_metrics_df, 'web_text', 'Web (Text)')
435
- if chart:
436
- st.altair_chart(chart, use_container_width=True)
437
-
438
- with col4:
439
- chart = create_bar_chart(ui_metrics_df, 'web_icon', 'Web (Icon)')
440
- if chart:
441
- st.altair_chart(chart, use_container_width=True)
442
-
443
- # Third row: Text vs Icon averages
444
- col1, col2 = st.columns(2)
445
-
446
- with col1:
447
- chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (Desktop + Web)')
448
- if chart:
449
- st.altair_chart(chart, use_container_width=True)
450
-
451
- with col2:
452
- chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (Desktop + Web)')
453
- if chart:
454
- st.altair_chart(chart, use_container_width=True)
455
  else:
456
- # For other screenspot datasets, show the standard layout
457
- col1, col2 = st.columns(2)
458
-
459
- with col1:
460
- # Overall Average
461
- chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average')
462
- if chart:
463
- st.altair_chart(chart, use_container_width=True)
464
-
465
- # Desktop Average
466
- chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
467
- if chart:
468
- st.altair_chart(chart, use_container_width=True)
469
-
470
- # Text Average
471
- chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (UI-Type)')
472
- if chart:
473
- st.altair_chart(chart, use_container_width=True)
474
-
475
- with col2:
476
- # Web Average
477
- chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
478
- if chart:
479
- st.altair_chart(chart, use_container_width=True)
480
-
481
- # Icon Average
482
- chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (UI-Type)')
483
- if chart:
484
- st.altair_chart(chart, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  # Checkpoint progression visualization
487
  with st.expander("Checkpoint Progression Analysis"):
@@ -504,23 +593,46 @@ def main():
504
  # Prepare data for visualization
505
  checkpoint_metrics = []
506
  for _, cp in checkpoint_df.iterrows():
507
- ui_results = cp['ui_type_results']
 
508
 
509
- # Calculate metrics
510
  desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
511
  desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
512
  web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
513
  web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  desktop_avg = (desktop_text + desktop_icon) / 2
516
  web_avg = (web_text + web_icon) / 2
 
 
517
  overall = (desktop_avg + web_avg) / 2 if selected_dataset == 'screenspot-v2' else cp['overall_accuracy']
518
 
519
  checkpoint_metrics.append({
520
  'steps': cp['checkpoint_steps'] or 0,
521
  'overall': overall,
522
- 'desktop': desktop_avg,
523
- 'web': web_avg,
 
 
 
 
 
 
524
  'loss': cp['training_loss'],
525
  'neg_log_loss': -np.log(cp['training_loss']) if cp['training_loss'] and cp['training_loss'] > 0 else None
526
  })
@@ -533,74 +645,143 @@ def main():
533
  with col1:
534
  st.write("**Accuracy over Training Steps**")
535
 
536
- # Melt data for multi-line chart
537
- melted = metrics_df[['steps', 'overall', 'desktop', 'web']].melt(
538
- id_vars=['steps'],
539
- var_name='Metric',
540
- value_name='Accuracy'
541
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
- chart = alt.Chart(melted).mark_line(point=True).encode(
544
- x=alt.X('steps:Q', title='Training Steps'),
545
- y=alt.Y('Accuracy:Q', scale=alt.Scale(domain=[0, 100]), title='Accuracy (%)'),
546
- color=alt.Color('Metric:N', scale=alt.Scale(
547
- domain=['overall', 'desktop', 'web'],
548
- range=['#4ECDC4', '#45B7D1', '#96CEB4']
549
- )),
550
- tooltip=['steps', 'Metric', 'Accuracy']
551
- ).properties(
552
- width=400,
553
- height=300,
554
- title='Accuracy Progression During Training'
555
- )
556
- st.altair_chart(chart, use_container_width=True)
 
 
 
 
 
557
 
558
  with col2:
559
- st.write("**Accuracy vs. Training Loss**")
560
 
561
  if metrics_df['neg_log_loss'].notna().any():
562
  scatter_data = metrics_df[metrics_df['neg_log_loss'].notna()]
563
 
564
  chart = alt.Chart(scatter_data).mark_circle(size=100).encode(
565
  x=alt.X('neg_log_loss:Q', title='-log(Training Loss)'),
566
- y=alt.Y('overall:Q', scale=alt.Scale(domain=[0, 100]), title='Overall Accuracy (%)'),
567
  color=alt.Color('steps:Q', scale=alt.Scale(scheme='viridis'), title='Training Steps'),
568
- tooltip=['steps', 'loss', 'overall']
569
  ).properties(
570
- width=400,
571
- height=300,
572
- title='Accuracy vs. -log(Training Loss)'
573
  )
574
  st.altair_chart(chart, use_container_width=True)
575
  else:
576
  st.info("No training loss data available for this model")
577
 
578
- # Show checkpoint details table
579
  st.write("**Checkpoint Details**")
580
- display_metrics = metrics_df[['steps', 'overall', 'desktop', 'web', 'loss']].copy()
581
- display_metrics.columns = ['Steps', 'Overall %', 'Desktop %', 'Web %', 'Training Loss']
582
- display_metrics[['Overall %', 'Desktop %', 'Web %']] = display_metrics[['Overall %', 'Desktop %', 'Web %']].round(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  display_metrics['Training Loss'] = display_metrics['Training Loss'].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
584
  st.dataframe(display_metrics, use_container_width=True)
585
  else:
586
  st.info("No models with multiple checkpoints available for progression analysis")
587
 
588
  # Detailed breakdown
589
- with st.expander("Detailed UI Type Breakdown"):
590
- # Create a heatmap-style table
591
- detailed_metrics = []
592
- for _, row in ui_metrics_df.iterrows():
593
- detailed_metrics.append({
594
- 'Model': row['model'],
595
- 'Desktop Text': f"{row['desktop_text']:.1f}%",
596
- 'Desktop Icon': f"{row['desktop_icon']:.1f}%",
597
- 'Web Text': f"{row['web_text']:.1f}%",
598
- 'Web Icon': f"{row['web_icon']:.1f}%",
599
- 'Overall': f"{row['overall']:.1f}%"
600
- })
601
-
602
- if detailed_metrics:
603
- st.dataframe(pd.DataFrame(detailed_metrics), use_container_width=True)
 
604
 
605
  else:
606
  # For non-ScreenSpot datasets, show a simple bar chart
@@ -627,27 +808,6 @@ def main():
627
  display_df['Accuracy (%)'] = display_df['Accuracy (%)'].apply(lambda x: f"{x:.2f}")
628
  display_df['Training Loss'] = display_df['Training Loss'].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
629
  st.dataframe(display_df, use_container_width=True)
630
-
631
- # Raw data viewer
632
- with st.expander("Sample Results"):
633
- if selected_model != 'All' and len(filtered_df) == 1:
634
- summary = filtered_df.iloc[0]['sample_results_summary']
635
- st.write(f"**Total evaluation samples:** {summary['total_samples']}")
636
- st.write("**First 5 sample results:**")
637
- for i, sample in enumerate(summary['first_5_samples'], 1):
638
- st.write(f"\n**Sample {i}:**")
639
- col1, col2 = st.columns([1, 2])
640
- with col1:
641
- st.write(f"- **Correct:** {'✅' if sample.get('is_correct') else '❌'}")
642
- st.write(f"- **Image:** {sample.get('img_filename', 'N/A')}")
643
- with col2:
644
- st.write(f"- **Instruction:** {sample.get('instruction', 'N/A')}")
645
- if sample.get('predicted_click'):
646
- st.write(f"- **Predicted Click:** {sample['predicted_click']}")
647
- if sample.get('error_msg'):
648
- st.write(f"- **Error:** {sample['error_msg']}")
649
- else:
650
- st.info("Select a specific model to view sample results")
651
 
652
  if __name__ == "__main__":
653
  main()
 
167
  "checkpoint_steps": metadata.get("checkpoint_steps"),
168
  "training_loss": metadata.get("training_loss"),
169
  "ui_type_results": ui_type_results,
170
+ "dataset_type_results": dataset_type_results
 
 
 
 
 
171
  }
172
 
173
  results.append(result_entry)
 
237
  continue
238
 
239
  model = row['model']
240
+ ui_results = row.get('ui_type_results', {})
241
+ dataset_type_results = row.get('dataset_type_results', {})
242
 
243
  # For ScreenSpot datasets, we have desktop/web and text/icon
244
  if 'screenspot' in dataset_filter.lower():
245
+ # First try to get from ui_type_results
246
  desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
247
  desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
248
  web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
249
  web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
250
 
251
+ # If all zeros, try to get from dataset_type_results
252
+ if desktop_text == 0 and desktop_icon == 0 and web_text == 0 and web_icon == 0:
253
+ # Check if data is nested under dataset types (e.g., 'screenspot-v2')
254
+ for dataset_key in dataset_type_results:
255
+ if 'screenspot' in dataset_key.lower():
256
+ dataset_data = dataset_type_results[dataset_key]
257
+ if 'by_ui_type' in dataset_data:
258
+ ui_data = dataset_data['by_ui_type']
259
+ desktop_text = ui_data.get('desktop_text', {}).get('correct', 0) / max(ui_data.get('desktop_text', {}).get('total', 1), 1) * 100
260
+ desktop_icon = ui_data.get('desktop_icon', {}).get('correct', 0) / max(ui_data.get('desktop_icon', {}).get('total', 1), 1) * 100
261
+ web_text = ui_data.get('web_text', {}).get('correct', 0) / max(ui_data.get('web_text', {}).get('total', 1), 1) * 100
262
+ web_icon = ui_data.get('web_icon', {}).get('correct', 0) / max(ui_data.get('web_icon', {}).get('total', 1), 1) * 100
263
+ break
264
+
265
  # Calculate averages
266
  desktop_avg = (desktop_text + desktop_icon) / 2 if (desktop_text > 0 or desktop_icon > 0) else 0
267
  web_avg = (web_text + web_icon) / 2 if (web_text > 0 or web_icon > 0) else 0
 
270
 
271
  # For screenspot-v2, calculate the overall as average of desktop and web
272
  if dataset_filter == 'screenspot-v2':
273
+ overall = (desktop_avg + web_avg) / 2 if (desktop_avg > 0 or web_avg > 0) else row['overall_accuracy']
274
  else:
275
  overall = row['overall_accuracy']
276
 
 
288
  'is_best_not_last': row.get('is_best_not_last', False),
289
  'all_checkpoints': row.get('all_checkpoints', [])
290
  })
291
+ else:
292
+ # For non-screenspot datasets, just pass through overall accuracy
293
+ metrics_list.append({
294
+ 'model': model,
295
+ 'overall': row['overall_accuracy'],
296
+ 'is_best_not_last': row.get('is_best_not_last', False),
297
+ 'all_checkpoints': row.get('all_checkpoints', [])
298
+ })
299
 
300
  return pd.DataFrame(metrics_list)
301
 
 
344
  tooltip=['Model', 'Score', 'Type']
345
  ).properties(
346
  title=title,
347
+ width=500, # Increased from 400
348
+ height=400 # Increased from 300
349
  )
350
 
351
  # Add value labels
 
392
  # Main content
393
  st.header(f"Results for {selected_dataset}")
394
 
395
+ # Debug information (can be removed later)
396
+ with st.expander("Debug Information"):
397
+ st.write(f"Total rows in filtered_df: {len(filtered_df)}")
398
+ st.write(f"Total rows in ui_metrics_df: {len(ui_metrics_df)}")
399
+ if not filtered_df.empty:
400
+ st.write("Sample data from filtered_df:")
401
+ st.write(filtered_df[['model', 'base_model', 'is_checkpoint', 'overall_accuracy']].head())
402
+
403
+ # Show UI type results structure
404
+ st.write("\nUI Type Results Structure:")
405
+ for idx, row in filtered_df.head(2).iterrows():
406
+ st.write(f"\nModel: {row['model']}")
407
+ ui_results = row.get('ui_type_results', {})
408
+ if ui_results:
409
+ st.write("UI Type Keys:", list(ui_results.keys()))
410
+ # Show a sample of the structure
411
+ for key in list(ui_results.keys())[:2]:
412
+ st.write(f" {key}: {ui_results[key]}")
413
+ else:
414
+ st.write(" No UI type results found")
415
+
416
+ # Also check dataset_type_results
417
+ dataset_type_results = row.get('dataset_type_results', {})
418
+ if dataset_type_results:
419
+ st.write("Dataset Type Results Keys:", list(dataset_type_results.keys()))
420
+ for key in list(dataset_type_results.keys())[:2]:
421
+ st.write(f" {key}: {dataset_type_results[key]}")
422
+
423
+ if not ui_metrics_df.empty:
424
+ st.write("\nSample data from ui_metrics_df:")
425
+ st.write(ui_metrics_df[['model', 'overall', 'desktop_avg', 'web_avg']].head())
426
+
427
  # Overall metrics
428
  col1, col2, col3 = st.columns(3)
429
  with col1:
 
440
  # Parse UI type metrics
441
  ui_metrics_df = parse_ui_type_metrics(filtered_df, selected_dataset)
442
 
443
+ # Add metric selector for screenspot datasets
444
+ selected_metric = 'overall' # Default metric
445
  if not ui_metrics_df.empty and 'screenspot' in selected_dataset.lower():
446
  st.subheader("Performance by UI Type")
447
 
448
+ # Metric selector dropdown
449
+ if selected_dataset == 'screenspot-v2':
450
+ metric_options = {
451
+ 'overall': 'Overall Average (Desktop + Web) / 2',
452
+ 'desktop_avg': 'Desktop Average',
453
+ 'web_avg': 'Web Average',
454
+ 'desktop_text': 'Desktop (Text)',
455
+ 'desktop_icon': 'Desktop (Icon)',
456
+ 'web_text': 'Web (Text)',
457
+ 'web_icon': 'Web (Icon)',
458
+ 'text_avg': 'Text Average',
459
+ 'icon_avg': 'Icon Average'
460
+ }
461
+ else:
462
+ metric_options = {
463
+ 'overall': 'Overall Average',
464
+ 'desktop_avg': 'Desktop Average',
465
+ 'web_avg': 'Web Average',
466
+ 'text_avg': 'Text Average',
467
+ 'icon_avg': 'Icon Average'
468
+ }
469
+
470
+ selected_metric = st.selectbox(
471
+ "Select metric to visualize:",
472
+ options=list(metric_options.keys()),
473
+ format_func=lambda x: metric_options[x],
474
+ key="metric_selector"
475
+ )
476
+
477
  # Add note about asterisks
478
  if any(ui_metrics_df['is_best_not_last']):
479
  st.info("* indicates the best checkpoint is not the last checkpoint")
480
 
481
+ # Create single chart for selected metric
482
+ chart = create_bar_chart(ui_metrics_df, selected_metric, metric_options[selected_metric])
483
+ if chart:
484
+ st.altair_chart(chart, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  else:
486
+ st.warning(f"No data available for {metric_options[selected_metric]}")
487
+
488
+ # Show all metrics in an expandable section
489
+ with st.expander("View All Metrics"):
490
+ if selected_dataset == 'screenspot-v2':
491
+ # First row: Overall, Desktop, Web averages
492
+ col1, col2, col3 = st.columns(3)
493
+
494
+ with col1:
495
+ chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average (Desktop + Web) / 2')
496
+ if chart:
497
+ st.altair_chart(chart, use_container_width=True)
498
+
499
+ with col2:
500
+ chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
501
+ if chart:
502
+ st.altair_chart(chart, use_container_width=True)
503
+
504
+ with col3:
505
+ chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
506
+ if chart:
507
+ st.altair_chart(chart, use_container_width=True)
508
+
509
+ # Second row: Individual UI type metrics
510
+ col1, col2, col3, col4 = st.columns(4)
511
+
512
+ with col1:
513
+ chart = create_bar_chart(ui_metrics_df, 'desktop_text', 'Desktop (Text)')
514
+ if chart:
515
+ st.altair_chart(chart, use_container_width=True)
516
+
517
+ with col2:
518
+ chart = create_bar_chart(ui_metrics_df, 'desktop_icon', 'Desktop (Icon)')
519
+ if chart:
520
+ st.altair_chart(chart, use_container_width=True)
521
+
522
+ with col3:
523
+ chart = create_bar_chart(ui_metrics_df, 'web_text', 'Web (Text)')
524
+ if chart:
525
+ st.altair_chart(chart, use_container_width=True)
526
+
527
+ with col4:
528
+ chart = create_bar_chart(ui_metrics_df, 'web_icon', 'Web (Icon)')
529
+ if chart:
530
+ st.altair_chart(chart, use_container_width=True)
531
+
532
+ # Third row: Text vs Icon averages
533
+ col1, col2 = st.columns(2)
534
+
535
+ with col1:
536
+ chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (Desktop + Web)')
537
+ if chart:
538
+ st.altair_chart(chart, use_container_width=True)
539
+
540
+ with col2:
541
+ chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (Desktop + Web)')
542
+ if chart:
543
+ st.altair_chart(chart, use_container_width=True)
544
+ else:
545
+ # For other screenspot datasets, show the standard layout
546
+ col1, col2 = st.columns(2)
547
+
548
+ with col1:
549
+ # Overall Average
550
+ chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average')
551
+ if chart:
552
+ st.altair_chart(chart, use_container_width=True)
553
+
554
+ # Desktop Average
555
+ chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
556
+ if chart:
557
+ st.altair_chart(chart, use_container_width=True)
558
+
559
+ # Text Average
560
+ chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (UI-Type)')
561
+ if chart:
562
+ st.altair_chart(chart, use_container_width=True)
563
+
564
+ with col2:
565
+ # Web Average
566
+ chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
567
+ if chart:
568
+ st.altair_chart(chart, use_container_width=True)
569
+
570
+ # Icon Average
571
+ chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (UI-Type)')
572
+ if chart:
573
+ st.altair_chart(chart, use_container_width=True)
574
 
575
  # Checkpoint progression visualization
576
  with st.expander("Checkpoint Progression Analysis"):
 
593
  # Prepare data for visualization
594
  checkpoint_metrics = []
595
  for _, cp in checkpoint_df.iterrows():
596
+ ui_results = cp.get('ui_type_results', {})
597
+ dataset_type_results = cp.get('dataset_type_results', {})
598
 
599
+ # First try to get from ui_type_results
600
  desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
601
  desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
602
  web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
603
  web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
604
 
605
+ # If all zeros, try to get from dataset_type_results
606
+ if desktop_text == 0 and desktop_icon == 0 and web_text == 0 and web_icon == 0:
607
+ # Check if data is nested under dataset types
608
+ for dataset_key in dataset_type_results:
609
+ if 'screenspot' in dataset_key.lower():
610
+ dataset_data = dataset_type_results[dataset_key]
611
+ if 'by_ui_type' in dataset_data:
612
+ ui_data = dataset_data['by_ui_type']
613
+ desktop_text = ui_data.get('desktop_text', {}).get('correct', 0) / max(ui_data.get('desktop_text', {}).get('total', 1), 1) * 100
614
+ desktop_icon = ui_data.get('desktop_icon', {}).get('correct', 0) / max(ui_data.get('desktop_icon', {}).get('total', 1), 1) * 100
615
+ web_text = ui_data.get('web_text', {}).get('correct', 0) / max(ui_data.get('web_text', {}).get('total', 1), 1) * 100
616
+ web_icon = ui_data.get('web_icon', {}).get('correct', 0) / max(ui_data.get('web_icon', {}).get('total', 1), 1) * 100
617
+ break
618
+
619
  desktop_avg = (desktop_text + desktop_icon) / 2
620
  web_avg = (web_text + web_icon) / 2
621
+ text_avg = (desktop_text + web_text) / 2
622
+ icon_avg = (desktop_icon + web_icon) / 2
623
  overall = (desktop_avg + web_avg) / 2 if selected_dataset == 'screenspot-v2' else cp['overall_accuracy']
624
 
625
  checkpoint_metrics.append({
626
  'steps': cp['checkpoint_steps'] or 0,
627
  'overall': overall,
628
+ 'desktop_avg': desktop_avg,
629
+ 'web_avg': web_avg,
630
+ 'desktop_text': desktop_text,
631
+ 'desktop_icon': desktop_icon,
632
+ 'web_text': web_text,
633
+ 'web_icon': web_icon,
634
+ 'text_avg': text_avg,
635
+ 'icon_avg': icon_avg,
636
  'loss': cp['training_loss'],
637
  'neg_log_loss': -np.log(cp['training_loss']) if cp['training_loss'] and cp['training_loss'] > 0 else None
638
  })
 
645
  with col1:
646
  st.write("**Accuracy over Training Steps**")
647
 
648
+ # Determine which metrics to show based on selected metric
649
+ if selected_metric == 'overall':
650
+ # Show overall, desktop, and web averages
651
+ metrics_to_show = ['overall', 'desktop_avg', 'web_avg']
652
+ metric_labels = ['Overall', 'Desktop Avg', 'Web Avg']
653
+ colors = ['#4ECDC4', '#45B7D1', '#96CEB4']
654
+ elif 'desktop' in selected_metric:
655
+ # Show all desktop metrics
656
+ metrics_to_show = ['desktop_avg', 'desktop_text', 'desktop_icon']
657
+ metric_labels = ['Desktop Avg', 'Desktop Text', 'Desktop Icon']
658
+ colors = ['#45B7D1', '#FFA726', '#FF6B6B']
659
+ elif 'web' in selected_metric:
660
+ # Show all web metrics
661
+ metrics_to_show = ['web_avg', 'web_text', 'web_icon']
662
+ metric_labels = ['Web Avg', 'Web Text', 'Web Icon']
663
+ colors = ['#96CEB4', '#9C27B0', '#E91E63']
664
+ elif 'text' in selected_metric:
665
+ # Show text metrics across environments
666
+ metrics_to_show = ['text_avg', 'desktop_text', 'web_text']
667
+ metric_labels = ['Text Avg', 'Desktop Text', 'Web Text']
668
+ colors = ['#FF9800', '#FFA726', '#FFB74D']
669
+ elif 'icon' in selected_metric:
670
+ # Show icon metrics across environments
671
+ metrics_to_show = ['icon_avg', 'desktop_icon', 'web_icon']
672
+ metric_labels = ['Icon Avg', 'Desktop Icon', 'Web Icon']
673
+ colors = ['#3F51B5', '#5C6BC0', '#7986CB']
674
+ else:
675
+ # Default: just show the selected metric
676
+ metrics_to_show = [selected_metric]
677
+ metric_labels = [metric_options.get(selected_metric, selected_metric)]
678
+ colors = ['#4ECDC4']
679
+
680
+ # Create multi-line chart data
681
+ chart_data = []
682
+ for i, (metric, label) in enumerate(zip(metrics_to_show, metric_labels)):
683
+ for _, row in metrics_df.iterrows():
684
+ if metric in row:
685
+ chart_data.append({
686
+ 'steps': row['steps'],
687
+ 'value': row[metric],
688
+ 'metric': label,
689
+ 'color_idx': i
690
+ })
691
 
692
+ if chart_data:
693
+ chart_df = pd.DataFrame(chart_data)
694
+
695
+ # Create multi-line chart with distinct colors
696
+ chart = alt.Chart(chart_df).mark_line(point=True, strokeWidth=2).encode(
697
+ x=alt.X('steps:Q', title='Training Steps'),
698
+ y=alt.Y('value:Q', scale=alt.Scale(domain=[0, 100]), title='Accuracy (%)'),
699
+ color=alt.Color('metric:N',
700
+ scale=alt.Scale(domain=metric_labels, range=colors),
701
+ legend=alt.Legend(title="Metric")),
702
+ tooltip=['steps:Q', 'metric:N', alt.Tooltip('value:Q', format='.1f', title='Accuracy')]
703
+ ).properties(
704
+ width=500,
705
+ height=400,
706
+ title='Accuracy Progression During Training'
707
+ )
708
+ st.altair_chart(chart, use_container_width=True)
709
+ else:
710
+ st.warning("No data available for the selected metrics")
711
 
712
  with col2:
713
+ st.write(f"**{metric_options[selected_metric]} vs. Training Loss**")
714
 
715
  if metrics_df['neg_log_loss'].notna().any():
716
  scatter_data = metrics_df[metrics_df['neg_log_loss'].notna()]
717
 
718
  chart = alt.Chart(scatter_data).mark_circle(size=100).encode(
719
  x=alt.X('neg_log_loss:Q', title='-log(Training Loss)'),
720
+ y=alt.Y(f'{selected_metric}:Q', scale=alt.Scale(domain=[0, 100]), title=f'{metric_options[selected_metric]} (%)'),
721
  color=alt.Color('steps:Q', scale=alt.Scale(scheme='viridis'), title='Training Steps'),
722
+ tooltip=['steps', 'loss', selected_metric]
723
  ).properties(
724
+ width=500, # Increased from 400
725
+ height=400, # Increased from 300
726
+ title=f'{metric_options[selected_metric]} vs. -log(Training Loss)'
727
  )
728
  st.altair_chart(chart, use_container_width=True)
729
  else:
730
  st.info("No training loss data available for this model")
731
 
732
+ # Show checkpoint details table with selected metric
733
  st.write("**Checkpoint Details**")
734
+
735
+ # Determine columns to display based on selected metric category
736
+ if selected_metric == 'overall':
737
+ display_cols = ['steps', 'overall', 'desktop_avg', 'web_avg', 'loss']
738
+ col_labels = ['Steps', 'Overall %', 'Desktop Avg %', 'Web Avg %', 'Training Loss']
739
+ elif 'desktop' in selected_metric:
740
+ display_cols = ['steps', 'desktop_avg', 'desktop_text', 'desktop_icon', 'loss']
741
+ col_labels = ['Steps', 'Desktop Avg %', 'Desktop Text %', 'Desktop Icon %', 'Training Loss']
742
+ elif 'web' in selected_metric:
743
+ display_cols = ['steps', 'web_avg', 'web_text', 'web_icon', 'loss']
744
+ col_labels = ['Steps', 'Web Avg %', 'Web Text %', 'Web Icon %', 'Training Loss']
745
+ elif 'text' in selected_metric:
746
+ display_cols = ['steps', 'text_avg', 'desktop_text', 'web_text', 'loss']
747
+ col_labels = ['Steps', 'Text Avg %', 'Desktop Text %', 'Web Text %', 'Training Loss']
748
+ elif 'icon' in selected_metric:
749
+ display_cols = ['steps', 'icon_avg', 'desktop_icon', 'web_icon', 'loss']
750
+ col_labels = ['Steps', 'Icon Avg %', 'Desktop Icon %', 'Web Icon %', 'Training Loss']
751
+ else:
752
+ display_cols = ['steps', selected_metric, 'loss']
753
+ col_labels = ['Steps', f'{metric_options[selected_metric]} %', 'Training Loss']
754
+
755
+ display_metrics = metrics_df[display_cols].copy()
756
+ display_metrics.columns = col_labels
757
+
758
+ # Format percentage columns
759
+ for col in col_labels:
760
+ if '%' in col and col != 'Training Loss':
761
+ display_metrics[col] = display_metrics[col].round(2)
762
+
763
  display_metrics['Training Loss'] = display_metrics['Training Loss'].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
764
  st.dataframe(display_metrics, use_container_width=True)
765
  else:
766
  st.info("No models with multiple checkpoints available for progression analysis")
767
 
768
  # Detailed breakdown
769
+ if selected_dataset == 'screenspot-v2':
770
+ with st.expander("Detailed UI Type Breakdown"):
771
+ # Create a heatmap-style table
772
+ detailed_metrics = []
773
+ for _, row in ui_metrics_df.iterrows():
774
+ detailed_metrics.append({
775
+ 'Model': row['model'],
776
+ 'Desktop Text': f"{row['desktop_text']:.1f}%",
777
+ 'Desktop Icon': f"{row['desktop_icon']:.1f}%",
778
+ 'Web Text': f"{row['web_text']:.1f}%",
779
+ 'Web Icon': f"{row['web_icon']:.1f}%",
780
+ 'Overall': f"{row['overall']:.1f}%"
781
+ })
782
+
783
+ if detailed_metrics:
784
+ st.dataframe(pd.DataFrame(detailed_metrics), use_container_width=True)
785
 
786
  else:
787
  # For non-ScreenSpot datasets, show a simple bar chart
 
808
  display_df['Accuracy (%)'] = display_df['Accuracy (%)'].apply(lambda x: f"{x:.2f}")
809
  display_df['Training Loss'] = display_df['Training Loss'].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
810
  st.dataframe(display_df, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
 
812
  if __name__ == "__main__":
813
  main()