Anas Awadalla commited on
Commit
79cb6e1
·
1 Parent(s): 2dbb46e

more analysis + baselines

Browse files
Files changed (2) hide show
  1. README.md +37 -13
  2. src/streamlit_app.py +285 -33
README.md CHANGED
@@ -20,10 +20,11 @@ A Streamlit application for visualizing model performance on grounding benchmark
20
  - **Real-time Data**: Streams results directly from the HuggingFace leaderboard repository without local storage
21
  - **Interactive Visualizations**: Bar charts comparing model performance across different metrics
22
  - **Baseline Comparisons**: Shows baseline models (Qwen2-VL, UI-TARS) alongside evaluated models
23
- - **UI Type Breakdown**: For ScreenSpot datasets, shows performance by:
24
- - Desktop vs Web
25
- - Text vs Icon elements
26
- - Overall averages
 
27
  - **Model Details**: View training loss, checkpoint steps, and evaluation timestamps
28
  - **Sample Results**: Inspect the first 5 evaluation samples for each model
29
 
@@ -49,15 +50,23 @@ The app will open in your browser at `http://localhost:8501`
49
 
50
  2. **Filter Models**: Optionally filter to view a specific model or all models
51
 
52
- 3. **View Charts**: The main page displays:
53
- - Overall metrics (number of models, best accuracy, total samples)
54
- - Bar charts comparing performance across different UI types
55
- - Baseline model comparisons (shown in orange)
 
 
 
 
56
 
57
  4. **Explore Details**:
58
  - Expand "Model Details" to see training metadata
59
  - Expand "Detailed UI Type Breakdown" for a comprehensive table
60
  - Expand "Sample Results" to see the first 5 evaluation samples
 
 
 
 
61
 
62
  ## Data Source
63
 
@@ -75,7 +84,7 @@ To minimize local storage requirements, the app:
75
 
76
  ## Supported Datasets
77
 
78
- - **ScreenSpot-v2**: Web and desktop UI element grounding
79
  - **ScreenSpot-Pro**: Professional UI grounding benchmark
80
  - **ShowdownClicks**: Click prediction benchmark
81
  - And more as they are added to the leaderboard
@@ -83,10 +92,25 @@ To minimize local storage requirements, the app:
83
  ## Baseline Models
84
 
85
  For ScreenSpot-v2, the following baselines are included:
86
- - Qwen2-VL-7B
87
- - UI-TARS-2B
88
- - UI-TARS-7B
89
- - UI-TARS-72B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  ## Caching
92
 
 
20
  - **Real-time Data**: Streams results directly from the HuggingFace leaderboard repository without local storage
21
  - **Interactive Visualizations**: Bar charts comparing model performance across different metrics
22
  - **Baseline Comparisons**: Shows baseline models (Qwen2-VL, UI-TARS) alongside evaluated models
23
+ - **Best Checkpoint Selection**: Automatically shows the best performing checkpoint for each model (marked with * if not the last checkpoint)
24
+ - **UI Type Breakdown**:
25
+ - For ScreenSpot-v2: Comprehensive charts showing Overall, Desktop, Web, and individual UI type performance
26
+ - For other datasets: Desktop vs Web and Text vs Icon performance
27
+ - **Checkpoint Progression Analysis**: Visualize how metrics evolve during training
28
  - **Model Details**: View training loss, checkpoint steps, and evaluation timestamps
29
  - **Sample Results**: Inspect the first 5 evaluation samples for each model
30
 
 
50
 
51
  2. **Filter Models**: Optionally filter to view a specific model or all models
52
 
53
+ 3. **View Charts**:
54
+ - For ScreenSpot-v2:
55
+ - Overall performance (average of desktop and web)
56
+ - Desktop and Web averages
57
+ - Individual UI type metrics: Desktop (Text), Desktop (Icon), Web (Text), Web (Icon)
58
+ - Text and Icon averages across environments
59
+ - Baseline model comparisons shown in orange
60
+ - Models marked with * indicate the best checkpoint is not the final one
61
 
62
  4. **Explore Details**:
63
  - Expand "Model Details" to see training metadata
64
  - Expand "Detailed UI Type Breakdown" for a comprehensive table
65
  - Expand "Sample Results" to see the first 5 evaluation samples
