acmc commited on
Commit
d6b031d
Β·
verified Β·
1 Parent(s): 6911975

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +695 -438
streamlit_app.py CHANGED
@@ -14,7 +14,8 @@ from plotly.subplots import make_subplots
14
  import warnings
15
  import datasets
16
  import logging
17
- warnings.filterwarnings('ignore')
 
18
 
19
  # Configure logging
20
  logging.basicConfig(level=logging.INFO)
@@ -23,41 +24,42 @@ logger = logging.getLogger(__name__)
23
  # Constants
24
  PLOT_PALETTE = {
25
  "jailbreak": "#D000D8", # Purple
26
- "benign": "#008393", # Cyan
27
- "control": "#EF0000", # Red
28
  }
29
 
 
30
  # Utility functions
31
  def load_and_prepare_dataset(dataset_config):
32
  """Load the risky conversations dataset and prepare it for analysis."""
33
  logger.info("Loading dataset...")
34
-
35
  dataset_name = dataset_config["dataset_name"]
36
  logger.info(f"Loading dataset: {dataset_name}")
37
-
38
  # Load the dataset
39
  dataset = datasets.load_dataset(dataset_name, split="train")
40
  logger.info(f"Dataset loaded with {len(dataset)} conversations")
41
-
42
  # Convert to pandas
43
  pandas_dataset = dataset.to_pandas()
44
-
45
  # Explode the conversation column
46
  pandas_dataset_exploded = pandas_dataset.explode("conversation")
47
  pandas_dataset_exploded = pandas_dataset_exploded.reset_index(drop=True)
48
-
49
  # Normalize conversation data
50
  conversations_unfolded = pd.json_normalize(
51
  pandas_dataset_exploded["conversation"],
52
  )
53
  conversations_unfolded = conversations_unfolded.add_prefix("turn.")
54
-
55
  # Ensure there's a 'conversation_metrics' column, even if empty
56
  if "conversation_metrics" not in pandas_dataset_exploded.columns:
57
  pandas_dataset_exploded["conversation_metrics"] = [{}] * len(
58
  pandas_dataset_exploded
59
  )
60
-
61
  # Normalize conversation metrics
