Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Streamlit app for interactive complexity metrics visualization. | |
""" | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Import visualization utilities | |
from visualization.utils import ( | |
load_and_prepare_dataset, | |
get_available_turn_metrics, | |
get_human_friendly_metric_name, | |
clean_metric_values, | |
PLOT_PALETTE, | |
setup_plot_style | |
) | |
# Setup page config | |
st.set_page_config( | |
page_title="Complexity Metrics Explorer", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Cache data loading | |
def load_data(dataset_name): | |
"""Load and cache the dataset""" | |
df, df_exploded = load_and_prepare_dataset({ | |
'dataset_name': dataset_name | |
}) | |
return df, df_exploded | |
def get_metrics(df_exploded): | |
"""Get available metrics from the dataset""" | |
return get_available_turn_metrics(df_exploded) | |
def main(): | |
st.title("π Complexity Metrics Explorer") | |
st.markdown("Interactive visualization of conversation complexity metrics across different dataset types.") | |
# Dataset selection | |
st.sidebar.header("ποΈ Dataset Selection") | |
# Available datasets | |
available_datasets = [ | |
"jailbreaks_dataset_with_results_reduced", | |
"jailbreaks_dataset_with_results", | |
"jailbreaks_dataset_with_results_filtered_successful_jailbreak", | |
"Custom..." | |
] | |
selected_option = st.sidebar.selectbox( | |
"Select Dataset", | |
options=available_datasets, | |
index=0, # Default to reduced dataset | |
help="Choose which dataset to analyze" | |
) | |
# Handle custom dataset input | |
if selected_option == "Custom...": | |
selected_dataset = st.sidebar.text_input( | |
"Custom Dataset Name", | |
value="jailbreaks_dataset_with_results_reduced", | |
help="Enter the full dataset name (e.g., 'jailbreaks_dataset_with_results_reduced')" | |
) | |
if not selected_dataset.strip(): | |
st.sidebar.warning("Please enter a dataset name") | |
st.stop() | |
else: | |
selected_dataset = selected_option | |
# Add refresh button | |
if st.sidebar.button("π Refresh Data", help="Clear cache and reload dataset"): | |
st.cache_data.clear() | |
st.rerun() | |
# Load data | |
with st.spinner(f"Loading dataset: {selected_dataset}..."): | |
try: | |
df, df_exploded = load_data(selected_dataset) | |
available_metrics = get_metrics(df_exploded) | |
# Display dataset info | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Dataset", selected_dataset.split('_')[-1].title()) | |
with col2: | |
st.metric("Conversations", f"{len(df):,}") | |
with col3: | |
st.metric("Turns", f"{len(df_exploded):,}") | |
with col4: | |
st.metric("Metrics", len(available_metrics)) | |
data_loaded = True | |
except Exception as e: | |
st.error(f"Error loading dataset: {e}") | |
st.info("Please check if the dataset exists and is accessible.") | |
st.info("π‘ Try using one of the predefined dataset options instead of custom input.") | |
data_loaded = False | |
if not data_loaded: | |
st.stop() | |
# Sidebar controls | |
st.sidebar.header("ποΈ Controls") | |
# Dataset type filter | |
dataset_types = df['type'].unique() | |
selected_types = st.sidebar.multiselect( | |
"Select Dataset Types", | |
options=dataset_types, | |
default=dataset_types, | |
help="Filter by conversation type" | |
) | |
# Role filter | |
if 'turn.role' in df_exploded.columns: | |
roles = df_exploded['turn.role'].unique() | |
selected_roles = st.sidebar.multiselect( | |
"Select Roles", | |
options=roles, | |
default=roles, | |
help="Filter by turn role" | |
) | |
else: | |
selected_roles = None | |
# Metric selection | |
st.sidebar.header("π Metrics") | |
# Dynamic metric categorization based on common patterns | |
def categorize_metrics(metrics): | |
"""Dynamically categorize metrics based on naming patterns""" | |
categories = {"All": metrics} # Always include all metrics | |
# Common patterns to look for | |
patterns = { | |
"Length": ['length', 'byte', 'word', 'token', 'char'], | |
"Readability": ['readability', 'flesch', 'standard'], | |
"Compression": ['lzw', 'compression'], | |
"Language Model": ['ll_', 'rll_', 'logprob'], | |
"Working Memory": ['wm_'], | |
"Discourse": ['discourse'], | |
"Evaluation": ['rubric', 'evaluation', 'stealth'], | |
"Distribution": ['zipf', 'type_token'], | |
"Coherence": ['coherence'], | |
"Entity": ['entity', 'entities'], | |
"Cognitive": ['cognitive', 'load'], | |
} | |
# Categorize metrics | |
for category, keywords in patterns.items(): | |
matching_metrics = [m for m in metrics if any(keyword in m.lower() for keyword in keywords)] | |
if matching_metrics: | |
categories[category] = matching_metrics | |
# Find uncategorized metrics | |
categorized = set() | |
for cat_metrics in categories.values(): | |
if cat_metrics != metrics: # Skip "All" category | |
categorized.update(cat_metrics) | |
uncategorized = [m for m in metrics if m not in categorized] | |
if uncategorized: | |
categories["Other"] = uncategorized | |
return categories | |
metric_categories = categorize_metrics(available_metrics) | |
# Metric selection interface | |
selection_mode = st.sidebar.radio( | |
"Selection Mode", | |
["By Category", "Search/Filter", "Select All"], | |
help="Choose how to select metrics" | |
) | |
if selection_mode == "By Category": | |
selected_category = st.sidebar.selectbox( | |
"Metric Category", | |
options=list(metric_categories.keys()), | |
help=f"Found {len(metric_categories)} categories" | |
) | |
available_in_category = metric_categories[selected_category] | |
default_selection = available_in_category[:5] if len(available_in_category) > 5 else available_in_category | |
# Add select all button for category | |
col1, col2 = st.sidebar.columns(2) | |
with col1: | |
if st.button("Select All", key="select_all_category"): | |
st.session_state.selected_metrics_category = available_in_category | |
with col2: | |
if st.button("Clear All", key="clear_all_category"): | |
st.session_state.selected_metrics_category = [] | |
# Use session state for persistence | |
if "selected_metrics_category" not in st.session_state: | |
st.session_state.selected_metrics_category = default_selection | |
selected_metrics = st.sidebar.multiselect( | |
f"Select Metrics ({len(available_in_category)} available)", | |
options=available_in_category, | |
default=st.session_state.selected_metrics_category, | |
key="metrics_multiselect_category", | |
help="Choose metrics to visualize" | |
) | |
elif selection_mode == "Search/Filter": | |
search_term = st.sidebar.text_input( | |
"Search Metrics", | |
placeholder="Enter keywords to filter metrics...", | |
help="Search for metrics containing specific terms" | |
) | |
if search_term: | |
filtered_metrics = [m for m in available_metrics if search_term.lower() in m.lower()] | |
else: | |
filtered_metrics = available_metrics | |
st.sidebar.write(f"Found {len(filtered_metrics)} metrics") | |
# Add select all button for search results | |
col1, col2 = st.sidebar.columns(2) | |
with col1: | |
if st.button("Select All", key="select_all_search"): | |
st.session_state.selected_metrics_search = filtered_metrics | |
with col2: | |
if st.button("Clear All", key="clear_all_search"): | |
st.session_state.selected_metrics_search = [] | |
# Use session state for persistence | |
if "selected_metrics_search" not in st.session_state: | |
st.session_state.selected_metrics_search = filtered_metrics[:5] if len(filtered_metrics) > 5 else filtered_metrics[:3] | |
selected_metrics = st.sidebar.multiselect( | |
"Select Metrics", | |
options=filtered_metrics, | |
default=st.session_state.selected_metrics_search, | |
key="metrics_multiselect_search", | |
help="Choose metrics to visualize" | |
) | |
else: # Select All | |
# Add select all button for all metrics | |
col1, col2 = st.sidebar.columns(2) | |
with col1: | |
if st.button("Select All", key="select_all_all"): | |
st.session_state.selected_metrics_all = available_metrics | |
with col2: | |
if st.button("Clear All", key="clear_all_all"): | |
st.session_state.selected_metrics_all = [] | |
# Use session state for persistence | |
if "selected_metrics_all" not in st.session_state: | |
st.session_state.selected_metrics_all = available_metrics[:10] # Limit default to first 10 for performance | |
selected_metrics = st.sidebar.multiselect( | |
f"All Metrics ({len(available_metrics)} total)", | |
options=available_metrics, | |
default=st.session_state.selected_metrics_all, | |
key="metrics_multiselect_all", | |
help="All available metrics - be careful with performance for large selections" | |
) | |
# Show selection summary | |
if selected_metrics: | |
st.sidebar.success(f"Selected {len(selected_metrics)} metrics") | |
# Performance warning for large selections | |
if len(selected_metrics) > 20: | |
st.sidebar.warning(f"β οΈ Large selection ({len(selected_metrics)} metrics) may impact performance") | |
elif len(selected_metrics) > 50: | |
st.sidebar.error(f"π¨ Very large selection ({len(selected_metrics)} metrics) - consider reducing for better performance") | |
else: | |
st.sidebar.warning("No metrics selected") | |
# Metric info expander | |
with st.sidebar.expander("βΉοΈ Metric Information", expanded=False): | |
st.write(f"**Total Available Metrics:** {len(available_metrics)}") | |
st.write(f"**Categories Found:** {len(metric_categories)}") | |
if st.checkbox("Show all metric names", key="show_all_metrics"): | |
st.write("**All Available Metrics:**") | |
for i, metric in enumerate(available_metrics, 1): | |
st.write(f"{i}. `{metric}`") | |
# Filter data | |
filtered_df = df[df['type'].isin(selected_types)] if selected_types else df | |
filtered_df_exploded = df_exploded[df_exploded['type'].isin(selected_types)] if selected_types else df_exploded | |
if selected_roles and 'turn.role' in filtered_df_exploded.columns: | |
filtered_df_exploded = filtered_df_exploded[filtered_df_exploded['turn.role'].isin(selected_roles)] | |
# Main content tabs | |
tab1, tab2, tab3, tab4 = st.tabs(["π Distributions", "π Correlations", "π Comparisons", "π― Details"]) | |
with tab1: | |
st.header("Distribution Analysis") | |
if not selected_metrics: | |
st.warning("Please select at least one metric to visualize.") | |
return | |
# Create distribution plots | |
for metric in selected_metrics: | |
full_metric_name = f"turn.turn_metrics.{metric}" | |
if full_metric_name not in filtered_df_exploded.columns: | |
st.warning(f"Metric {metric} not found in dataset") | |
continue | |
st.subheader(f"π {get_human_friendly_metric_name(metric)}") | |
# Clean the data | |
metric_data = filtered_df_exploded[['type', full_metric_name]].copy() | |
metric_data = metric_data.dropna() | |
if len(metric_data) == 0: | |
st.warning(f"No data available for {metric}") | |
continue | |
# Create plotly histogram | |
fig = px.histogram( | |
metric_data, | |
x=full_metric_name, | |
color='type', | |
marginal='box', | |
title=f"Distribution of {get_human_friendly_metric_name(metric)}", | |
color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None, | |
opacity=0.7, | |
nbins=50 | |
) | |
fig.update_layout( | |
xaxis_title=get_human_friendly_metric_name(metric), | |
yaxis_title="Count", | |
height=400 | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Summary statistics | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("**Summary Statistics**") | |
summary_stats = metric_data.groupby('type')[full_metric_name].agg(['count', 'mean', 'std', 'min', 'max']).round(3) | |
st.dataframe(summary_stats) | |
with col2: | |
st.write("**Percentiles**") | |
percentiles = metric_data.groupby('type')[full_metric_name].quantile([0.25, 0.5, 0.75]).unstack().round(3) | |
percentiles.columns = ['25%', '50%', '75%'] | |
st.dataframe(percentiles) | |
with tab2: | |
st.header("Correlation Analysis") | |
if len(selected_metrics) < 2: | |
st.warning("Please select at least 2 metrics for correlation analysis.") | |
else: | |
# Prepare correlation data | |
corr_columns = [f"turn.turn_metrics.{m}" for m in selected_metrics] | |
corr_data = filtered_df_exploded[corr_columns + ['type']].copy() | |
# Clean column names for display | |
corr_data.columns = [get_human_friendly_metric_name(col.replace('turn.turn_metrics.', '')) if col.startswith('turn.turn_metrics.') else col for col in corr_data.columns] | |
# Calculate correlation matrix | |
corr_matrix = corr_data.select_dtypes(include=[np.number]).corr() | |
# Create correlation heatmap | |
fig = px.imshow( | |
corr_matrix, | |
text_auto=True, | |
aspect="auto", | |
title="Correlation Matrix", | |
color_continuous_scale='RdBu_r', | |
zmin=-1, zmax=1 | |
) | |
fig.update_layout(height=600) | |
st.plotly_chart(fig, use_container_width=True) | |
# Scatter plots for strong correlations | |
st.subheader("Strong Correlations") | |
# Find strong correlations (>0.7 or <-0.7) | |
strong_corrs = [] | |
for i in range(len(corr_matrix.columns)): | |
for j in range(i+1, len(corr_matrix.columns)): | |
corr_val = corr_matrix.iloc[i, j] | |
if abs(corr_val) > 0.7: | |
strong_corrs.append((corr_matrix.columns[i], corr_matrix.columns[j], corr_val)) | |
if strong_corrs: | |
for metric1, metric2, corr_val in strong_corrs[:3]: # Show top 3 | |
fig = px.scatter( | |
corr_data, | |
x=metric1, | |
y=metric2, | |
color='type', | |
title=f"{metric1} vs {metric2} (r={corr_val:.3f})", | |
color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None, | |
opacity=0.6 | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
st.info("No strong correlations (|r| > 0.7) found between selected metrics.") | |
with tab3: | |
st.header("Type Comparisons") | |
if not selected_metrics: | |
st.warning("Please select at least one metric to compare.") | |
else: | |
# Box plots for each metric | |
for metric in selected_metrics: | |
full_metric_name = f"turn.turn_metrics.{metric}" | |
if full_metric_name not in filtered_df_exploded.columns: | |
continue | |
st.subheader(f"π¦ {get_human_friendly_metric_name(metric)} by Type") | |
# Create box plot | |
fig = px.box( | |
filtered_df_exploded.dropna(subset=[full_metric_name]), | |
x='type', | |
y=full_metric_name, | |
title=f"Distribution of {get_human_friendly_metric_name(metric)} by Type", | |
color='type', | |
color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None | |
) | |
fig.update_layout( | |
xaxis_title="Dataset Type", | |
yaxis_title=get_human_friendly_metric_name(metric), | |
height=400 | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
with tab4: | |
st.header("Detailed View") | |
# Data overview | |
st.subheader("π Dataset Overview") | |
st.info(f"**Current Dataset:** `{selected_dataset}`") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Total Conversations", len(filtered_df)) | |
with col2: | |
st.metric("Total Turns", len(filtered_df_exploded)) | |
with col3: | |
st.metric("Available Metrics", len(available_metrics)) | |
# Type distribution | |
st.subheader("π Type Distribution") | |
type_counts = filtered_df['type'].value_counts() | |
fig = px.pie( | |
values=type_counts.values, | |
names=type_counts.index, | |
title="Distribution of Conversation Types", | |
color_discrete_map=PLOT_PALETTE if len(type_counts) <= 3 else None | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Sample data | |
st.subheader("π Sample Data") | |
if st.checkbox("Show raw data sample"): | |
sample_cols = ['type'] + [f"turn.turn_metrics.{m}" for m in selected_metrics if f"turn.turn_metrics.{m}" in filtered_df_exploded.columns] | |
sample_data = filtered_df_exploded[sample_cols].head(100) | |
st.dataframe(sample_data) | |
# Metric availability | |
st.subheader("π Metric Availability") | |
metric_completeness = {} | |
for metric in selected_metrics: | |
full_metric_name = f"turn.turn_metrics.{metric}" | |
if full_metric_name in filtered_df_exploded.columns: | |
completeness = (1 - filtered_df_exploded[full_metric_name].isna().sum() / len(filtered_df_exploded)) * 100 | |
metric_completeness[get_human_friendly_metric_name(metric)] = completeness | |
if metric_completeness: | |
completeness_df = pd.DataFrame(list(metric_completeness.items()), columns=['Metric', 'Completeness (%)']) | |
fig = px.bar( | |
completeness_df, | |
x='Metric', | |
y='Completeness (%)', | |
title="Data Completeness by Metric", | |
color='Completeness (%)', | |
color_continuous_scale='Viridis' | |
) | |
fig.update_layout(xaxis_tickangle=-45, height=400) | |
st.plotly_chart(fig, use_container_width=True) | |
if __name__ == "__main__": | |
main() | |