acmc commited on
Commit
cdde792
Β·
verified Β·
1 Parent(s): 057eb4b

Create streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +521 -0
streamlit_app.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Streamlit app for interactive complexity metrics visualization.
4
+ """
5
+
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ import plotly.express as px
12
+ import plotly.graph_objects as go
13
+ from plotly.subplots import make_subplots
14
+ import warnings
15
+ warnings.filterwarnings('ignore')
16
+
17
+ # Import visualization utilities
18
+ from visualization.utils import (
19
+ load_and_prepare_dataset,
20
+ get_available_turn_metrics,
21
+ get_human_friendly_metric_name,
22
+ clean_metric_values,
23
+ PLOT_PALETTE,
24
+ setup_plot_style
25
+ )
26
+
27
+ # Setup page config
28
+ st.set_page_config(
29
+ page_title="Complexity Metrics Explorer",
30
+ page_icon="πŸ“Š",
31
+ layout="wide",
32
+ initial_sidebar_state="expanded"
33
+ )
34
+
35
+ # Cache data loading
36
+ @st.cache_data
37
+ def load_data(dataset_name):
38
+ """Load and cache the dataset"""
39
+ df, df_exploded = load_and_prepare_dataset({
40
+ 'dataset_name': dataset_name
41
+ })
42
+ return df, df_exploded
43
+
44
+ @st.cache_data
45
+ def get_metrics(df_exploded):
46
+ """Get available metrics from the dataset"""
47
+ return get_available_turn_metrics(df_exploded)
48
+
49
+ def main():
50
+ st.title("πŸ” Complexity Metrics Explorer")
51
+ st.markdown("Interactive visualization of conversation complexity metrics across different dataset types.")
52
+
53
+ # Dataset selection
54
+ st.sidebar.header("πŸ—‚οΈ Dataset Selection")
55
+
56
+ # Available datasets
57
+ available_datasets = [
58
+ "jailbreaks_dataset_with_results_reduced",
59
+ "jailbreaks_dataset_with_results",
60
+ "jailbreaks_dataset_with_results_filtered_successful_jailbreak",
61
+ "Custom..."
62
+ ]
63
+
64
+ selected_option = st.sidebar.selectbox(
65
+ "Select Dataset",
66
+ options=available_datasets,
67
+ index=0, # Default to reduced dataset
68
+ help="Choose which dataset to analyze"
69
+ )
70
+
71
+ # Handle custom dataset input
72
+ if selected_option == "Custom...":
73
+ selected_dataset = st.sidebar.text_input(
74
+ "Custom Dataset Name",
75
+ value="jailbreaks_dataset_with_results_reduced",
76
+ help="Enter the full dataset name (e.g., 'jailbreaks_dataset_with_results_reduced')"
77
+ )
78
+ if not selected_dataset.strip():
79
+ st.sidebar.warning("Please enter a dataset name")
80
+ st.stop()
81
+ else:
82
+ selected_dataset = selected_option
83
+
84
+ # Add refresh button
85
+ if st.sidebar.button("πŸ”„ Refresh Data", help="Clear cache and reload dataset"):
86
+ st.cache_data.clear()
87
+ st.rerun()
88
+
89
+ # Load data
90
+ with st.spinner(f"Loading dataset: {selected_dataset}..."):
91
+ try:
92
+ df, df_exploded = load_data(selected_dataset)
93
+ available_metrics = get_metrics(df_exploded)
94
+
95
+ # Display dataset info
96
+ col1, col2, col3, col4 = st.columns(4)
97
+ with col1:
98
+ st.metric("Dataset", selected_dataset.split('_')[-1].title())
99
+ with col2:
100
+ st.metric("Conversations", f"{len(df):,}")
101
+ with col3:
102
+ st.metric("Turns", f"{len(df_exploded):,}")
103
+ with col4:
104
+ st.metric("Metrics", len(available_metrics))
105
+
106
+ data_loaded = True
107
+ except Exception as e:
108
+ st.error(f"Error loading dataset: {e}")
109
+ st.info("Please check if the dataset exists and is accessible.")
110
+ st.info("πŸ’‘ Try using one of the predefined dataset options instead of custom input.")
111
+ data_loaded = False
112
+
113
+ if not data_loaded:
114
+ st.stop()
115
+
116
+ # Sidebar controls
117
+ st.sidebar.header("πŸŽ›οΈ Controls")
118
+
119
+ # Dataset type filter
120
+ dataset_types = df['type'].unique()
121
+ selected_types = st.sidebar.multiselect(
122
+ "Select Dataset Types",
123
+ options=dataset_types,
124
+ default=dataset_types,
125
+ help="Filter by conversation type"
126
+ )
127
+
128
+ # Role filter
129
+ if 'turn.role' in df_exploded.columns:
130
+ roles = df_exploded['turn.role'].unique()
131
+ selected_roles = st.sidebar.multiselect(
132
+ "Select Roles",
133
+ options=roles,
134
+ default=roles,
135
+ help="Filter by turn role"
136
+ )
137
+ else:
138
+ selected_roles = None
139
+
140
+ # Metric selection
141
+ st.sidebar.header("πŸ“Š Metrics")
142
+
143
+ # Dynamic metric categorization based on common patterns
144
+ def categorize_metrics(metrics):
145
+ """Dynamically categorize metrics based on naming patterns"""
146
+ categories = {"All": metrics} # Always include all metrics
147
+
148
+ # Common patterns to look for
149
+ patterns = {
150
+ "Length": ['length', 'byte', 'word', 'token', 'char'],
151
+ "Readability": ['readability', 'flesch', 'standard'],
152
+ "Compression": ['lzw', 'compression'],
153
+ "Language Model": ['ll_', 'rll_', 'logprob'],
154
+ "Working Memory": ['wm_'],
155
+ "Discourse": ['discourse'],
156
+ "Evaluation": ['rubric', 'evaluation', 'stealth'],
157
+ "Distribution": ['zipf', 'type_token'],
158
+ "Coherence": ['coherence'],
159
+ "Entity": ['entity', 'entities'],
160
+ "Cognitive": ['cognitive', 'load'],
161
+ }
162
+
163
+ # Categorize metrics
164
+ for category, keywords in patterns.items():
165
+ matching_metrics = [m for m in metrics if any(keyword in m.lower() for keyword in keywords)]
166
+ if matching_metrics:
167
+ categories[category] = matching_metrics
168
+
169
+ # Find uncategorized metrics
170
+ categorized = set()
171
+ for cat_metrics in categories.values():
172
+ if cat_metrics != metrics: # Skip "All" category
173
+ categorized.update(cat_metrics)
174
+
175
+ uncategorized = [m for m in metrics if m not in categorized]
176
+ if uncategorized:
177
+ categories["Other"] = uncategorized
178
+
179
+ return categories
180
+
181
+ metric_categories = categorize_metrics(available_metrics)
182
+
183
+ # Metric selection interface
184
+ selection_mode = st.sidebar.radio(
185
+ "Selection Mode",
186
+ ["By Category", "Search/Filter", "Select All"],
187
+ help="Choose how to select metrics"
188
+ )
189
+
190
+ if selection_mode == "By Category":
191
+ selected_category = st.sidebar.selectbox(
192
+ "Metric Category",
193
+ options=list(metric_categories.keys()),
194
+ help=f"Found {len(metric_categories)} categories"
195
+ )
196
+
197
+ available_in_category = metric_categories[selected_category]
198
+ default_selection = available_in_category[:5] if len(available_in_category) > 5 else available_in_category
199
+
200
+ # Add select all button for category
201
+ col1, col2 = st.sidebar.columns(2)
202
+ with col1:
203
+ if st.button("Select All", key="select_all_category"):
204
+ st.session_state.selected_metrics_category = available_in_category
205
+ with col2:
206
+ if st.button("Clear All", key="clear_all_category"):
207
+ st.session_state.selected_metrics_category = []
208
+
209
+ # Use session state for persistence
210
+ if "selected_metrics_category" not in st.session_state:
211
+ st.session_state.selected_metrics_category = default_selection
212
+
213
+ selected_metrics = st.sidebar.multiselect(
214
+ f"Select Metrics ({len(available_in_category)} available)",
215
+ options=available_in_category,
216
+ default=st.session_state.selected_metrics_category,
217
+ key="metrics_multiselect_category",
218
+ help="Choose metrics to visualize"
219
+ )
220
+
221
+ elif selection_mode == "Search/Filter":
222
+ search_term = st.sidebar.text_input(
223
+ "Search Metrics",
224
+ placeholder="Enter keywords to filter metrics...",
225
+ help="Search for metrics containing specific terms"
226
+ )
227
+
228
+ if search_term:
229
+ filtered_metrics = [m for m in available_metrics if search_term.lower() in m.lower()]
230
+ else:
231
+ filtered_metrics = available_metrics
232
+
233
+ st.sidebar.write(f"Found {len(filtered_metrics)} metrics")
234
+
235
+ # Add select all button for search results
236
+ col1, col2 = st.sidebar.columns(2)
237
+ with col1:
238
+ if st.button("Select All", key="select_all_search"):
239
+ st.session_state.selected_metrics_search = filtered_metrics
240
+ with col2:
241
+ if st.button("Clear All", key="clear_all_search"):
242
+ st.session_state.selected_metrics_search = []
243
+
244
+ # Use session state for persistence
245
+ if "selected_metrics_search" not in st.session_state:
246
+ st.session_state.selected_metrics_search = filtered_metrics[:5] if len(filtered_metrics) > 5 else filtered_metrics[:3]
247
+
248
+ selected_metrics = st.sidebar.multiselect(
249
+ "Select Metrics",
250
+ options=filtered_metrics,
251
+ default=st.session_state.selected_metrics_search,
252
+ key="metrics_multiselect_search",
253
+ help="Choose metrics to visualize"
254
+ )
255
+
256
+ else: # Select All
257
+ # Add select all button for all metrics
258
+ col1, col2 = st.sidebar.columns(2)
259
+ with col1:
260
+ if st.button("Select All", key="select_all_all"):
261
+ st.session_state.selected_metrics_all = available_metrics
262
+ with col2:
263
+ if st.button("Clear All", key="clear_all_all"):
264
+ st.session_state.selected_metrics_all = []
265
+
266
+ # Use session state for persistence
267
+ if "selected_metrics_all" not in st.session_state:
268
+ st.session_state.selected_metrics_all = available_metrics[:10] # Limit default to first 10 for performance
269
+
270
+ selected_metrics = st.sidebar.multiselect(
271
+ f"All Metrics ({len(available_metrics)} total)",
272
+ options=available_metrics,
273
+ default=st.session_state.selected_metrics_all,
274
+ key="metrics_multiselect_all",
275
+ help="All available metrics - be careful with performance for large selections"
276
+ )
277
+
278
+ # Show selection summary
279
+ if selected_metrics:
280
+ st.sidebar.success(f"Selected {len(selected_metrics)} metrics")
281
+
282
+ # Performance warning for large selections
283
+ if len(selected_metrics) > 20:
284
+ st.sidebar.warning(f"⚠️ Large selection ({len(selected_metrics)} metrics) may impact performance")
285
+ elif len(selected_metrics) > 50:
286
+ st.sidebar.error(f"🚨 Very large selection ({len(selected_metrics)} metrics) - consider reducing for better performance")
287
+ else:
288
+ st.sidebar.warning("No metrics selected")
289
+
290
+ # Metric info expander
291
+ with st.sidebar.expander("ℹ️ Metric Information", expanded=False):
292
+ st.write(f"**Total Available Metrics:** {len(available_metrics)}")
293
+ st.write(f"**Categories Found:** {len(metric_categories)}")
294
+
295
+ if st.checkbox("Show all metric names", key="show_all_metrics"):
296
+ st.write("**All Available Metrics:**")
297
+ for i, metric in enumerate(available_metrics, 1):
298
+ st.write(f"{i}. `{metric}`")
299
+
300
+ # Filter data
301
+ filtered_df = df[df['type'].isin(selected_types)] if selected_types else df
302
+ filtered_df_exploded = df_exploded[df_exploded['type'].isin(selected_types)] if selected_types else df_exploded
303
+
304
+ if selected_roles and 'turn.role' in filtered_df_exploded.columns:
305
+ filtered_df_exploded = filtered_df_exploded[filtered_df_exploded['turn.role'].isin(selected_roles)]
306
+
307
+ # Main content tabs
308
+ tab1, tab2, tab3, tab4 = st.tabs(["πŸ“Š Distributions", "πŸ”— Correlations", "πŸ“ˆ Comparisons", "🎯 Details"])
309
+
310
+ with tab1:
311
+ st.header("Distribution Analysis")
312
+
313
+ if not selected_metrics:
314
+ st.warning("Please select at least one metric to visualize.")
315
+ return
316
+
317
+ # Create distribution plots
318
+ for metric in selected_metrics:
319
+ full_metric_name = f"turn.turn_metrics.{metric}"
320
+
321
+ if full_metric_name not in filtered_df_exploded.columns:
322
+ st.warning(f"Metric {metric} not found in dataset")
323
+ continue
324
+
325
+ st.subheader(f"πŸ“Š {get_human_friendly_metric_name(metric)}")
326
+
327
+ # Clean the data
328
+ metric_data = filtered_df_exploded[['type', full_metric_name]].copy()
329
+ metric_data = metric_data.dropna()
330
+
331
+ if len(metric_data) == 0:
332
+ st.warning(f"No data available for {metric}")
333
+ continue
334
+
335
+ # Create plotly histogram
336
+ fig = px.histogram(
337
+ metric_data,
338
+ x=full_metric_name,
339
+ color='type',
340
+ marginal='box',
341
+ title=f"Distribution of {get_human_friendly_metric_name(metric)}",
342
+ color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None,
343
+ opacity=0.7,
344
+ nbins=50
345
+ )
346
+
347
+ fig.update_layout(
348
+ xaxis_title=get_human_friendly_metric_name(metric),
349
+ yaxis_title="Count",
350
+ height=400
351
+ )
352
+
353
+ st.plotly_chart(fig, use_container_width=True)
354
+
355
+ # Summary statistics
356
+ col1, col2 = st.columns(2)
357
+
358
+ with col1:
359
+ st.write("**Summary Statistics**")
360
+ summary_stats = metric_data.groupby('type')[full_metric_name].agg(['count', 'mean', 'std', 'min', 'max']).round(3)
361
+ st.dataframe(summary_stats)
362
+
363
+ with col2:
364
+ st.write("**Percentiles**")
365
+ percentiles = metric_data.groupby('type')[full_metric_name].quantile([0.25, 0.5, 0.75]).unstack().round(3)
366
+ percentiles.columns = ['25%', '50%', '75%']
367
+ st.dataframe(percentiles)
368
+
369
+ with tab2:
370
+ st.header("Correlation Analysis")
371
+
372
+ if len(selected_metrics) < 2:
373
+ st.warning("Please select at least 2 metrics for correlation analysis.")
374
+ else:
375
+ # Prepare correlation data
376
+ corr_columns = [f"turn.turn_metrics.{m}" for m in selected_metrics]
377
+ corr_data = filtered_df_exploded[corr_columns + ['type']].copy()
378
+
379
+ # Clean column names for display
380
+ corr_data.columns = [get_human_friendly_metric_name(col.replace('turn.turn_metrics.', '')) if col.startswith('turn.turn_metrics.') else col for col in corr_data.columns]
381
+
382
+ # Calculate correlation matrix
383
+ corr_matrix = corr_data.select_dtypes(include=[np.number]).corr()
384
+
385
+ # Create correlation heatmap
386
+ fig = px.imshow(
387
+ corr_matrix,
388
+ text_auto=True,
389
+ aspect="auto",
390
+ title="Correlation Matrix",
391
+ color_continuous_scale='RdBu_r',
392
+ zmin=-1, zmax=1
393
+ )
394
+
395
+ fig.update_layout(height=600)
396
+ st.plotly_chart(fig, use_container_width=True)
397
+
398
+ # Scatter plots for strong correlations
399
+ st.subheader("Strong Correlations")
400
+
401
+ # Find strong correlations (>0.7 or <-0.7)
402
+ strong_corrs = []
403
+ for i in range(len(corr_matrix.columns)):
404
+ for j in range(i+1, len(corr_matrix.columns)):
405
+ corr_val = corr_matrix.iloc[i, j]
406
+ if abs(corr_val) > 0.7:
407
+ strong_corrs.append((corr_matrix.columns[i], corr_matrix.columns[j], corr_val))
408
+
409
+ if strong_corrs:
410
+ for metric1, metric2, corr_val in strong_corrs[:3]: # Show top 3
411
+ fig = px.scatter(
412
+ corr_data,
413
+ x=metric1,
414
+ y=metric2,
415
+ color='type',
416
+ title=f"{metric1} vs {metric2} (r={corr_val:.3f})",
417
+ color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None,
418
+ opacity=0.6
419
+ )
420
+ st.plotly_chart(fig, use_container_width=True)
421
+ else:
422
+ st.info("No strong correlations (|r| > 0.7) found between selected metrics.")
423
+
424
+ with tab3:
425
+ st.header("Type Comparisons")
426
+
427
+ if not selected_metrics:
428
+ st.warning("Please select at least one metric to compare.")
429
+ else:
430
+ # Box plots for each metric
431
+ for metric in selected_metrics:
432
+ full_metric_name = f"turn.turn_metrics.{metric}"
433
+
434
+ if full_metric_name not in filtered_df_exploded.columns:
435
+ continue
436
+
437
+ st.subheader(f"πŸ“¦ {get_human_friendly_metric_name(metric)} by Type")
438
+
439
+ # Create box plot
440
+ fig = px.box(
441
+ filtered_df_exploded.dropna(subset=[full_metric_name]),
442
+ x='type',
443
+ y=full_metric_name,
444
+ title=f"Distribution of {get_human_friendly_metric_name(metric)} by Type",
445
+ color='type',
446
+ color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None
447
+ )
448
+
449
+ fig.update_layout(
450
+ xaxis_title="Dataset Type",
451
+ yaxis_title=get_human_friendly_metric_name(metric),
452
+ height=400
453
+ )
454
+
455
+ st.plotly_chart(fig, use_container_width=True)
456
+
457
+ with tab4:
458
+ st.header("Detailed View")
459
+
460
+ # Data overview
461
+ st.subheader("πŸ“‹ Dataset Overview")
462
+
463
+ st.info(f"**Current Dataset:** `{selected_dataset}`")
464
+
465
+ col1, col2, col3 = st.columns(3)
466
+
467
+ with col1:
468
+ st.metric("Total Conversations", len(filtered_df))
469
+
470
+ with col2:
471
+ st.metric("Total Turns", len(filtered_df_exploded))
472
+
473
+ with col3:
474
+ st.metric("Available Metrics", len(available_metrics))
475
+
476
+ # Type distribution
477
+ st.subheader("πŸ“Š Type Distribution")
478
+ type_counts = filtered_df['type'].value_counts()
479
+
480
+ fig = px.pie(
481
+ values=type_counts.values,
482
+ names=type_counts.index,
483
+ title="Distribution of Conversation Types",
484
+ color_discrete_map=PLOT_PALETTE if len(type_counts) <= 3 else None
485
+ )
486
+
487
+ st.plotly_chart(fig, use_container_width=True)
488
+
489
+ # Sample data
490
+ st.subheader("πŸ“„ Sample Data")
491
+
492
+ if st.checkbox("Show raw data sample"):
493
+ sample_cols = ['type'] + [f"turn.turn_metrics.{m}" for m in selected_metrics if f"turn.turn_metrics.{m}" in filtered_df_exploded.columns]
494
+ sample_data = filtered_df_exploded[sample_cols].head(100)
495
+ st.dataframe(sample_data)
496
+
497
+ # Metric availability
498
+ st.subheader("πŸ“Š Metric Availability")
499
+
500
+ metric_completeness = {}
501
+ for metric in selected_metrics:
502
+ full_metric_name = f"turn.turn_metrics.{metric}"
503
+ if full_metric_name in filtered_df_exploded.columns:
504
+ completeness = (1 - filtered_df_exploded[full_metric_name].isna().sum() / len(filtered_df_exploded)) * 100
505
+ metric_completeness[get_human_friendly_metric_name(metric)] = completeness
506
+
507
+ if metric_completeness:
508
+ completeness_df = pd.DataFrame(list(metric_completeness.items()), columns=['Metric', 'Completeness (%)'])
509
+ fig = px.bar(
510
+ completeness_df,
511
+ x='Metric',
512
+ y='Completeness (%)',
513
+ title="Data Completeness by Metric",
514
+ color='Completeness (%)',
515
+ color_continuous_scale='Viridis'
516
+ )
517
+ fig.update_layout(xaxis_tickangle=-45, height=400)
518
+ st.plotly_chart(fig, use_container_width=True)
519
+
520
+ if __name__ == "__main__":
521
+ main()