62
  conversations_metrics_unfolded = pd.json_normalize(
63
  pandas_dataset_exploded["conversation_metrics"]
@@ -65,7 +67,7 @@ def load_and_prepare_dataset(dataset_config):
65
  conversations_metrics_unfolded = conversations_metrics_unfolded.add_prefix(
66
  "conversation_metrics."
67
  )
68
-
69
  # Concatenate all dataframes
70
  pandas_dataset_exploded = pd.concat(
71
  [
@@ -77,42 +79,41 @@ def load_and_prepare_dataset(dataset_config):
77
  ],
78
  axis=1,
79
  )
80
-
81
  logger.info(f"Dataset prepared with {len(pandas_dataset_exploded)} turns")
82
  return pandas_dataset, pandas_dataset_exploded
83
 
 
84
  def get_available_turn_metrics(dataset_exploded):
85
  """Dynamically discover all available turn metrics from the dataset."""
86
  # Find all columns that contain turn metrics
87
  turn_metric_columns = [
88
  col for col in dataset_exploded.columns if col.startswith("turn.turn_metrics.")
89
  ]
90
-
91
  # Extract the metric names by removing the prefix
92
  metrics = [col.replace("turn.turn_metrics.", "") for col in turn_metric_columns]
93
-
94
  # Sort metrics for consistent ordering
95
  metrics = sorted(metrics)
96
-
97
  logger.info(f"Discovered {len(metrics)} available turn metrics")
98
  return metrics
99
 
 
100
  def get_human_friendly_metric_name(metric_name):
101
  """Convert a metric name to a human-friendly display name."""
102
  # Mapping of metric names to human-friendly display names
103
  metric_display_names = {
104
  # Length metrics
105
  "conversational_length_tokens": "Length in Tokens",
106
- "conversational_length_words": "Length in Words",
107
  "conversational_length_bytes": "Length in Bytes",
108
-
109
  # Readability metrics
110
  "readability_flesch_score": "Flesch Reading Ease",
111
  "readability_standard_score": "Reading Grade Level",
112
-
113
  # Compression metrics
114
  "lzw_compression": "LZW Compression Ratio",
115
-
116
  # Language model metrics
117
  "ll_mean_logprobs": "Mean Log Probability",
118
  "ll_max_logprobs": "Maximum Log Probability",
@@ -120,21 +121,17 @@ def get_human_friendly_metric_name(metric_name):
120
  "rll_mean_logprobs": "Reverse Mean Log Probability",
121
  "rll_max_logprobs": "Reverse Maximum Log Probability",
122
  "rll_min_logprobs": "Reverse Minimum Log Probability",
123
-
124
  # Lexical diversity
125
  "type_token_ratio": "Type-Token Ratio",
126
-
127
  # Language distribution
128
  "zipf_score": "Zipf Distribution Score",
129
-
130
  # Evaluation metrics
131
  "rubric_evaluation.Stealth": "Stealth Score",
132
-
133
  # Working memory metrics
134
  "wm_total_entity_references": "Total Entity References",
135
  "wm_unique_entities_count": "Unique Entities Count",
136
  "wm_named_entity_mentions": "Named Entity Mentions",
137
- "wm_concept_noun_mentions": "Concept Noun Mentions",
138
  "wm_pronoun_references": "Pronoun References",
139
  "wm_entity_density_per_word": "Entity Density per Word",
140
  "wm_entity_density_per_100_words": "Entity Density per 100 Words",
@@ -143,7 +140,6 @@ def get_human_friendly_metric_name(metric_name):
143
  "wm_entity_repetition_ratio": "Entity Repetition Ratio",
144
  "wm_cognitive_load_score": "Cognitive Load Score",
145
  "wm_high_cognitive_load": "High Cognitive Load",
146
-
147
  # Discourse coherence metrics
148
  "discourse_coherence_to_next_user": "Coherence to Next User Turn",
149
  "discourse_coherence_to_next_turn": "Coherence to Next Turn",
@@ -152,164 +148,241 @@ def get_human_friendly_metric_name(metric_name):
152
  "discourse_user_topic_drift": "User Topic Drift",
153
  "discourse_user_entity_continuity": "User Entity Continuity",
154
  "discourse_num_user_turns": "Number of User Turns",
155
-
156
  # Tokens per byte
157
  "tokens_per_byte": "Tokens per Byte",
158
  }
159
-
160
  # Check exact match first
161
  if metric_name in metric_display_names:
162
  return metric_display_names[metric_name]
163
-
164
  # Handle conversation-level aggregations
165
- for suffix in ["_conversation_mean", "_conversation_min", "_conversation_max", "_conversation_std", "_conversation_count"]:
 
 
 
 
 
 
166
  if metric_name.endswith(suffix):
167
- base_metric = metric_name[:-len(suffix)]
168
  if base_metric in metric_display_names:
169
  agg_type = suffix.split("_")[-1].title()
170
  return f"{metric_display_names[base_metric]} ({agg_type})"
171
-
172
  # Handle turn-level metrics with "turn.turn_metrics." prefix
173
  if metric_name.startswith("turn.turn_metrics."):
174
- base_metric = metric_name[len("turn.turn_metrics."):]
175
  if base_metric in metric_display_names:
176
  return metric_display_names[base_metric]
177
-
178
  # Fallback: convert underscores to spaces and title case
179
  clean_name = metric_name
180
  for prefix in ["turn.turn_metrics.", "conversation_metrics.", "turn_metrics."]:
181
  if clean_name.startswith(prefix):
182
- clean_name = clean_name[len(prefix):]
183
  break
184
-
185
  # Convert to human-readable format
186
  clean_name = clean_name.replace("_", " ").title()
187
  return clean_name
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  # Setup page config
190
  st.set_page_config(
191
  page_title="Complexity Metrics Explorer",
192
  page_icon="πŸ“Š",
193
  layout="wide",
194
- initial_sidebar_state="expanded"
195
  )
196
 
 
197
  # Cache data loading
198
  @st.cache_data
199
  def load_data(dataset_name):
200
  """Load and cache the dataset"""
201
- df, df_exploded = load_and_prepare_dataset({
202
- 'dataset_name': dataset_name
203
- })
204
  return df, df_exploded
205
 
 
206
  @st.cache_data
207
  def get_metrics(df_exploded):
208
  """Get available metrics from the dataset"""
209
  return get_available_turn_metrics(df_exploded)
210
 
 
211
  def main():
212
  st.title("πŸ” Complexity Metrics Explorer")
213
- st.markdown("Interactive visualization of conversation complexity metrics across different dataset types.")
214
-
 
 
215
  # Dataset selection
216
  st.sidebar.header("πŸ—‚οΈ Dataset Selection")
217
-
218
  # Available datasets
219
  available_datasets = [
220
  "risky-conversations/jailbreaks_dataset_with_results_reduced",
221
  "risky-conversations/jailbreaks_dataset_with_results",
222
  "risky-conversations/jailbreaks_dataset_with_results_filtered_successful_jailbreak",
223
- "Custom..."
224
  ]
225
-
226
  selected_option = st.sidebar.selectbox(
227
  "Select Dataset",
228
  options=available_datasets,
229
  index=0, # Default to reduced dataset
230
- help="Choose which dataset to analyze"
231
  )
232
-
233
  # Handle custom dataset input
234
  if selected_option == "Custom...":
235
  selected_dataset = st.sidebar.text_input(
236
  "Custom Dataset Name",
237
  value="risky-conversations/jailbreaks_dataset_with_results_reduced",
238
- help="Enter the full dataset name (e.g., 'risky-conversations/jailbreaks_dataset_with_results_reduced')"
239
  )
240
  if not selected_dataset.strip():
241
  st.sidebar.warning("Please enter a dataset name")
242
  st.stop()
243
  else:
244
  selected_dataset = selected_option
245
-
246
  # Add refresh button
247
  if st.sidebar.button("πŸ”„ Refresh Data", help="Clear cache and reload dataset"):
248
  st.cache_data.clear()
249
  st.rerun()
250
-
251
  # Load data
252
  with st.spinner(f"Loading dataset: {selected_dataset}..."):
253
  try:
254
  df, df_exploded = load_data(selected_dataset)
255
  available_metrics = get_metrics(df_exploded)
256
-
257
  # Display dataset info
258
  col1, col2, col3, col4 = st.columns(4)
259
  with col1:
260
- st.metric("Dataset", selected_dataset.split('_')[-1].title())
261
  with col2:
262
  st.metric("Conversations", f"{len(df):,}")
263
  with col3:
264
  st.metric("Turns", f"{len(df_exploded):,}")
265
  with col4:
266
  st.metric("Metrics", len(available_metrics))
267
-
268
  data_loaded = True
269
  except Exception as e:
270
  st.error(f"Error loading dataset: {e}")
271
  st.info("Please check if the dataset exists and is accessible.")
272
- st.info("πŸ’‘ Try using one of the predefined dataset options instead of custom input.")
 
 
273
  data_loaded = False
274
-
275
  if not data_loaded:
276
  st.stop()
277
-
278
  # Sidebar controls
279
  st.sidebar.header("πŸŽ›οΈ Controls")
280
-
281
  # Dataset type filter
282
- dataset_types = df['type'].unique()
283
  selected_types = st.sidebar.multiselect(
284
  "Select Dataset Types",
285
  options=dataset_types,
286
  default=dataset_types,
287
- help="Filter by conversation type"
288
  )
289
-
290
  # Role filter
291
- if 'turn.role' in df_exploded.columns:
292
- roles = df_exploded['turn.role'].dropna().unique()
293
  # Assert only user and assistant roles exist
294
- expected_roles = {'user', 'assistant'}
295
  actual_roles = set(roles)
296
- assert actual_roles.issubset(expected_roles), f"Unexpected roles found: {actual_roles - expected_roles}. Expected only 'user' and 'assistant'"
297
-
 
 
298
  st.sidebar.subheader("πŸ‘₯ Role Filter")
299
  col1, col2 = st.sidebar.columns(2)
300
-
301
  with col1:
302
  include_user = st.checkbox("User", value=True, help="Include user turns")
303
  with col2:
304
- include_assistant = st.checkbox("Assistant", value=True, help="Include assistant turns")
305
-
 
 
306
  # Build selected roles list
307
  selected_roles = []
308
- if include_user and 'user' in roles:
309
- selected_roles.append('user')
310
- if include_assistant and 'assistant' in roles:
311
- selected_roles.append('assistant')
312
-
313
  # Show selection info
314
  if selected_roles:
315
  st.sidebar.success(f"Including: {', '.join(selected_roles)}")
@@ -317,82 +390,98 @@ def main():
317
  st.sidebar.warning("No roles selected")
318
  else:
319
  selected_roles = None
320
-
321
  # Filter data based on selections
322
- filtered_df = df[df['type'].isin(selected_types)] if selected_types else df
323
- filtered_df_exploded = df_exploded[df_exploded['type'].isin(selected_types)] if selected_types else df_exploded
324
-
325
- if selected_roles and 'turn.role' in filtered_df_exploded.columns:
326
- filtered_df_exploded = filtered_df_exploded[filtered_df_exploded['turn.role'].isin(selected_roles)]
 
 
 
 
 
 
327
  elif selected_roles is not None and len(selected_roles) == 0:
328
  # If roles exist but none are selected, show empty dataset
329
- filtered_df_exploded = filtered_df_exploded.iloc[0:0] # Empty dataframe with same structure
330
-
 
 
331
  # Check if we have data after filtering
332
  if len(filtered_df_exploded) == 0:
333
- st.error("No data available with current filters. Please adjust your selection.")
 
 
334
  st.stop()
335
-
336
  # Metric selection
337
  st.sidebar.header("πŸ“Š Metrics")
338
-
339
  # Dynamic metric categorization based on common patterns
340
  def categorize_metrics(metrics):
341
  """Dynamically categorize metrics based on naming patterns"""
342
  categories = {"All": metrics} # Always include all metrics
343
-
344
  # Common patterns to look for
345
  patterns = {
346
- "Length": ['length', 'byte', 'word', 'token', 'char'],
347
- "Readability": ['readability', 'flesch', 'standard'],
348
- "Compression": ['lzw', 'compression'],
349
- "Language Model": ['ll_', 'rll_', 'logprob'],
350
- "Working Memory": ['wm_'],
351
- "Discourse": ['discourse'],
352
- "Evaluation": ['rubric', 'evaluation', 'stealth'],
353
- "Distribution": ['zipf', 'type_token'],
354
- "Coherence": ['coherence'],
355
- "Entity": ['entity', 'entities'],
356
- "Cognitive": ['cognitive', 'load'],
357
  }
358
-
359
  # Categorize metrics
360
  for category, keywords in patterns.items():
361
- matching_metrics = [m for m in metrics if any(keyword in m.lower() for keyword in keywords)]
 
 
362
  if matching_metrics:
363
  categories[category] = matching_metrics
364
-
365
  # Find uncategorized metrics
366
  categorized = set()
367
  for cat_metrics in categories.values():
368
  if cat_metrics != metrics: # Skip "All" category
369
  categorized.update(cat_metrics)
370
-
371
  uncategorized = [m for m in metrics if m not in categorized]
372
  if uncategorized:
373
  categories["Other"] = uncategorized
374
-
375
  return categories
376
-
377
  metric_categories = categorize_metrics(available_metrics)
378
-
379
  # Metric selection interface
380
  selection_mode = st.sidebar.radio(
381
  "Selection Mode",
382
  ["By Category", "Search/Filter", "Select All"],
383
- help="Choose how to select metrics"
384
  )
385
-
386
  if selection_mode == "By Category":
387
  selected_category = st.sidebar.selectbox(
388
- "Metric Category",
389
  options=list(metric_categories.keys()),
390
- help=f"Found {len(metric_categories)} categories"
391
  )
392
-
393
  available_in_category = metric_categories[selected_category]
394
- default_selection = available_in_category[:5] if len(available_in_category) > 5 else available_in_category
395
-
 
 
 
 
396
  # Add select all button for category
397
  col1, col2 = st.sidebar.columns(2)
398
  with col1:
@@ -401,33 +490,35 @@ def main():
401
  with col2:
402
  if st.button("Clear All", key="clear_all_category"):
403
  st.session_state.selected_metrics_category = []
404
-
405
  # Use session state for persistence
406
  if "selected_metrics_category" not in st.session_state:
407
  st.session_state.selected_metrics_category = default_selection
408
-
409
  selected_metrics = st.sidebar.multiselect(
410
  f"Select Metrics ({len(available_in_category)} available)",
411
  options=available_in_category,
412
  default=st.session_state.selected_metrics_category,
413
  key="metrics_multiselect_category",
414
- help="Choose metrics to visualize"
415
  )
416
-
417
  elif selection_mode == "Search/Filter":
418
  search_term = st.sidebar.text_input(
419
  "Search Metrics",
420
  placeholder="Enter keywords to filter metrics...",
421
- help="Search for metrics containing specific terms"
422
  )
423
-
424
  if search_term:
425
- filtered_metrics = [m for m in available_metrics if search_term.lower() in m.lower()]
 
 
426
  else:
427
  filtered_metrics = available_metrics
428
-
429
  st.sidebar.write(f"Found {len(filtered_metrics)} metrics")
430
-
431
  # Add select all button for search results
432
  col1, col2 = st.sidebar.columns(2)
433
  with col1:
@@ -436,19 +527,23 @@ def main():
436
  with col2:
437
  if st.button("Clear All", key="clear_all_search"):
438
  st.session_state.selected_metrics_search = []
439
-
440
  # Use session state for persistence
441
  if "selected_metrics_search" not in st.session_state:
442
- st.session_state.selected_metrics_search = filtered_metrics[:5] if len(filtered_metrics) > 5 else filtered_metrics[:3]
443
-
 
 
 
 
444
  selected_metrics = st.sidebar.multiselect(
445
  "Select Metrics",
446
  options=filtered_metrics,
447
  default=st.session_state.selected_metrics_search,
448
  key="metrics_multiselect_search",
449
- help="Choose metrics to visualize"
450
  )
451
-
452
  else: # Select All
453
  # Add select all button for all metrics
454
  col1, col2 = st.sidebar.columns(2)
@@ -458,262 +553,309 @@ def main():
458
  with col2:
459
  if st.button("Clear All", key="clear_all_all"):
460
  st.session_state.selected_metrics_all = []
461
-
462
  # Use session state for persistence
463
  if "selected_metrics_all" not in st.session_state:
464
- st.session_state.selected_metrics_all = available_metrics[:10] # Limit default to first 10 for performance
465
-
 
 
466
  selected_metrics = st.sidebar.multiselect(
467
  f"All Metrics ({len(available_metrics)} total)",
468
  options=available_metrics,
469
  default=st.session_state.selected_metrics_all,
470
  key="metrics_multiselect_all",
471
- help="All available metrics - be careful with performance for large selections"
472
  )
473
-
474
  # Show selection summary
475
  if selected_metrics:
476
  st.sidebar.success(f"Selected {len(selected_metrics)} metrics")
477
-
478
  # Performance warning for large selections
479
  if len(selected_metrics) > 20:
480
- st.sidebar.warning(f"⚠️ Large selection ({len(selected_metrics)} metrics) may impact performance")
 
 
481
  elif len(selected_metrics) > 50:
482
- st.sidebar.error(f"🚨 Very large selection ({len(selected_metrics)} metrics) - consider reducing for better performance")
 
 
483
  else:
484
  st.sidebar.warning("No metrics selected")
485
-
486
  # Metric info expander
487
  with st.sidebar.expander("ℹ️ Metric Information", expanded=False):
488
  st.write(f"**Total Available Metrics:** {len(available_metrics)}")
489
  st.write(f"**Categories Found:** {len(metric_categories)}")
490
-
491
  if st.checkbox("Show all metric names", key="show_all_metrics"):
492
  st.write("**All Available Metrics:**")
493
  for i, metric in enumerate(available_metrics, 1):
494
  st.write(f"{i}. `{metric}`")
495
-
496
  # Main content tabs
497
- tab1, tab2, tab3, tab4, tab5 = st.tabs(["πŸ“Š Distributions", "πŸ”— Correlations", "πŸ“ˆ Comparisons", "πŸ” Conversation", "🎯 Details"])
498
-
 
 
 
 
 
 
 
 
499
  with tab1:
500
  st.header("Distribution Analysis")
501
-
502
  if not selected_metrics:
503
  st.warning("Please select at least one metric to visualize.")
504
  return
505
-
506
- # Create distribution plots
507
- for metric in selected_metrics:
508
- full_metric_name = f"turn.turn_metrics.{metric}"
509
-
510
- if full_metric_name not in filtered_df_exploded.columns:
511
- st.warning(f"Metric {metric} not found in dataset")
512
- continue
513
-
514
- st.subheader(f"πŸ“Š {get_human_friendly_metric_name(metric)}")
515
-
516
- # Clean the data
517
- metric_data = filtered_df_exploded[['type', full_metric_name]].copy()
518
- metric_data = metric_data.dropna()
519
-
520
- if len(metric_data) == 0:
521
- st.warning(f"No data available for {metric}")
522
- continue
523
-
524
- # Create plotly histogram
525
- fig = px.histogram(
526
- metric_data,
527
- x=full_metric_name,
528
- color='type',
529
- marginal='box',
530
- title=f"Distribution of {get_human_friendly_metric_name(metric)}",
531
- color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None,
532
- opacity=0.7,
533
- nbins=50
534
- )
535
-
536
- fig.update_layout(
537
- xaxis_title=get_human_friendly_metric_name(metric),
538
- yaxis_title="Count",
539
- height=400
540
- )
541
-
542
- st.plotly_chart(fig, use_container_width=True)
543
-
544
- # Summary statistics
545
- col1, col2 = st.columns(2)
546
-
547
- with col1:
548
- st.write("**Summary Statistics**")
549
- summary_stats = metric_data.groupby('type')[full_metric_name].agg(['count', 'mean', 'std', 'min', 'max']).round(3)
550
- st.dataframe(summary_stats)
551
-
552
- with col2:
553
- st.write("**Percentiles**")
554
- percentiles = metric_data.groupby('type')[full_metric_name].quantile([0.25, 0.5, 0.75]).unstack().round(3)
555
- percentiles.columns = ['25%', '50%', '75%']
556
- st.dataframe(percentiles)
557
-
558
  with tab2:
559
  st.header("Correlation Analysis")
560
-
561
  if len(selected_metrics) < 2:
562
  st.warning("Please select at least 2 metrics for correlation analysis.")
563
  else:
564
- # Prepare correlation data
565
- corr_columns = [f"turn.turn_metrics.{m}" for m in selected_metrics]
566
- corr_data = filtered_df_exploded[corr_columns + ['type']].copy()
567
-
568
- # Clean column names for display
569
- corr_data.columns = [get_human_friendly_metric_name(col.replace('turn.turn_metrics.', '')) if col.startswith('turn.turn_metrics.') else col for col in corr_data.columns]
570
-
571
- # Calculate correlation matrix
572
- corr_matrix = corr_data.select_dtypes(include=[np.number]).corr()
573
-
574
- # Create correlation heatmap
575
- fig = px.imshow(
576
- corr_matrix,
577
- text_auto=True,
578
- aspect="auto",
579
- title="Correlation Matrix",
580
- color_continuous_scale='RdBu_r',
581
- zmin=-1, zmax=1
582
  )
583
-
584
- fig.update_layout(height=600)
585
- st.plotly_chart(fig, use_container_width=True)
586
-
587
- # Scatter plots for strong correlations
588
- st.subheader("Strong Correlations")
589
-
590
- # Find strong correlations (>0.7 or <-0.7)
591
- strong_corrs = []
592
- for i in range(len(corr_matrix.columns)):
593
- for j in range(i+1, len(corr_matrix.columns)):
594
- corr_val = corr_matrix.iloc[i, j]
595
- if abs(corr_val) > 0.7:
596
- strong_corrs.append((corr_matrix.columns[i], corr_matrix.columns[j], corr_val))
597
-
598
- if strong_corrs:
599
- for metric1, metric2, corr_val in strong_corrs[:3]: # Show top 3
600
- fig = px.scatter(
601
- corr_data,
602
- x=metric1,
603
- y=metric2,
604
- color='type',
605
- title=f"{metric1} vs {metric2} (r={corr_val:.3f})",
606
- color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None,
607
- opacity=0.6
608
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  st.plotly_chart(fig, use_container_width=True)
610
- else:
611
- st.info("No strong correlations (|r| > 0.7) found between selected metrics.")
612
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  with tab3:
614
  st.header("Type Comparisons")
615
-
616
  if not selected_metrics:
617
  st.warning("Please select at least one metric to compare.")
618
  else:
619
  # Box plots for each metric
620
  for metric in selected_metrics:
621
  full_metric_name = f"turn.turn_metrics.{metric}"
622
-
623
  if full_metric_name not in filtered_df_exploded.columns:
624
  continue
625
-
626
  st.subheader(f"πŸ“¦ {get_human_friendly_metric_name(metric)} by Type")
627
-
628
  # Create box plot
629
  fig = px.box(
630
  filtered_df_exploded.dropna(subset=[full_metric_name]),
631
- x='type',
632
  y=full_metric_name,
633
  title=f"Distribution of {get_human_friendly_metric_name(metric)} by Type",
634
- color='type',
635
- color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None
 
 
636
  )
637
-
638
  fig.update_layout(
639
  xaxis_title="Dataset Type",
640
  yaxis_title=get_human_friendly_metric_name(metric),
641
- height=400
642
  )
643
-
644
  st.plotly_chart(fig, use_container_width=True)
645
-
646
  with tab4:
647
  st.header("Individual Conversation Analysis")
648
-
649
  # Conversation selector
650
  st.subheader("πŸ” Select Conversation")
651
-
652
  # Get unique conversations with some metadata
653
  conversation_info = []
654
  for idx, row in filtered_df.iterrows():
655
- conv_type = row['type']
656
  # Get basic info about the conversation
657
- conv_turns = len(row.get('conversation', []))
658
- conversation_info.append({
659
- 'index': idx,
660
- 'type': conv_type,
661
- 'turns': conv_turns,
662
- 'display': f"Conversation {idx} ({conv_type}) - {conv_turns} turns"
663
- })
664
-
 
 
665
  # Sort by type and number of turns for better organization
666
- conversation_info = sorted(conversation_info, key=lambda x: (x['type'], -x['turns']))
667
-
 
 
668
  # Conversation selection
669
  col1, col2 = st.columns([3, 1])
670
-
671
  with col1:
672
  selected_conv_display = st.selectbox(
673
  "Choose a conversation to analyze",
674
- options=[conv['display'] for conv in conversation_info],
675
- help="Select a conversation to view detailed metrics and content"
676
  )
677
-
678
  with col2:
679
  if st.button("🎲 Random", help="Select a random conversation"):
680
  import random
681
- selected_conv_display = random.choice([conv['display'] for conv in conversation_info])
 
 
 
682
  st.rerun()
683
-
684
  # Get the selected conversation data
685
- selected_conv_info = next(conv for conv in conversation_info if conv['display'] == selected_conv_display)
686
- selected_idx = selected_conv_info['index']
 
 
 
 
687
  selected_conversation = filtered_df.iloc[selected_idx]
688
-
689
  # Display conversation metadata
690
  st.subheader("πŸ“‹ Conversation Overview")
691
-
692
  # First row - basic info
693
  col1, col2, col3, col4 = st.columns(4)
694
  with col1:
695
- st.metric("Type", selected_conversation['type'])
696
  with col2:
697
  st.metric("Index", selected_idx)
698
  with col3:
699
- st.metric("Total Turns", len(selected_conversation.get('conversation', [])))
700
  with col4:
701
  # Count user vs assistant turns
702
- roles = [turn.get('role', 'unknown') for turn in selected_conversation.get('conversation', [])]
703
- user_turns = roles.count('user')
704
- assistant_turns = roles.count('assistant')
 
 
 
705
  st.metric("User/Assistant", f"{user_turns}/{assistant_turns}")
706
-
707
  # Second row - additional metadata
708
  col1, col2, col3 = st.columns(3)
709
  with col1:
710
- provenance = selected_conversation.get('provenance_dataset', 'Unknown')
711
  st.metric("Dataset Source", provenance)
712
  with col2:
713
- language = selected_conversation.get('language', 'Unknown')
714
- st.metric("Language", language.upper() if language else 'Unknown')
715
  with col3:
716
- timestamp = selected_conversation.get('timestamp', None)
717
  if timestamp:
718
  # Handle different timestamp formats
719
  if isinstance(timestamp, str):
@@ -722,139 +864,184 @@ def main():
722
  st.metric("Timestamp", str(timestamp))
723
  else:
724
  st.metric("Timestamp", "Not Available")
725
-
726
  # Add toxicity summary
727
- conversation_turns_temp = selected_conversation.get('conversation', [])
728
- if hasattr(conversation_turns_temp, 'tolist'):
729
  conversation_turns_temp = conversation_turns_temp.tolist()
730
  elif conversation_turns_temp is None:
731
  conversation_turns_temp = []
732
-
733
  if len(conversation_turns_temp) > 0:
734
  # Calculate overall toxicity statistics
735
  all_toxicities = []
736
  for turn in conversation_turns_temp:
737
- toxicities = turn.get('toxicities', {})
738
- if toxicities and 'toxicity' in toxicities:
739
- all_toxicities.append(toxicities['toxicity'])
740
-
741
  if all_toxicities:
742
  avg_toxicity = sum(all_toxicities) / len(all_toxicities)
743
  max_toxicity = max(all_toxicities)
744
-
745
  st.markdown("**πŸ” Toxicity Summary:**")
746
  col1, col2, col3 = st.columns(3)
747
  with col1:
748
  # Color code average toxicity
749
  if avg_toxicity > 0.5:
750
- st.metric("Average Toxicity", f"{avg_toxicity:.4f}", delta="HIGH", delta_color="inverse")
 
 
 
 
 
751
  elif avg_toxicity > 0.1:
752
- st.metric("Average Toxicity", f"{avg_toxicity:.4f}", delta="MED", delta_color="off")
 
 
 
 
 
753
  else:
754
- st.metric("Average Toxicity", f"{avg_toxicity:.4f}", delta="LOW", delta_color="normal")
755
-
 
 
 
 
 
756
  with col2:
757
  # Color code max toxicity
758
  if max_toxicity > 0.5:
759
- st.metric("Max Toxicity", f"{max_toxicity:.4f}", delta="HIGH", delta_color="inverse")
 
 
 
 
 
760
  elif max_toxicity > 0.1:
761
- st.metric("Max Toxicity", f"{max_toxicity:.4f}", delta="MED", delta_color="off")
 
 
 
 
 
762
  else:
763
- st.metric("Max Toxicity", f"{max_toxicity:.4f}", delta="LOW", delta_color="normal")
764
-
 
 
 
 
 
765
  with col3:
766
  high_tox_turns = sum(1 for t in all_toxicities if t > 0.5)
767
  st.metric("High Toxicity Turns", high_tox_turns)
768
-
769
  # Get conversation turns with metrics
770
- conv_turns_data = filtered_df_exploded[filtered_df_exploded.index.isin(
771
- filtered_df_exploded[filtered_df_exploded.index // len(filtered_df_exploded) * len(filtered_df) +
772
- filtered_df_exploded.index % len(filtered_df) == selected_idx].index
773
- )].copy()
774
-
 
 
 
 
 
 
 
775
  # Alternative approach: filter by matching all conversation data
776
  # This is more reliable but less efficient
777
  conv_turns_data = []
778
  start_idx = None
779
  for idx, row in filtered_df_exploded.iterrows():
780
  # Check if this row belongs to our selected conversation
781
- if (row['type'] == selected_conversation['type'] and
782
- hasattr(row, 'conversation') and
783
- row.get('conversation') is not None):
 
 
784
  # This is a simplified approach - in reality you'd need better conversation matching
785
  pass
786
-
787
  # Simpler approach: get all turns from the conversation directly
788
- conversation_turns = selected_conversation.get('conversation', [])
789
-
790
  # Ensure conversation_turns is a list and handle different data types
791
- if hasattr(conversation_turns, 'tolist'):
792
  conversation_turns = conversation_turns.tolist()
793
  elif conversation_turns is None:
794
  conversation_turns = []
795
-
796
  if len(conversation_turns) > 0:
797
  # Display conversation content with metrics
798
  st.subheader("πŸ’¬ Conversation with Metrics")
799
-
800
  # Get actual turn-level data for this conversation
801
  turn_metric_columns = [f"turn.turn_metrics.{m}" for m in selected_metrics]
802
- available_columns = [col for col in turn_metric_columns if col in filtered_df_exploded.columns]
803
-
 
 
 
 
804
  # Get sample metrics for this conversation type (since exact matching is complex)
805
  sample_metrics = None
806
  if available_columns:
807
- type_turns = filtered_df_exploded[filtered_df_exploded['type'] == selected_conversation['type']]
 
 
808
  sample_size = min(len(conversation_turns), len(type_turns))
809
  if sample_size > 0:
810
  sample_metrics = type_turns.head(sample_size)
811
-
812
  # Display each turn with its metrics
813
  for i, turn in enumerate(conversation_turns):
814
- role = turn.get('role', 'unknown')
815
- content = turn.get('content', 'No content')
816
-
817
  # Display turn content with role styling
818
- if role == 'user':
819
  st.markdown(f"**πŸ‘€ User (Turn {i+1}):**")
820
  st.info(content)
821
- elif role == 'assistant':
822
  st.markdown(f"**πŸ€– Assistant (Turn {i+1}):**")
823
  st.success(content)
824
  else:
825
  st.markdown(f"**❓ {role.title()} (Turn {i+1}):**")
826
  st.warning(content)
827
-
828
  # Display metrics for this turn
829
  if sample_metrics is not None and i < len(sample_metrics):
830
  turn_row = sample_metrics.iloc[i]
831
-
832
  # Create metrics display
833
  metrics_for_turn = {}
834
  for col in available_columns:
835
- metric_name = col.replace('turn.turn_metrics.', '')
836
  friendly_name = get_human_friendly_metric_name(metric_name)
837
- value = turn_row.get(col, 'N/A')
838
  if pd.notna(value) and isinstance(value, (int, float)):
839
  metrics_for_turn[friendly_name] = round(value, 3)
840
  else:
841
- metrics_for_turn[friendly_name] = 'N/A'
842
-
843
  # Add toxicity metrics if available
844
- toxicities = turn.get('toxicities', {})
845
  if toxicities:
846
  st.markdown("**πŸ” Toxicity Scores:**")
847
  tox_cols = st.columns(4)
848
  tox_metrics = [
849
- ('toxicity', 'Overall Toxicity'),
850
- ('severe_toxicity', 'Severe Toxicity'),
851
- ('identity_attack', 'Identity Attack'),
852
- ('insult', 'Insult'),
853
- ('obscene', 'Obscene'),
854
- ('sexual_explicit', 'Sexual Explicit'),
855
- ('threat', 'Threat')
856
  ]
857
-
858
  for idx, (tox_key, tox_name) in enumerate(tox_metrics):
859
  if tox_key in toxicities:
860
  col_idx = idx % 4
@@ -863,14 +1050,29 @@ def main():
863
  if isinstance(tox_value, (int, float)):
864
  # Color code based on toxicity level
865
  if tox_value > 0.5:
866
- st.metric(tox_name, f"{tox_value:.4f}", delta="HIGH", delta_color="inverse")
 
 
 
 
 
867
  elif tox_value > 0.1:
868
- st.metric(tox_name, f"{tox_value:.4f}", delta="MED", delta_color="off")
 
 
 
 
 
869
  else:
870
- st.metric(tox_name, f"{tox_value:.4f}", delta="LOW", delta_color="normal")
 
 
 
 
 
871
  else:
872
  st.metric(tox_name, str(tox_value))
873
-
874
  # Display complexity metrics
875
  if metrics_for_turn:
876
  st.markdown("**πŸ“Š Complexity Metrics:**")
@@ -878,29 +1080,34 @@ def main():
878
  num_cols = min(4, len(metrics_for_turn))
879
  if num_cols > 0:
880
  cols = st.columns(num_cols)
881
- for idx, (metric_name, value) in enumerate(metrics_for_turn.items()):
 
 
882
  col_idx = idx % num_cols
883
  with cols[col_idx]:
884
- if isinstance(value, (int, float)) and value != 'N/A':
 
 
 
885
  st.metric(metric_name, value)
886
  else:
887
  st.metric(metric_name, str(value))
888
  else:
889
  # Show toxicity even when no complexity metrics available
890
- toxicities = turn.get('toxicities', {})
891
  if toxicities:
892
  st.markdown("**πŸ” Toxicity Scores:**")
893
  tox_cols = st.columns(4)
894
  tox_metrics = [
895
- ('toxicity', 'Overall Toxicity'),
896
- ('severe_toxicity', 'Severe Toxicity'),
897
- ('identity_attack', 'Identity Attack'),
898
- ('insult', 'Insult'),
899
- ('obscene', 'Obscene'),
900
- ('sexual_explicit', 'Sexual Explicit'),
901
- ('threat', 'Threat')
902
  ]
903
-
904
  for idx, (tox_key, tox_name) in enumerate(tox_metrics):
905
  if tox_key in toxicities:
906
  col_idx = idx % 4
@@ -909,14 +1116,29 @@ def main():
909
  if isinstance(tox_value, (int, float)):
910
  # Color code based on toxicity level
911
  if tox_value > 0.5:
912
- st.metric(tox_name, f"{tox_value:.4f}", delta="HIGH", delta_color="inverse")
 
 
 
 
 
913
  elif tox_value > 0.1:
914
- st.metric(tox_name, f"{tox_value:.4f}", delta="MED", delta_color="off")
 
 
 
 
 
915
  else:
916
- st.metric(tox_name, f"{tox_value:.4f}", delta="LOW", delta_color="normal")
 
 
 
 
 
917
  else:
918
  st.metric(tox_name, str(tox_value))
919
-
920
  # Show basic turn statistics when no complexity metrics available
921
  st.markdown("**πŸ“ˆ Basic Statistics:**")
922
  col1, col2, col3 = st.columns(3)
@@ -926,21 +1148,21 @@ def main():
926
  st.metric("Words", len(content.split()))
927
  with col3:
928
  st.metric("Role", role.title())
929
-
930
  # Add separator between turns
931
  st.divider()
932
-
933
  # Plot metrics over turns with real data if available
934
  if available_columns and sample_metrics is not None:
935
  st.subheader("πŸ“ˆ Metrics Over Turns")
936
-
937
  fig = go.Figure()
938
-
939
  # Add traces for each selected metric (real data)
940
  for col in available_columns[:5]: # Limit to first 5 for readability
941
- metric_name = col.replace('turn.turn_metrics.', '')
942
  friendly_name = get_human_friendly_metric_name(metric_name)
943
-
944
  # Get values for this metric
945
  y_values = []
946
  for _, turn_row in sample_metrics.iterrows():
@@ -949,101 +1171,136 @@ def main():
949
  y_values.append(value)
950
  else:
951
  y_values.append(None)
952
-
953
  if any(v is not None for v in y_values):
954
- fig.add_trace(go.Scatter(
955
- x=list(range(1, len(y_values) + 1)),
956
- y=y_values,
957
- mode='lines+markers',
958
- name=friendly_name,
959
- line=dict(width=2),
960
- marker=dict(size=8),
961
- connectgaps=False
962
- ))
963
-
 
 
964
  if fig.data: # Only show if we have data
965
  fig.update_layout(
966
  title="Complexity Metrics Across Conversation Turns",
967
  xaxis_title="Turn Number",
968
  yaxis_title="Metric Value",
969
  height=400,
970
- hovermode='x unified'
971
  )
972
-
973
  st.plotly_chart(fig, use_container_width=True)
974
  else:
975
- st.info("No numeric metric data available to plot for this conversation type.")
976
-
 
 
977
  elif selected_metrics:
978
- st.info("Select metrics that are available in the dataset to see turn-level analysis.")
 
 
979
  else:
980
  st.warning("Select some metrics to see detailed turn-level analysis.")
981
-
982
  else:
983
  st.warning("No conversation data available for the selected conversation.")
984
-
985
  with tab5:
986
  st.header("Detailed View")
987
-
988
- # Data overview
989
- st.subheader("πŸ“‹ Dataset Overview")
990
-
991
- st.info(f"**Current Dataset:** `{selected_dataset}`")
992
-
993
- col1, col2, col3 = st.columns(3)
994
-
995
  with col1:
996
- st.metric("Total Conversations", len(filtered_df))
997
-
998
- with col2:
999
- st.metric("Total Turns", len(filtered_df_exploded))
1000
-
1001
- with col3:
1002
- st.metric("Available Metrics", len(available_metrics))
1003
-
1004
- # Type distribution
1005
- st.subheader("πŸ“Š Type Distribution")
1006
- type_counts = filtered_df['type'].value_counts()
1007
-
1008
- fig = px.pie(
1009
- values=type_counts.values,
1010
- names=type_counts.index,
1011
- title="Distribution of Conversation Types",
1012
- color_discrete_map=PLOT_PALETTE if len(type_counts) <= 3 else None
1013
- )
1014
-
1015
- st.plotly_chart(fig, use_container_width=True)
1016
-
1017
- # Sample data
1018
- st.subheader("πŸ“„ Sample Data")
1019
-
1020
- if st.checkbox("Show raw data sample"):
1021
- sample_cols = ['type'] + [f"turn.turn_metrics.{m}" for m in selected_metrics if f"turn.turn_metrics.{m}" in filtered_df_exploded.columns]
1022
- sample_data = filtered_df_exploded[sample_cols].head(100)
1023
- st.dataframe(sample_data)
1024
-
1025
- # Metric availability
1026
- st.subheader("πŸ“Š Metric Availability")
1027
-
1028
- metric_completeness = {}
1029
- for metric in selected_metrics:
1030
- full_metric_name = f"turn.turn_metrics.{metric}"
1031
- if full_metric_name in filtered_df_exploded.columns:
1032
- completeness = (1 - filtered_df_exploded[full_metric_name].isna().sum() / len(filtered_df_exploded)) * 100
1033
- metric_completeness[get_human_friendly_metric_name(metric)] = completeness
1034
-
1035
- if metric_completeness:
1036
- completeness_df = pd.DataFrame(list(metric_completeness.items()), columns=['Metric', 'Completeness (%)'])
1037
- fig = px.bar(
1038
- completeness_df,
1039
- x='Metric',
1040
- y='Completeness (%)',
1041
- title="Data Completeness by Metric",
1042
- color='Completeness (%)',
1043
- color_continuous_scale='Viridis'
1044
  )
1045
- fig.update_layout(xaxis_tickangle=-45, height=400)
1046
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1047
 
1048
  if __name__ == "__main__":
1049
  main()
 
14
  import warnings
15
  import datasets
16
  import logging
17
+
18
+ warnings.filterwarnings("ignore")
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO)
 
24
  # Constants
25
  PLOT_PALETTE = {
26
  "jailbreak": "#D000D8", # Purple
27
+ "benign": "#008393", # Cyan
28
+ "control": "#EF0000", # Red
29
  }
30
 
31
+
32
  # Utility functions
33
  def load_and_prepare_dataset(dataset_config):
34
  """Load the risky conversations dataset and prepare it for analysis."""
35
  logger.info("Loading dataset...")
36
+
37
  dataset_name = dataset_config["dataset_name"]
38
  logger.info(f"Loading dataset: {dataset_name}")
39
+
40
  # Load the dataset
41
  dataset = datasets.load_dataset(dataset_name, split="train")
42
  logger.info(f"Dataset loaded with {len(dataset)} conversations")
43
+
44
  # Convert to pandas
45
  pandas_dataset = dataset.to_pandas()
46
+
47
  # Explode the conversation column
48
  pandas_dataset_exploded = pandas_dataset.explode("conversation")
49
  pandas_dataset_exploded = pandas_dataset_exploded.reset_index(drop=True)
50
+
51
  # Normalize conversation data
52
  conversations_unfolded = pd.json_normalize(
53
  pandas_dataset_exploded["conversation"],
54
  )
55
  conversations_unfolded = conversations_unfolded.add_prefix("turn.")
56
+
57
  # Ensure there's a 'conversation_metrics' column, even if empty
58
  if "conversation_metrics" not in pandas_dataset_exploded.columns:
59
  pandas_dataset_exploded["conversation_metrics"] = [{}] * len(
60
  pandas_dataset_exploded
61
  )
62
+
63
  # Normalize conversation metrics
64
  conversations_metrics_unfolded = pd.json_normalize(
65
  pandas_dataset_exploded["conversation_metrics"]
 
67
  conversations_metrics_unfolded = conversations_metrics_unfolded.add_prefix(
68
  "conversation_metrics."
69
  )
70
+
71
  # Concatenate all dataframes
72
  pandas_dataset_exploded = pd.concat(
73
  [
 
79
  ],
80
  axis=1,
81
  )
82
+
83
  logger.info(f"Dataset prepared with {len(pandas_dataset_exploded)} turns")
84
  return pandas_dataset, pandas_dataset_exploded
85
 
86
+
87
  def get_available_turn_metrics(dataset_exploded):
88
  """Dynamically discover all available turn metrics from the dataset."""
89
  # Find all columns that contain turn metrics
90
  turn_metric_columns = [
91
  col for col in dataset_exploded.columns if col.startswith("turn.turn_metrics.")
92
  ]
93
+
94
  # Extract the metric names by removing the prefix
95
  metrics = [col.replace("turn.turn_metrics.", "") for col in turn_metric_columns]
96
+
97
  # Sort metrics for consistent ordering
98
  metrics = sorted(metrics)
99
+
100
  logger.info(f"Discovered {len(metrics)} available turn metrics")
101
  return metrics
102
 
103
+
104
  def get_human_friendly_metric_name(metric_name):
105
  """Convert a metric name to a human-friendly display name."""
106
  # Mapping of metric names to human-friendly display names
107
  metric_display_names = {
108
  # Length metrics
109
  "conversational_length_tokens": "Length in Tokens",
110
+ "conversational_length_words": "Length in Words",
111
  "conversational_length_bytes": "Length in Bytes",
 
112
  # Readability metrics
113
  "readability_flesch_score": "Flesch Reading Ease",
114
  "readability_standard_score": "Reading Grade Level",
 
115
  # Compression metrics
116
  "lzw_compression": "LZW Compression Ratio",
 
117
  # Language model metrics
118
  "ll_mean_logprobs": "Mean Log Probability",
119
  "ll_max_logprobs": "Maximum Log Probability",
 
121
  "rll_mean_logprobs": "Reverse Mean Log Probability",
122
  "rll_max_logprobs": "Reverse Maximum Log Probability",
123
  "rll_min_logprobs": "Reverse Minimum Log Probability",
 
124
  # Lexical diversity
125
  "type_token_ratio": "Type-Token Ratio",
 
126
  # Language distribution
127
  "zipf_score": "Zipf Distribution Score",
 
128
  # Evaluation metrics
129
  "rubric_evaluation.Stealth": "Stealth Score",
 
130
  # Working memory metrics
131
  "wm_total_entity_references": "Total Entity References",
132
  "wm_unique_entities_count": "Unique Entities Count",
133
  "wm_named_entity_mentions": "Named Entity Mentions",
134
+ "wm_concept_noun_mentions": "Concept Noun Mentions",
135
  "wm_pronoun_references": "Pronoun References",
136
  "wm_entity_density_per_word": "Entity Density per Word",
137
  "wm_entity_density_per_100_words": "Entity Density per 100 Words",
 
140
  "wm_entity_repetition_ratio": "Entity Repetition Ratio",
141
  "wm_cognitive_load_score": "Cognitive Load Score",
142
  "wm_high_cognitive_load": "High Cognitive Load",
 
143
  # Discourse coherence metrics
144
  "discourse_coherence_to_next_user": "Coherence to Next User Turn",
145
  "discourse_coherence_to_next_turn": "Coherence to Next Turn",
 
148
  "discourse_user_topic_drift": "User Topic Drift",
149
  "discourse_user_entity_continuity": "User Entity Continuity",
150
  "discourse_num_user_turns": "Number of User Turns",
 
151
  # Tokens per byte
152
  "tokens_per_byte": "Tokens per Byte",
153
  }
154
+
155
  # Check exact match first
156
  if metric_name in metric_display_names:
157
  return metric_display_names[metric_name]
158
+
159
  # Handle conversation-level aggregations
160
+ for suffix in [
161
+ "_conversation_mean",
162
+ "_conversation_min",
163
+ "_conversation_max",
164
+ "_conversation_std",
165
+ "_conversation_count",
166
+ ]:
167
  if metric_name.endswith(suffix):
168
+ base_metric = metric_name[: -len(suffix)]
169
  if base_metric in metric_display_names:
170
  agg_type = suffix.split("_")[-1].title()
171
  return f"{metric_display_names[base_metric]} ({agg_type})"
172
+
173
  # Handle turn-level metrics with "turn.turn_metrics." prefix
174
  if metric_name.startswith("turn.turn_metrics."):
175
+ base_metric = metric_name[len("turn.turn_metrics.") :]
176
  if base_metric in metric_display_names:
177
  return metric_display_names[base_metric]
178
+
179
  # Fallback: convert underscores to spaces and title case
180
  clean_name = metric_name
181
  for prefix in ["turn.turn_metrics.", "conversation_metrics.", "turn_metrics."]:
182
  if clean_name.startswith(prefix):
183
+ clean_name = clean_name[len(prefix) :]
184
  break
185
+
186
  # Convert to human-readable format
187
  clean_name = clean_name.replace("_", " ").title()
188
  return clean_name
189
 
190
+
191
+ def render_metric_distribution(metric, filtered_df_exploded, selected_types):
192
+ """Render distribution plot for a single metric."""
193
+ full_metric_name = f"turn.turn_metrics.{metric}"
194
+
195
+ if full_metric_name not in filtered_df_exploded.columns:
196
+ st.warning(f"Metric {metric} not found in dataset")
197
+ return
198
+
199
+ st.subheader(f"πŸ“Š {get_human_friendly_metric_name(metric)}")
200
+
201
+ # Clean the data
202
+ metric_data = filtered_df_exploded[["type", full_metric_name]].copy()
203
+ metric_data = metric_data.dropna()
204
+
205
+ if len(metric_data) == 0:
206
+ st.warning(f"No data available for {metric}")
207
+ return
208
+
209
+ # Create plotly histogram
210
+ fig = px.histogram(
211
+ metric_data,
212
+ x=full_metric_name,
213
+ color="type",
214
+ marginal="box",
215
+ title=f"Distribution of {get_human_friendly_metric_name(metric)}",
216
+ color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None,
217
+ opacity=0.7,
218
+ nbins=50,
219
+ )
220
+
221
+ fig.update_layout(
222
+ xaxis_title=get_human_friendly_metric_name(metric),
223
+ yaxis_title="Count",
224
+ height=400,
225
+ )
226
+
227
+ st.plotly_chart(fig, use_container_width=True)
228
+
229
+ # Summary statistics
230
+ col1, col2 = st.columns(2)
231
+
232
+ with col1:
233
+ st.write("**Summary Statistics**")
234
+ summary_stats = (
235
+ metric_data.groupby("type")[full_metric_name]
236
+ .agg(["count", "mean", "std", "min", "max"])
237
+ .round(3)
238
+ )
239
+ st.dataframe(summary_stats)
240
+
241
+ with col2:
242
+ st.write("**Percentiles**")
243
+ percentiles = (
244
+ metric_data.groupby("type")[full_metric_name]
245
+ .quantile([0.25, 0.5, 0.75])
246
+ .unstack()
247
+ .round(3)
248
+ )
249
+ percentiles.columns = ["25%", "50%", "75%"]
250
+ st.dataframe(percentiles)
251
+
252
+
253
  # Setup page config
254
  st.set_page_config(
255
  page_title="Complexity Metrics Explorer",
256
  page_icon="πŸ“Š",
257
  layout="wide",
258
+ initial_sidebar_state="expanded",
259
  )
260
 
261
+
262
  # Cache data loading
263
  @st.cache_data
264
  def load_data(dataset_name):
265
  """Load and cache the dataset"""
266
+ df, df_exploded = load_and_prepare_dataset({"dataset_name": dataset_name})
 
 
267
  return df, df_exploded
268
 
269
+
270
  @st.cache_data
271
  def get_metrics(df_exploded):
272
  """Get available metrics from the dataset"""
273
  return get_available_turn_metrics(df_exploded)
274
 
275
+
276
  def main():
277
  st.title("πŸ” Complexity Metrics Explorer")
278
+ st.markdown(
279
+ "Interactive visualization of conversation complexity metrics across different dataset types."
280
+ )
281
+
282
  # Dataset selection
283
  st.sidebar.header("πŸ—‚οΈ Dataset Selection")
284
+
285
  # Available datasets
286
  available_datasets = [
287
  "risky-conversations/jailbreaks_dataset_with_results_reduced",
288
  "risky-conversations/jailbreaks_dataset_with_results",
289
  "risky-conversations/jailbreaks_dataset_with_results_filtered_successful_jailbreak",
290
+ "Custom...",
291
  ]
292
+
293
  selected_option = st.sidebar.selectbox(
294
  "Select Dataset",
295
  options=available_datasets,
296
  index=0, # Default to reduced dataset
297
+ help="Choose which dataset to analyze",
298
  )
299
+
300
  # Handle custom dataset input
301
  if selected_option == "Custom...":
302
  selected_dataset = st.sidebar.text_input(
303
  "Custom Dataset Name",
304
  value="risky-conversations/jailbreaks_dataset_with_results_reduced",
305
+ help="Enter the full dataset name (e.g., 'risky-conversations/jailbreaks_dataset_with_results_reduced')",
306
  )
307
  if not selected_dataset.strip():
308
  st.sidebar.warning("Please enter a dataset name")
309
  st.stop()
310
  else:
311
  selected_dataset = selected_option
312
+
313
  # Add refresh button
314
  if st.sidebar.button("πŸ”„ Refresh Data", help="Clear cache and reload dataset"):
315
  st.cache_data.clear()
316
  st.rerun()
317
+
318
  # Load data
319
  with st.spinner(f"Loading dataset: {selected_dataset}..."):
320
  try:
321
  df, df_exploded = load_data(selected_dataset)
322
  available_metrics = get_metrics(df_exploded)
323
+
324
  # Display dataset info
325
  col1, col2, col3, col4 = st.columns(4)
326
  with col1:
327
+ st.metric("Dataset", selected_dataset.split("_")[-1].title())
328
  with col2:
329
  st.metric("Conversations", f"{len(df):,}")
330
  with col3:
331
  st.metric("Turns", f"{len(df_exploded):,}")
332
  with col4:
333
  st.metric("Metrics", len(available_metrics))
334
+
335
  data_loaded = True
336
  except Exception as e:
337
  st.error(f"Error loading dataset: {e}")
338
  st.info("Please check if the dataset exists and is accessible.")
339
+ st.info(
340
+ "πŸ’‘ Try using one of the predefined dataset options instead of custom input."
341
+ )
342
  data_loaded = False
343
+
344
  if not data_loaded:
345
  st.stop()
346
+
347
  # Sidebar controls
348
  st.sidebar.header("πŸŽ›οΈ Controls")
349
+
350
  # Dataset type filter
351
+ dataset_types = df["type"].unique()
352
  selected_types = st.sidebar.multiselect(
353
  "Select Dataset Types",
354
  options=dataset_types,
355
  default=dataset_types,
356
+ help="Filter by conversation type",
357
  )
358
+
359
  # Role filter
360
+ if "turn.role" in df_exploded.columns:
361
+ roles = df_exploded["turn.role"].dropna().unique()
362
  # Assert only user and assistant roles exist
363
+ expected_roles = {"user", "assistant"}
364
  actual_roles = set(roles)
365
+ assert actual_roles.issubset(
366
+ expected_roles
367
+ ), f"Unexpected roles found: {actual_roles - expected_roles}. Expected only 'user' and 'assistant'"
368
+
369
  st.sidebar.subheader("πŸ‘₯ Role Filter")
370
  col1, col2 = st.sidebar.columns(2)
371
+
372
  with col1:
373
  include_user = st.checkbox("User", value=True, help="Include user turns")
374
  with col2:
375
+ include_assistant = st.checkbox(
376
+ "Assistant", value=True, help="Include assistant turns"
377
+ )
378
+
379
  # Build selected roles list
380
  selected_roles = []
381
+ if include_user and "user" in roles:
382
+ selected_roles.append("user")
383
+ if include_assistant and "assistant" in roles:
384
+ selected_roles.append("assistant")
385
+
386
  # Show selection info
387
  if selected_roles:
388
  st.sidebar.success(f"Including: {', '.join(selected_roles)}")
 
390
  st.sidebar.warning("No roles selected")
391
  else:
392
  selected_roles = None
393
+
394
  # Filter data based on selections
395
+ filtered_df = df[df["type"].isin(selected_types)] if selected_types else df
396
+ filtered_df_exploded = (
397
+ df_exploded[df_exploded["type"].isin(selected_types)]
398
+ if selected_types
399
+ else df_exploded
400
+ )
401
+
402
+ if selected_roles and "turn.role" in filtered_df_exploded.columns:
403
+ filtered_df_exploded = filtered_df_exploded[
404
+ filtered_df_exploded["turn.role"].isin(selected_roles)
405
+ ]
406
  elif selected_roles is not None and len(selected_roles) == 0:
407
  # If roles exist but none are selected, show empty dataset
408
+ filtered_df_exploded = filtered_df_exploded.iloc[
409
+ 0:0
410
+ ] # Empty dataframe with same structure
411
+
412
  # Check if we have data after filtering
413
  if len(filtered_df_exploded) == 0:
414
+ st.error(
415
+ "No data available with current filters. Please adjust your selection."
416
+ )
417
  st.stop()
418
+
419
  # Metric selection
420
  st.sidebar.header("πŸ“Š Metrics")
421
+
422
  # Dynamic metric categorization based on common patterns
423
  def categorize_metrics(metrics):
424
  """Dynamically categorize metrics based on naming patterns"""
425
  categories = {"All": metrics} # Always include all metrics
426
+
427
  # Common patterns to look for
428
  patterns = {
429
+ "Length": ["length", "byte", "word", "token", "char"],
430
+ "Readability": ["readability", "flesch", "standard"],
431
+ "Compression": ["lzw", "compression"],
432
+ "Language Model": ["ll_", "rll_", "logprob"],
433
+ "Working Memory": ["wm_"],
434
+ "Discourse": ["discourse"],
435
+ "Evaluation": ["rubric", "evaluation", "stealth"],
436
+ "Distribution": ["zipf", "type_token"],
437
+ "Coherence": ["coherence"],
438
+ "Entity": ["entity", "entities"],
439
+ "Cognitive": ["cognitive", "load"],
440
  }
441
+
442
  # Categorize metrics
443
  for category, keywords in patterns.items():
444
+ matching_metrics = [
445
+ m for m in metrics if any(keyword in m.lower() for keyword in keywords)
446
+ ]
447
  if matching_metrics:
448
  categories[category] = matching_metrics
449
+
450
  # Find uncategorized metrics
451
  categorized = set()
452
  for cat_metrics in categories.values():
453
  if cat_metrics != metrics: # Skip "All" category
454
  categorized.update(cat_metrics)
455
+
456
  uncategorized = [m for m in metrics if m not in categorized]
457
  if uncategorized:
458
  categories["Other"] = uncategorized
459
+
460
  return categories
461
+
462
  metric_categories = categorize_metrics(available_metrics)
463
+
464
  # Metric selection interface
465
  selection_mode = st.sidebar.radio(
466
  "Selection Mode",
467
  ["By Category", "Search/Filter", "Select All"],
468
+ help="Choose how to select metrics",
469
  )
470
+
471
  if selection_mode == "By Category":
472
  selected_category = st.sidebar.selectbox(
473
+ "Metric Category",
474
  options=list(metric_categories.keys()),
475
+ help=f"Found {len(metric_categories)} categories",
476
  )
477
+
478
  available_in_category = metric_categories[selected_category]
479
+ default_selection = (
480
+ available_in_category[:5]
481
+ if len(available_in_category) > 5
482
+ else available_in_category
483
+ )
484
+
485
  # Add select all button for category
486
  col1, col2 = st.sidebar.columns(2)
487
  with col1:
 
490
  with col2:
491
  if st.button("Clear All", key="clear_all_category"):
492
  st.session_state.selected_metrics_category = []
493
+
494
  # Use session state for persistence
495
  if "selected_metrics_category" not in st.session_state:
496
  st.session_state.selected_metrics_category = default_selection
497
+
498
  selected_metrics = st.sidebar.multiselect(
499
  f"Select Metrics ({len(available_in_category)} available)",
500
  options=available_in_category,
501
  default=st.session_state.selected_metrics_category,
502
  key="metrics_multiselect_category",
503
+ help="Choose metrics to visualize",
504
  )
505
+
506
  elif selection_mode == "Search/Filter":
507
  search_term = st.sidebar.text_input(
508
  "Search Metrics",
509
  placeholder="Enter keywords to filter metrics...",
510
+ help="Search for metrics containing specific terms",
511
  )
512
+
513
  if search_term:
514
+ filtered_metrics = [
515
+ m for m in available_metrics if search_term.lower() in m.lower()
516
+ ]
517
  else:
518
  filtered_metrics = available_metrics
519
+
520
  st.sidebar.write(f"Found {len(filtered_metrics)} metrics")
521
+
522
  # Add select all button for search results
523
  col1, col2 = st.sidebar.columns(2)
524
  with col1:
 
527
  with col2:
528
  if st.button("Clear All", key="clear_all_search"):
529
  st.session_state.selected_metrics_search = []
530
+
531
  # Use session state for persistence
532
  if "selected_metrics_search" not in st.session_state:
533
+ st.session_state.selected_metrics_search = (
534
+ filtered_metrics[:5]
535
+ if len(filtered_metrics) > 5
536
+ else filtered_metrics[:3]
537
+ )
538
+
539
  selected_metrics = st.sidebar.multiselect(
540
  "Select Metrics",
541
  options=filtered_metrics,
542
  default=st.session_state.selected_metrics_search,
543
  key="metrics_multiselect_search",
544
+ help="Choose metrics to visualize",
545
  )
546
+
547
  else: # Select All
548
  # Add select all button for all metrics
549
  col1, col2 = st.sidebar.columns(2)
 
553
  with col2:
554
  if st.button("Clear All", key="clear_all_all"):
555
  st.session_state.selected_metrics_all = []
556
+
557
  # Use session state for persistence
558
  if "selected_metrics_all" not in st.session_state:
559
+ st.session_state.selected_metrics_all = available_metrics[
560
+ :10
561
+ ] # Limit default to first 10 for performance
562
+
563
  selected_metrics = st.sidebar.multiselect(
564
  f"All Metrics ({len(available_metrics)} total)",
565
  options=available_metrics,
566
  default=st.session_state.selected_metrics_all,
567
  key="metrics_multiselect_all",
568
+ help="All available metrics - be careful with performance for large selections",
569
  )
570
+
571
  # Show selection summary
572
  if selected_metrics:
573
  st.sidebar.success(f"Selected {len(selected_metrics)} metrics")
574
+
575
  # Performance warning for large selections
576
  if len(selected_metrics) > 20:
577
+ st.sidebar.warning(
578
+ f"⚠️ Large selection ({len(selected_metrics)} metrics) may impact performance"
579
+ )
580
  elif len(selected_metrics) > 50:
581
+ st.sidebar.error(
582
+ f"🚨 Very large selection ({len(selected_metrics)} metrics) - consider reducing for better performance"
583
+ )
584
  else:
585
  st.sidebar.warning("No metrics selected")
586
+
587
  # Metric info expander
588
  with st.sidebar.expander("ℹ️ Metric Information", expanded=False):
589
  st.write(f"**Total Available Metrics:** {len(available_metrics)}")
590
  st.write(f"**Categories Found:** {len(metric_categories)}")
591
+
592
  if st.checkbox("Show all metric names", key="show_all_metrics"):
593
  st.write("**All Available Metrics:**")
594
  for i, metric in enumerate(available_metrics, 1):
595
  st.write(f"{i}. `{metric}`")
596
+
597
  # Main content tabs
598
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(
599
+ [
600
+ "πŸ“Š Distributions",
601
+ "πŸ”— Correlations",
602
+ "πŸ“ˆ Comparisons",
603
+ "πŸ” Conversation",
604
+ "🎯 Details",
605
+ ]
606
+ )
607
+
608
  with tab1:
609
  st.header("Distribution Analysis")
610
+
611
  if not selected_metrics:
612
  st.warning("Please select at least one metric to visualize.")
613
  return
614
+
615
+ # Create buttons for each metric to prevent loading all at once
616
+ st.info(
617
+ f"πŸ“Š Select a metric to plot its distribution ({len(selected_metrics)} metrics available)"
618
+ )
619
+
620
+ # Organize buttons in columns for better layout
621
+ cols_per_row = 3
622
+ for i in range(0, len(selected_metrics), cols_per_row):
623
+ cols = st.columns(cols_per_row)
624
+ for j, metric in enumerate(selected_metrics[i : i + cols_per_row]):
625
+ with cols[j]:
626
+ friendly_name = get_human_friendly_metric_name(metric)
627
+ # Truncate button text if too long
628
+ button_text = (
629
+ friendly_name[:25] + "..."
630
+ if len(friendly_name) > 25
631
+ else friendly_name
632
+ )
633
+
634
+ if st.button(
635
+ f"πŸ“ˆ {button_text}",
636
+ key=f"plot_{metric}",
637
+ help=f"Plot distribution for {friendly_name}",
638
+ ):
639
+ render_metric_distribution(
640
+ metric, filtered_df_exploded, selected_types
641
+ )
642
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  with tab2:
644
  st.header("Correlation Analysis")
645
+
646
  if len(selected_metrics) < 2:
647
  st.warning("Please select at least 2 metrics for correlation analysis.")
648
  else:
649
+ # Add button to trigger correlation analysis
650
+ st.info(
651
+ f"πŸ”— Ready to analyze correlations between {len(selected_metrics)} metrics"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  )
653
+
654
+ col1, col2 = st.columns([1, 3])
655
+ with col1:
656
+ run_correlation = st.button(
657
+ "πŸ” Run Correlation Analysis",
658
+ help="Calculate and display correlation matrix and scatter plots",
659
+ )
660
+ with col2:
661
+ if len(selected_metrics) > 10:
662
+ st.warning(
663
+ f"⚠️ Large analysis ({len(selected_metrics)} metrics) - may take some time"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  )
665
+
666
+ if run_correlation:
667
+ with st.spinner("Calculating correlations..."):
668
+ # Prepare correlation data
669
+ corr_columns = [f"turn.turn_metrics.{m}" for m in selected_metrics]
670
+ corr_data = filtered_df_exploded[corr_columns + ["type"]].copy()
671
+
672
+ # Clean column names for display
673
+ corr_data.columns = [
674
+ (
675
+ get_human_friendly_metric_name(
676
+ col.replace("turn.turn_metrics.", "")
677
+ )
678
+ if col.startswith("turn.turn_metrics.")
679
+ else col
680
+ )
681
+ for col in corr_data.columns
682
+ ]
683
+
684
+ # Calculate correlation matrix
685
+ corr_matrix = corr_data.select_dtypes(include=[np.number]).corr()
686
+
687
+ # Create correlation heatmap
688
+ fig = px.imshow(
689
+ corr_matrix,
690
+ text_auto=True,
691
+ aspect="auto",
692
+ title="Correlation Matrix",
693
+ color_continuous_scale="RdBu_r",
694
+ zmin=-1,
695
+ zmax=1,
696
+ )
697
+
698
+ fig.update_layout(height=600)
699
  st.plotly_chart(fig, use_container_width=True)
700
+
701
+ # Scatter plots for strong correlations
702
+ st.subheader("Strong Correlations")
703
+
704
+ # Find strong correlations (>0.7 or <-0.7)
705
+ strong_corrs = []
706
+ for i in range(len(corr_matrix.columns)):
707
+ for j in range(i + 1, len(corr_matrix.columns)):
708
+ corr_val = corr_matrix.iloc[i, j]
709
+ if abs(corr_val) > 0.7:
710
+ strong_corrs.append(
711
+ (
712
+ corr_matrix.columns[i],
713
+ corr_matrix.columns[j],
714
+ corr_val,
715
+ )
716
+ )
717
+
718
+ if strong_corrs:
719
+ for metric1, metric2, corr_val in strong_corrs[
720
+ :3
721
+ ]: # Show top 3
722
+ fig = px.scatter(
723
+ corr_data,
724
+ x=metric1,
725
+ y=metric2,
726
+ color="type",
727
+ title=f"{metric1} vs {metric2} (r={corr_val:.3f})",
728
+ color_discrete_map=(
729
+ PLOT_PALETTE if len(selected_types) <= 3 else None
730
+ ),
731
+ opacity=0.6,
732
+ )
733
+ st.plotly_chart(fig, use_container_width=True)
734
+ else:
735
+ st.info(
736
+ "No strong correlations (|r| > 0.7) found between selected metrics."
737
+ )
738
+
739
  with tab3:
740
  st.header("Type Comparisons")
741
+
742
  if not selected_metrics:
743
  st.warning("Please select at least one metric to compare.")
744
  else:
745
  # Box plots for each metric
746
  for metric in selected_metrics:
747
  full_metric_name = f"turn.turn_metrics.{metric}"
748
+
749
  if full_metric_name not in filtered_df_exploded.columns:
750
  continue
751
+
752
  st.subheader(f"πŸ“¦ {get_human_friendly_metric_name(metric)} by Type")
753
+
754
  # Create box plot
755
  fig = px.box(
756
  filtered_df_exploded.dropna(subset=[full_metric_name]),
757
+ x="type",
758
  y=full_metric_name,
759
  title=f"Distribution of {get_human_friendly_metric_name(metric)} by Type",
760
+ color="type",
761
+ color_discrete_map=(
762
+ PLOT_PALETTE if len(selected_types) <= 3 else None
763
+ ),
764
  )
765
+
766
  fig.update_layout(
767
  xaxis_title="Dataset Type",
768
  yaxis_title=get_human_friendly_metric_name(metric),
769
+ height=400,
770
  )
771
+
772
  st.plotly_chart(fig, use_container_width=True)
773
+
774
  with tab4:
775
  st.header("Individual Conversation Analysis")
776
+
777
  # Conversation selector
778
  st.subheader("πŸ” Select Conversation")
779
+
780
  # Get unique conversations with some metadata
781
  conversation_info = []
782
  for idx, row in filtered_df.iterrows():
783
+ conv_type = row["type"]
784
  # Get basic info about the conversation
785
+ conv_turns = len(row.get("conversation", []))
786
+ conversation_info.append(
787
+ {
788
+ "index": idx,
789
+ "type": conv_type,
790
+ "turns": conv_turns,
791
+ "display": f"Conversation {idx} ({conv_type}) - {conv_turns} turns",
792
+ }
793
+ )
794
+
795
  # Sort by type and number of turns for better organization
796
+ conversation_info = sorted(
797
+ conversation_info, key=lambda x: (x["type"], -x["turns"])
798
+ )
799
+
800
  # Conversation selection
801
  col1, col2 = st.columns([3, 1])
802
+
803
  with col1:
804
  selected_conv_display = st.selectbox(
805
  "Choose a conversation to analyze",
806
+ options=[conv["display"] for conv in conversation_info],
807
+ help="Select a conversation to view detailed metrics and content",
808
  )
809
+
810
  with col2:
811
  if st.button("🎲 Random", help="Select a random conversation"):
812
  import random
813
+
814
+ selected_conv_display = random.choice(
815
+ [conv["display"] for conv in conversation_info]
816
+ )
817
  st.rerun()
818
+
819
  # Get the selected conversation data
820
+ selected_conv_info = next(
821
+ conv
822
+ for conv in conversation_info
823
+ if conv["display"] == selected_conv_display
824
+ )
825
+ selected_idx = selected_conv_info["index"]
826
  selected_conversation = filtered_df.iloc[selected_idx]
827
+
828
  # Display conversation metadata
829
  st.subheader("πŸ“‹ Conversation Overview")
830
+
831
  # First row - basic info
832
  col1, col2, col3, col4 = st.columns(4)
833
  with col1:
834
+ st.metric("Type", selected_conversation["type"])
835
  with col2:
836
  st.metric("Index", selected_idx)
837
  with col3:
838
+ st.metric("Total Turns", len(selected_conversation.get("conversation", [])))
839
  with col4:
840
  # Count user vs assistant turns
841
+ roles = [
842
+ turn.get("role", "unknown")
843
+ for turn in selected_conversation.get("conversation", [])
844
+ ]
845
+ user_turns = roles.count("user")
846
+ assistant_turns = roles.count("assistant")
847
  st.metric("User/Assistant", f"{user_turns}/{assistant_turns}")
848
+
849
  # Second row - additional metadata
850
  col1, col2, col3 = st.columns(3)
851
  with col1:
852
+ provenance = selected_conversation.get("provenance_dataset", "Unknown")
853
  st.metric("Dataset Source", provenance)
854
  with col2:
855
+ language = selected_conversation.get("language", "Unknown")
856
+ st.metric("Language", language.upper() if language else "Unknown")
857
  with col3:
858
+ timestamp = selected_conversation.get("timestamp", None)
859
  if timestamp:
860
  # Handle different timestamp formats
861
  if isinstance(timestamp, str):
 
864
  st.metric("Timestamp", str(timestamp))
865
  else:
866
  st.metric("Timestamp", "Not Available")
867
+
868
  # Add toxicity summary
869
+ conversation_turns_temp = selected_conversation.get("conversation", [])
870
+ if hasattr(conversation_turns_temp, "tolist"):
871
  conversation_turns_temp = conversation_turns_temp.tolist()
872
  elif conversation_turns_temp is None:
873
  conversation_turns_temp = []
874
+
875
  if len(conversation_turns_temp) > 0:
876
  # Calculate overall toxicity statistics
877
  all_toxicities = []
878
  for turn in conversation_turns_temp:
879
+ toxicities = turn.get("toxicities", {})
880
+ if toxicities and "toxicity" in toxicities:
881
+ all_toxicities.append(toxicities["toxicity"])
882
+
883
  if all_toxicities:
884
  avg_toxicity = sum(all_toxicities) / len(all_toxicities)
885
  max_toxicity = max(all_toxicities)
886
+
887
  st.markdown("**πŸ” Toxicity Summary:**")
888
  col1, col2, col3 = st.columns(3)
889
  with col1:
890
  # Color code average toxicity
891
  if avg_toxicity > 0.5:
892
+ st.metric(
893
+ "Average Toxicity",
894
+ f"{avg_toxicity:.4f}",
895
+ delta="HIGH",
896
+ delta_color="inverse",
897
+ )
898
  elif avg_toxicity > 0.1:
899
+ st.metric(
900
+ "Average Toxicity",
901
+ f"{avg_toxicity:.4f}",
902
+ delta="MED",
903
+ delta_color="off",
904
+ )
905
  else:
906
+ st.metric(
907
+ "Average Toxicity",
908
+ f"{avg_toxicity:.4f}",
909
+ delta="LOW",
910
+ delta_color="normal",
911
+ )
912
+
913
  with col2:
914
  # Color code max toxicity
915
  if max_toxicity > 0.5:
916
+ st.metric(
917
+ "Max Toxicity",
918
+ f"{max_toxicity:.4f}",
919
+ delta="HIGH",
920
+ delta_color="inverse",
921
+ )
922
  elif max_toxicity > 0.1:
923
+ st.metric(
924
+ "Max Toxicity",
925
+ f"{max_toxicity:.4f}",
926
+ delta="MED",
927
+ delta_color="off",
928
+ )
929
  else:
930
+ st.metric(
931
+ "Max Toxicity",
932
+ f"{max_toxicity:.4f}",
933
+ delta="LOW",
934
+ delta_color="normal",
935
+ )
936
+
937
  with col3:
938
  high_tox_turns = sum(1 for t in all_toxicities if t > 0.5)
939
  st.metric("High Toxicity Turns", high_tox_turns)
940
+
941
  # Get conversation turns with metrics
942
+ conv_turns_data = filtered_df_exploded[
943
+ filtered_df_exploded.index.isin(
944
+ filtered_df_exploded[
945
+ filtered_df_exploded.index
946
+ // len(filtered_df_exploded)
947
+ * len(filtered_df)
948
+ + filtered_df_exploded.index % len(filtered_df)
949
+ == selected_idx
950
+ ].index
951
+ )
952
+ ].copy()
953
+
954
  # Alternative approach: filter by matching all conversation data
955
  # This is more reliable but less efficient
956
  conv_turns_data = []
957
  start_idx = None
958
  for idx, row in filtered_df_exploded.iterrows():
959
  # Check if this row belongs to our selected conversation
960
+ if (
961
+ row["type"] == selected_conversation["type"]
962
+ and hasattr(row, "conversation")
963
+ and row.get("conversation") is not None
964
+ ):
965
  # This is a simplified approach - in reality you'd need better conversation matching
966
  pass
967
+
968
  # Simpler approach: get all turns from the conversation directly
969
+ conversation_turns = selected_conversation.get("conversation", [])
970
+
971
  # Ensure conversation_turns is a list and handle different data types
972
+ if hasattr(conversation_turns, "tolist"):
973
  conversation_turns = conversation_turns.tolist()
974
  elif conversation_turns is None:
975
  conversation_turns = []
976
+
977
  if len(conversation_turns) > 0:
978
  # Display conversation content with metrics
979
  st.subheader("πŸ’¬ Conversation with Metrics")
980
+
981
  # Get actual turn-level data for this conversation
982
  turn_metric_columns = [f"turn.turn_metrics.{m}" for m in selected_metrics]
983
+ available_columns = [
984
+ col
985
+ for col in turn_metric_columns
986
+ if col in filtered_df_exploded.columns
987
+ ]
988
+
989
  # Get sample metrics for this conversation type (since exact matching is complex)
990
  sample_metrics = None
991
  if available_columns:
992
+ type_turns = filtered_df_exploded[
993
+ filtered_df_exploded["type"] == selected_conversation["type"]
994
+ ]
995
  sample_size = min(len(conversation_turns), len(type_turns))
996
  if sample_size > 0:
997
  sample_metrics = type_turns.head(sample_size)
998
+
999
  # Display each turn with its metrics
1000
  for i, turn in enumerate(conversation_turns):
1001
+ role = turn.get("role", "unknown")
1002
+ content = turn.get("content", "No content")
1003
+
1004
  # Display turn content with role styling
1005
+ if role == "user":
1006
  st.markdown(f"**πŸ‘€ User (Turn {i+1}):**")
1007
  st.info(content)
1008
+ elif role == "assistant":
1009
  st.markdown(f"**πŸ€– Assistant (Turn {i+1}):**")
1010
  st.success(content)
1011
  else:
1012
  st.markdown(f"**❓ {role.title()} (Turn {i+1}):**")
1013
  st.warning(content)
1014
+
1015
  # Display metrics for this turn
1016
  if sample_metrics is not None and i < len(sample_metrics):
1017
  turn_row = sample_metrics.iloc[i]
1018
+
1019
  # Create metrics display
1020
  metrics_for_turn = {}
1021
  for col in available_columns:
1022
+ metric_name = col.replace("turn.turn_metrics.", "")
1023
  friendly_name = get_human_friendly_metric_name(metric_name)
1024
+ value = turn_row.get(col, "N/A")
1025
  if pd.notna(value) and isinstance(value, (int, float)):
1026
  metrics_for_turn[friendly_name] = round(value, 3)
1027
  else:
1028
+ metrics_for_turn[friendly_name] = "N/A"
1029
+
1030
  # Add toxicity metrics if available
1031
+ toxicities = turn.get("toxicities", {})
1032
  if toxicities:
1033
  st.markdown("**πŸ” Toxicity Scores:**")
1034
  tox_cols = st.columns(4)
1035
  tox_metrics = [
1036
+ ("toxicity", "Overall Toxicity"),
1037
+ ("severe_toxicity", "Severe Toxicity"),
1038
+ ("identity_attack", "Identity Attack"),
1039
+ ("insult", "Insult"),
1040
+ ("obscene", "Obscene"),
1041
+ ("sexual_explicit", "Sexual Explicit"),
1042
+ ("threat", "Threat"),
1043
  ]
1044
+
1045
  for idx, (tox_key, tox_name) in enumerate(tox_metrics):
1046
  if tox_key in toxicities:
1047
  col_idx = idx % 4
 
1050
  if isinstance(tox_value, (int, float)):
1051
  # Color code based on toxicity level
1052
  if tox_value > 0.5:
1053
+ st.metric(
1054
+ tox_name,
1055
+ f"{tox_value:.4f}",
1056
+ delta="HIGH",
1057
+ delta_color="inverse",
1058
+ )
1059
  elif tox_value > 0.1:
1060
+ st.metric(
1061
+ tox_name,
1062
+ f"{tox_value:.4f}",
1063
+ delta="MED",
1064
+ delta_color="off",
1065
+ )
1066
  else:
1067
+ st.metric(
1068
+ tox_name,
1069
+ f"{tox_value:.4f}",
1070
+ delta="LOW",
1071
+ delta_color="normal",
1072
+ )
1073
  else:
1074
  st.metric(tox_name, str(tox_value))
1075
+
1076
  # Display complexity metrics
1077
  if metrics_for_turn:
1078
  st.markdown("**πŸ“Š Complexity Metrics:**")
 
1080
  num_cols = min(4, len(metrics_for_turn))
1081
  if num_cols > 0:
1082
  cols = st.columns(num_cols)
1083
+ for idx, (metric_name, value) in enumerate(
1084
+ metrics_for_turn.items()
1085
+ ):
1086
  col_idx = idx % num_cols
1087
  with cols[col_idx]:
1088
+ if (
1089
+ isinstance(value, (int, float))
1090
+ and value != "N/A"
1091
+ ):
1092
  st.metric(metric_name, value)
1093
  else:
1094
  st.metric(metric_name, str(value))
1095
  else:
1096
  # Show toxicity even when no complexity metrics available
1097
+ toxicities = turn.get("toxicities", {})
1098
  if toxicities:
1099
  st.markdown("**πŸ” Toxicity Scores:**")
1100
  tox_cols = st.columns(4)
1101
  tox_metrics = [
1102
+ ("toxicity", "Overall Toxicity"),
1103
+ ("severe_toxicity", "Severe Toxicity"),
1104
+ ("identity_attack", "Identity Attack"),
1105
+ ("insult", "Insult"),
1106
+ ("obscene", "Obscene"),
1107
+ ("sexual_explicit", "Sexual Explicit"),
1108
+ ("threat", "Threat"),
1109
  ]
1110
+
1111
  for idx, (tox_key, tox_name) in enumerate(tox_metrics):
1112
  if tox_key in toxicities:
1113
  col_idx = idx % 4
 
1116
  if isinstance(tox_value, (int, float)):
1117
  # Color code based on toxicity level
1118
  if tox_value > 0.5:
1119
+ st.metric(
1120
+ tox_name,
1121
+ f"{tox_value:.4f}",
1122
+ delta="HIGH",
1123
+ delta_color="inverse",
1124
+ )
1125
  elif tox_value > 0.1:
1126
+ st.metric(
1127
+ tox_name,
1128
+ f"{tox_value:.4f}",
1129
+ delta="MED",
1130
+ delta_color="off",
1131
+ )
1132
  else:
1133
+ st.metric(
1134
+ tox_name,
1135
+ f"{tox_value:.4f}",
1136
+ delta="LOW",
1137
+ delta_color="normal",
1138
+ )
1139
  else:
1140
  st.metric(tox_name, str(tox_value))
1141
+
1142
  # Show basic turn statistics when no complexity metrics available
1143
  st.markdown("**πŸ“ˆ Basic Statistics:**")
1144
  col1, col2, col3 = st.columns(3)
 
1148
  st.metric("Words", len(content.split()))
1149
  with col3:
1150
  st.metric("Role", role.title())
1151
+
1152
  # Add separator between turns
1153
  st.divider()
1154
+
1155
  # Plot metrics over turns with real data if available
1156
  if available_columns and sample_metrics is not None:
1157
  st.subheader("πŸ“ˆ Metrics Over Turns")
1158
+
1159
  fig = go.Figure()
1160
+
1161
  # Add traces for each selected metric (real data)
1162
  for col in available_columns[:5]: # Limit to first 5 for readability
1163
+ metric_name = col.replace("turn.turn_metrics.", "")
1164
  friendly_name = get_human_friendly_metric_name(metric_name)
1165
+
1166
  # Get values for this metric
1167
  y_values = []
1168
  for _, turn_row in sample_metrics.iterrows():
 
1171
  y_values.append(value)
1172
  else:
1173
  y_values.append(None)
1174
+
1175
  if any(v is not None for v in y_values):
1176
+ fig.add_trace(
1177
+ go.Scatter(
1178
+ x=list(range(1, len(y_values) + 1)),
1179
+ y=y_values,
1180
+ mode="lines+markers",
1181
+ name=friendly_name,
1182
+ line=dict(width=2),
1183
+ marker=dict(size=8),
1184
+ connectgaps=False,
1185
+ )
1186
+ )
1187
+
1188
  if fig.data: # Only show if we have data
1189
  fig.update_layout(
1190
  title="Complexity Metrics Across Conversation Turns",
1191
  xaxis_title="Turn Number",
1192
  yaxis_title="Metric Value",
1193
  height=400,
1194
+ hovermode="x unified",
1195
  )
1196
+
1197
  st.plotly_chart(fig, use_container_width=True)
1198
  else:
1199
+ st.info(
1200
+ "No numeric metric data available to plot for this conversation type."
1201
+ )
1202
+
1203
  elif selected_metrics:
1204
+ st.info(
1205
+ "Select metrics that are available in the dataset to see turn-level analysis."
1206
+ )
1207
  else:
1208
  st.warning("Select some metrics to see detailed turn-level analysis.")
1209
+
1210
  else:
1211
  st.warning("No conversation data available for the selected conversation.")
1212
+
1213
  with tab5:
1214
  st.header("Detailed View")
1215
+
1216
+ # Add button to trigger detailed analysis
1217
+ st.info("🎯 Generate detailed dataset analysis and visualizations")
1218
+
1219
+ col1, col2 = st.columns([1, 3])
 
 
 
1220
  with col1:
1221
+ show_details = st.button(
1222
+ "πŸ“Š Show Detailed Analysis",
1223
+ help="Generate comprehensive dataset overview and metric analysis",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1224
  )
1225
+ with col2:
1226
+ if len(selected_metrics) > 20:
1227
+ st.warning("⚠️ Large metric selection - analysis may take some time")
1228
+
1229
+ if show_details:
1230
+ with st.spinner("Generating detailed analysis..."):
1231
+ # Data overview
1232
+ st.subheader("πŸ“‹ Dataset Overview")
1233
+
1234
+ st.info(f"**Current Dataset:** `{selected_dataset}`")
1235
+
1236
+ col1, col2, col3 = st.columns(3)
1237
+
1238
+ with col1:
1239
+ st.metric("Total Conversations", len(filtered_df))
1240
+
1241
+ with col2:
1242
+ st.metric("Total Turns", len(filtered_df_exploded))
1243
+
1244
+ with col3:
1245
+ st.metric("Available Metrics", len(available_metrics))
1246
+
1247
+ # Type distribution
1248
+ st.subheader("πŸ“Š Type Distribution")
1249
+ type_counts = filtered_df["type"].value_counts()
1250
+
1251
+ fig = px.pie(
1252
+ values=type_counts.values,
1253
+ names=type_counts.index,
1254
+ title="Distribution of Conversation Types",
1255
+ color_discrete_map=PLOT_PALETTE if len(type_counts) <= 3 else None,
1256
+ )
1257
+
1258
+ st.plotly_chart(fig, use_container_width=True)
1259
+
1260
+ # Sample data
1261
+ st.subheader("πŸ“„ Sample Data")
1262
+
1263
+ if st.checkbox("Show raw data sample"):
1264
+ sample_cols = ["type"] + [
1265
+ f"turn.turn_metrics.{m}"
1266
+ for m in selected_metrics
1267
+ if f"turn.turn_metrics.{m}" in filtered_df_exploded.columns
1268
+ ]
1269
+ sample_data = filtered_df_exploded[sample_cols].head(100)
1270
+ st.dataframe(sample_data)
1271
+
1272
+ # Metric availability
1273
+ st.subheader("πŸ“Š Metric Availability")
1274
+
1275
+ metric_completeness = {}
1276
+ for metric in selected_metrics:
1277
+ full_metric_name = f"turn.turn_metrics.{metric}"
1278
+ if full_metric_name in filtered_df_exploded.columns:
1279
+ completeness = (
1280
+ 1
1281
+ - filtered_df_exploded[full_metric_name].isna().sum()
1282
+ / len(filtered_df_exploded)
1283
+ ) * 100
1284
+ metric_completeness[get_human_friendly_metric_name(metric)] = (
1285
+ completeness
1286
+ )
1287
+
1288
+ if metric_completeness:
1289
+ completeness_df = pd.DataFrame(
1290
+ list(metric_completeness.items()),
1291
+ columns=["Metric", "Completeness (%)"],
1292
+ )
1293
+ fig = px.bar(
1294
+ completeness_df,
1295
+ x="Metric",
1296
+ y="Completeness (%)",
1297
+ title="Data Completeness by Metric",
1298
+ color="Completeness (%)",
1299
+ color_continuous_scale="Viridis",
1300
+ )
1301
+ fig.update_layout(xaxis_tickangle=-45, height=400)
1302
+ st.plotly_chart(fig, use_container_width=True)
1303
+
1304
 
1305
  if __name__ == "__main__":
1306
  main()