66
+ - Expand "Checkpoint Progression Analysis" to:
67
+ - View accuracy progression over training steps
68
+ - See the relationship between training loss and accuracy
69
+ - Compare performance across checkpoints
70
 
71
  ## Data Source
72
 
 
84
 
85
  ## Supported Datasets
86
 
87
+ - **ScreenSpot-v2**: Web and desktop UI element grounding (with special handling for desktop/web averaging)
88
  - **ScreenSpot-Pro**: Professional UI grounding benchmark
89
  - **ShowdownClicks**: Click prediction benchmark
90
  - And more as they are added to the leaderboard
 
92
  ## Baseline Models
93
 
94
  For ScreenSpot-v2, the following baselines are included:
95
+ - Qwen2-VL-7B: 37.96%
96
+ - UI-TARS-2B: 82.8%
97
+ - UI-TARS-7B: 92.2%
98
+ - UI-TARS-72B: 88.3%
99
+
100
+ For ScreenSpot-Pro, the following baselines are included:
101
+ - Qwen2.5-VL-3B-Instruct: 16.1%
102
+ - Qwen2.5-VL-7B-Instruct: 26.8%
103
+ - Qwen2.5-VL-72B-Instruct: 53.3%
104
+ - UI-TARS-2B: 27.7%
105
+ - UI-TARS-7B: 35.7%
106
+ - UI-TARS-72B: 38.1%
107
+
108
+ ## Checkpoint Handling
109
+
110
+ - The app automatically identifies the best performing checkpoint for each model
111
+ - If multiple checkpoints exist, only the best one is shown in the main charts
112
+ - An asterisk (*) indicates when the best checkpoint is not the last one
113
+ - Use the "Checkpoint Progression Analysis" to explore all checkpoints
114
 
115
  ## Caching
116
 
src/streamlit_app.py CHANGED
@@ -53,6 +53,26 @@ BASELINES = {
53
  "web_icon": 86.3,
54
  "overall": 88.3
55
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
  }
58
 
@@ -99,18 +119,36 @@ def fetch_leaderboard_data():
99
  # Get model name from metadata or path
100
  model_checkpoint = metadata.get("model_checkpoint", "")
101
  model_name = model_checkpoint.split('/')[-1]
 
 
102
 
103
  # Handle checkpoint names
104
  if not model_name and len(path_parts) > 2:
105
  # Check if it's a checkpoint subdirectory structure
106
  if len(path_parts) > 3 and path_parts[2] != path_parts[3]:
107
  # Format: grounding/dataset/base_model/checkpoint.json
108
- base_model = path_parts[2]
109
  checkpoint_file = path_parts[3].replace(".json", "")
110
- model_name = f"{base_model}/{checkpoint_file}"
 
111
  else:
112
  # Regular format: grounding/dataset/results_modelname.json
113
  model_name = path_parts[2].replace("results_", "").replace(".json", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # Extract UI type results if available
116
  ui_type_results = detailed_results.get("by_ui_type", {})
@@ -120,6 +158,8 @@ def fetch_leaderboard_data():
120
  result_entry = {
121
  "dataset": dataset_name,
122
  "model": model_name,
 
 
123
  "model_path": model_checkpoint,
124
  "overall_accuracy": metrics.get("accuracy", 0) * 100, # Convert to percentage
125
  "total_samples": metrics.get("total", 0),
@@ -145,7 +185,49 @@ def fetch_leaderboard_data():
145
  progress_bar.empty()
146
  status_text.empty()
147
 
148
- return pd.DataFrame(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  except Exception as e:
151
  st.error(f"Error fetching leaderboard data: {str(e)}")
@@ -164,17 +246,23 @@ def parse_ui_type_metrics(df: pd.DataFrame, dataset_filter: str) -> pd.DataFrame
164
 
165
  # For ScreenSpot datasets, we have desktop/web and text/icon
166
  if 'screenspot' in dataset_filter.lower():
167
- # Calculate aggregated metrics
168
  desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
169
  desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
170
  web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
171
  web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
172
 
173
  # Calculate averages
174
- desktop_avg = (desktop_text + desktop_icon) / 2 if desktop_text or desktop_icon else 0
175
- web_avg = (web_text + web_icon) / 2 if web_text or web_icon else 0
176
- text_avg = (desktop_text + web_text) / 2 if desktop_text or web_text else 0
177
- icon_avg = (desktop_icon + web_icon) / 2 if desktop_icon or web_icon else 0
 
 
 
 
 
 
178
 
179
  metrics_list.append({
180
  'model': model,
@@ -186,7 +274,9 @@ def parse_ui_type_metrics(df: pd.DataFrame, dataset_filter: str) -> pd.DataFrame
186
  'web_avg': web_avg,
187
  'text_avg': text_avg,
188
  'icon_avg': icon_avg,
189
- 'overall': row['overall_accuracy']
 
 
190
  })
191
 
192
  return pd.DataFrame(metrics_list)
@@ -303,35 +393,197 @@ def main():
303
  if not ui_metrics_df.empty and 'screenspot' in selected_dataset.lower():
304
  st.subheader("Performance by UI Type")
305
 
306
- # Create charts in a grid
307
- col1, col2 = st.columns(2)
 
308
 
309
- with col1:
310
- # Overall Average
311
- chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average')
312
- if chart:
313
- st.altair_chart(chart, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
- # Desktop Average
316
- chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
317
- if chart:
318
- st.altair_chart(chart, use_container_width=True)
319
 
320
- # Text Average
321
- chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (UI-Type)')
322
- if chart:
323
- st.altair_chart(chart, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- with col2:
326
- # Web Average
327
- chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
328
- if chart:
329
- st.altair_chart(chart, use_container_width=True)
330
 
331
- # Icon Average
332
- chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (UI-Type)')
333
- if chart:
334
- st.altair_chart(chart, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  # Detailed breakdown
337
  with st.expander("Detailed UI Type Breakdown"):
 
53
  "web_icon": 86.3,
54
  "overall": 88.3
55
  }
56
+ },
57
+ "screenspot-pro": {
58
+ "Qwen2.5-VL-3B-Instruct": {
59
+ "overall": 16.1
60
+ },
61
+ "Qwen2.5-VL-7B-Instruct": {
62
+ "overall": 26.8
63
+ },
64
+ "Qwen2.5-VL-72B-Instruct": {
65
+ "overall": 53.3
66
+ },
67
+ "UI-TARS-2B": {
68
+ "overall": 27.7
69
+ },
70
+ "UI-TARS-7B": {
71
+ "overall": 35.7
72
+ },
73
+ "UI-TARS-72B": {
74
+ "overall": 38.1
75
+ }
76
  }
77
  }
78
 
 
119
  # Get model name from metadata or path
120
  model_checkpoint = metadata.get("model_checkpoint", "")
121
  model_name = model_checkpoint.split('/')[-1]
122
+ base_model_name = None
123
+ is_checkpoint = False
124
 
125
  # Handle checkpoint names
126
  if not model_name and len(path_parts) > 2:
127
  # Check if it's a checkpoint subdirectory structure
128
  if len(path_parts) > 3 and path_parts[2] != path_parts[3]:
129
  # Format: grounding/dataset/base_model/checkpoint.json
130
+ base_model_name = path_parts[2]
131
  checkpoint_file = path_parts[3].replace(".json", "")
132
+ model_name = f"{base_model_name}/{checkpoint_file}"
133
+ is_checkpoint = True
134
  else:
135
  # Regular format: grounding/dataset/results_modelname.json
136
  model_name = path_parts[2].replace("results_", "").replace(".json", "")
137
+ base_model_name = model_name
138
+
139
+ # Check if model name indicates a checkpoint
140
+ if 'checkpoint-' in model_name:
141
+ is_checkpoint = True
142
+ if not base_model_name:
143
+ # Extract base model name from full path
144
+ if '/' in model_name:
145
+ parts = model_name.split('/')
146
+ base_model_name = parts[0]
147
+ else:
148
+ # Try to get from model_checkpoint path
149
+ checkpoint_parts = model_checkpoint.split('/')
150
+ if len(checkpoint_parts) > 1:
151
+ base_model_name = checkpoint_parts[-2]
152
 
153
  # Extract UI type results if available
154
  ui_type_results = detailed_results.get("by_ui_type", {})
 
158
  result_entry = {
159
  "dataset": dataset_name,
160
  "model": model_name,
161
+ "base_model": base_model_name or model_name,
162
+ "is_checkpoint": is_checkpoint,
163
  "model_path": model_checkpoint,
164
  "overall_accuracy": metrics.get("accuracy", 0) * 100, # Convert to percentage
165
  "total_samples": metrics.get("total", 0),
 
185
  progress_bar.empty()
186
  status_text.empty()
187
 
188
+ # Create DataFrame
189
+ df = pd.DataFrame(results)
190
+
191
+ # Process checkpoints: for each base model, find the best checkpoint
192
+ if not df.empty:
193
+ # Group by dataset and base_model
194
+ grouped = df.groupby(['dataset', 'base_model'])
195
+
196
+ # For each group, find the best checkpoint
197
+ best_models = []
198
+ for (dataset, base_model), group in grouped:
199
+ if len(group) > 1:
200
+ # Multiple entries for this model (likely checkpoints)
201
+ best_idx = group['overall_accuracy'].idxmax()
202
+ best_row = group.loc[best_idx].copy()
203
+
204
+ # Check if the best is the last checkpoint
205
+ checkpoint_steps = group[group['checkpoint_steps'].notna()]['checkpoint_steps'].sort_values()
206
+ if len(checkpoint_steps) > 0:
207
+ last_checkpoint_steps = checkpoint_steps.iloc[-1]
208
+ best_checkpoint_steps = best_row['checkpoint_steps']
209
+ if pd.notna(best_checkpoint_steps) and best_checkpoint_steps != last_checkpoint_steps:
210
+ # Best checkpoint is not the last one, add asterisk
211
+ best_row['model'] = best_row['model'] + '*'
212
+ best_row['is_best_not_last'] = True
213
+ else:
214
+ best_row['is_best_not_last'] = False
215
+
216
+ # Store all checkpoints for this model
217
+ best_row['all_checkpoints'] = group.to_dict('records')
218
+ best_models.append(best_row)
219
+ else:
220
+ # Single entry for this model
221
+ row = group.iloc[0].copy()
222
+ row['is_best_not_last'] = False
223
+ row['all_checkpoints'] = [row.to_dict()]
224
+ best_models.append(row)
225
+
226
+ # Create new dataframe with best models
227
+ df_best = pd.DataFrame(best_models)
228
+ return df_best
229
+
230
+ return df
231
 
232
  except Exception as e:
233
  st.error(f"Error fetching leaderboard data: {str(e)}")
 
246
 
247
  # For ScreenSpot datasets, we have desktop/web and text/icon
248
  if 'screenspot' in dataset_filter.lower():
249
+ # Calculate individual metrics
250
  desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
251
  desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
252
  web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
253
  web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
254
 
255
  # Calculate averages
256
+ desktop_avg = (desktop_text + desktop_icon) / 2 if (desktop_text > 0 or desktop_icon > 0) else 0
257
+ web_avg = (web_text + web_icon) / 2 if (web_text > 0 or web_icon > 0) else 0
258
+ text_avg = (desktop_text + web_text) / 2 if (desktop_text > 0 or web_text > 0) else 0
259
+ icon_avg = (desktop_icon + web_icon) / 2 if (desktop_icon > 0 or web_icon > 0) else 0
260
+
261
+ # For screenspot-v2, calculate the overall as average of desktop and web
262
+ if dataset_filter == 'screenspot-v2':
263
+ overall = (desktop_avg + web_avg) / 2 if (desktop_avg > 0 or web_avg > 0) else 0
264
+ else:
265
+ overall = row['overall_accuracy']
266
 
267
  metrics_list.append({
268
  'model': model,
 
274
  'web_avg': web_avg,
275
  'text_avg': text_avg,
276
  'icon_avg': icon_avg,
277
+ 'overall': overall,
278
+ 'is_best_not_last': row.get('is_best_not_last', False),
279
+ 'all_checkpoints': row.get('all_checkpoints', [])
280
  })
281
 
282
  return pd.DataFrame(metrics_list)
 
393
  if not ui_metrics_df.empty and 'screenspot' in selected_dataset.lower():
394
  st.subheader("Performance by UI Type")
395
 
396
+ # Add note about asterisks
397
+ if any(ui_metrics_df['is_best_not_last']):
398
+ st.info("* indicates the best checkpoint is not the last checkpoint")
399
 
400
+ # Create charts in a grid
401
+ if selected_dataset == 'screenspot-v2':
402
+ # First row: Overall, Desktop, Web averages
403
+ col1, col2, col3 = st.columns(3)
404
+
405
+ with col1:
406
+ chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average (Desktop + Web) / 2')
407
+ if chart:
408
+ st.altair_chart(chart, use_container_width=True)
409
+
410
+ with col2:
411
+ chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
412
+ if chart:
413
+ st.altair_chart(chart, use_container_width=True)
414
+
415
+ with col3:
416
+ chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
417
+ if chart:
418
+ st.altair_chart(chart, use_container_width=True)
419
+
420
+ # Second row: Individual UI type metrics
421
+ col1, col2, col3, col4 = st.columns(4)
422
+
423
+ with col1:
424
+ chart = create_bar_chart(ui_metrics_df, 'desktop_text', 'Desktop (Text)')
425
+ if chart:
426
+ st.altair_chart(chart, use_container_width=True)
427
 
428
+ with col2:
429
+ chart = create_bar_chart(ui_metrics_df, 'desktop_icon', 'Desktop (Icon)')
430
+ if chart:
431
+ st.altair_chart(chart, use_container_width=True)
432
 
433
+ with col3:
434
+ chart = create_bar_chart(ui_metrics_df, 'web_text', 'Web (Text)')
435
+ if chart:
436
+ st.altair_chart(chart, use_container_width=True)
437
+
438
+ with col4:
439
+ chart = create_bar_chart(ui_metrics_df, 'web_icon', 'Web (Icon)')
440
+ if chart:
441
+ st.altair_chart(chart, use_container_width=True)
442
+
443
+ # Third row: Text vs Icon averages
444
+ col1, col2 = st.columns(2)
445
+
446
+ with col1:
447
+ chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (Desktop + Web)')
448
+ if chart:
449
+ st.altair_chart(chart, use_container_width=True)
450
+
451
+ with col2:
452
+ chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (Desktop + Web)')
453
+ if chart:
454
+ st.altair_chart(chart, use_container_width=True)
455
+ else:
456
+ # For other screenspot datasets, show the standard layout
457
+ col1, col2 = st.columns(2)
458
+
459
+ with col1:
460
+ # Overall Average
461
+ chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average')
462
+ if chart:
463
+ st.altair_chart(chart, use_container_width=True)
464
+
465
+ # Desktop Average
466
+ chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
467
+ if chart:
468
+ st.altair_chart(chart, use_container_width=True)
469
+
470
+ # Text Average
471
+ chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (UI-Type)')
472
+ if chart:
473
+ st.altair_chart(chart, use_container_width=True)
474
+
475
+ with col2:
476
+ # Web Average
477
+ chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
478
+ if chart:
479
+ st.altair_chart(chart, use_container_width=True)
480
+
481
+ # Icon Average
482
+ chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (UI-Type)')
483
+ if chart:
484
+ st.altair_chart(chart, use_container_width=True)
485
 
486
+ # Checkpoint progression visualization
487
+ with st.expander("Checkpoint Progression Analysis"):
488
+ # Select a model with checkpoints
489
+ models_with_checkpoints = ui_metrics_df[ui_metrics_df['all_checkpoints'].apply(lambda x: len(x) > 1)]
 
490
 
491
+ if not models_with_checkpoints.empty:
492
+ selected_checkpoint_model = st.selectbox(
493
+ "Select a model to view checkpoint progression:",
494
+ models_with_checkpoints['model'].str.replace('*', '').unique()
495
+ )
496
+
497
+ # Get checkpoint data for selected model
498
+ model_row = models_with_checkpoints[models_with_checkpoints['model'].str.replace('*', '') == selected_checkpoint_model].iloc[0]
499
+ checkpoint_data = model_row['all_checkpoints']
500
+
501
+ # Create DataFrame from checkpoint data
502
+ checkpoint_df = pd.DataFrame(checkpoint_data)
503
+
504
+ # Prepare data for visualization
505
+ checkpoint_metrics = []
506
+ for _, cp in checkpoint_df.iterrows():
507
+ ui_results = cp['ui_type_results']
508
+
509
+ # Calculate metrics
510
+ desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
511
+ desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
512
+ web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
513
+ web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
514
+
515
+ desktop_avg = (desktop_text + desktop_icon) / 2
516
+ web_avg = (web_text + web_icon) / 2
517
+ overall = (desktop_avg + web_avg) / 2 if selected_dataset == 'screenspot-v2' else cp['overall_accuracy']
518
+
519
+ checkpoint_metrics.append({
520
+ 'steps': cp['checkpoint_steps'] or 0,
521
+ 'overall': overall,
522
+ 'desktop': desktop_avg,
523
+ 'web': web_avg,
524
+ 'loss': cp['training_loss'],
525
+ 'neg_log_loss': -np.log(cp['training_loss']) if cp['training_loss'] and cp['training_loss'] > 0 else None
526
+ })
527
+
528
+ metrics_df = pd.DataFrame(checkpoint_metrics).sort_values('steps')
529
+
530
+ # Plot metrics over training steps
531
+ col1, col2 = st.columns(2)
532
+
533
+ with col1:
534
+ st.write("**Accuracy over Training Steps**")
535
+
536
+ # Melt data for multi-line chart
537
+ melted = metrics_df[['steps', 'overall', 'desktop', 'web']].melt(
538
+ id_vars=['steps'],
539
+ var_name='Metric',
540
+ value_name='Accuracy'
541
+ )
542
+
543
+ chart = alt.Chart(melted).mark_line(point=True).encode(
544
+ x=alt.X('steps:Q', title='Training Steps'),
545
+ y=alt.Y('Accuracy:Q', scale=alt.Scale(domain=[0, 100]), title='Accuracy (%)'),
546
+ color=alt.Color('Metric:N', scale=alt.Scale(
547
+ domain=['overall', 'desktop', 'web'],
548
+ range=['#4ECDC4', '#45B7D1', '#96CEB4']
549
+ )),
550
+ tooltip=['steps', 'Metric', 'Accuracy']
551
+ ).properties(
552
+ width=400,
553
+ height=300,
554
+ title='Accuracy Progression During Training'
555
+ )
556
+ st.altair_chart(chart, use_container_width=True)
557
+
558
+ with col2:
559
+ st.write("**Accuracy vs. Training Loss**")
560
+
561
+ if metrics_df['neg_log_loss'].notna().any():
562
+ scatter_data = metrics_df[metrics_df['neg_log_loss'].notna()]
563
+
564
+ chart = alt.Chart(scatter_data).mark_circle(size=100).encode(
565
+ x=alt.X('neg_log_loss:Q', title='-log(Training Loss)'),
566
+ y=alt.Y('overall:Q', scale=alt.Scale(domain=[0, 100]), title='Overall Accuracy (%)'),
567
+ color=alt.Color('steps:Q', scale=alt.Scale(scheme='viridis'), title='Training Steps'),
568
+ tooltip=['steps', 'loss', 'overall']
569
+ ).properties(
570
+ width=400,
571
+ height=300,
572
+ title='Accuracy vs. -log(Training Loss)'
573
+ )
574
+ st.altair_chart(chart, use_container_width=True)
575
+ else:
576
+ st.info("No training loss data available for this model")
577
+
578
+ # Show checkpoint details table
579
+ st.write("**Checkpoint Details**")
580
+ display_metrics = metrics_df[['steps', 'overall', 'desktop', 'web', 'loss']].copy()
581
+ display_metrics.columns = ['Steps', 'Overall %', 'Desktop %', 'Web %', 'Training Loss']
582
+ display_metrics[['Overall %', 'Desktop %', 'Web %']] = display_metrics[['Overall %', 'Desktop %', 'Web %']].round(2)
583
+ display_metrics['Training Loss'] = display_metrics['Training Loss'].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
584
+ st.dataframe(display_metrics, use_container_width=True)
585
+ else:
586
+ st.info("No models with multiple checkpoints available for progression analysis")
587
 
588
  # Detailed breakdown
589
  with st.expander("Detailed UI Type Breakdown"):