Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Streamlit app for interactive complexity metrics visualization. | |
""" | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import warnings | |
import datasets | |
import logging | |
warnings.filterwarnings("ignore") | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
PLOT_PALETTE = { | |
"jailbreak": "#D000D8", # Purple | |
"benign": "#008393", # Cyan | |
"control": "#EF0000", # Red | |
} | |
# Utility functions | |
def load_and_prepare_dataset(dataset_config): | |
"""Load the risky conversations dataset and prepare it for analysis.""" | |
logger.info("Loading dataset...") | |
dataset_name = dataset_config["dataset_name"] | |
logger.info(f"Loading dataset: {dataset_name}") | |
# Load the dataset | |
dataset = datasets.load_dataset(dataset_name, split="train") | |
logger.info(f"Dataset loaded with {len(dataset)} conversations") | |
# Convert to pandas | |
pandas_dataset = dataset.to_pandas() | |
# Explode the conversation column | |
pandas_dataset_exploded = pandas_dataset.explode("conversation") | |
pandas_dataset_exploded = pandas_dataset_exploded.reset_index(drop=True) | |
# Normalize conversation data | |
conversations_unfolded = pd.json_normalize( | |
pandas_dataset_exploded["conversation"], | |
) | |
conversations_unfolded = conversations_unfolded.add_prefix("turn.") | |
# Ensure there's a 'conversation_metrics' column, even if empty | |
if "conversation_metrics" not in pandas_dataset_exploded.columns: | |
pandas_dataset_exploded["conversation_metrics"] = [{}] * len( | |
pandas_dataset_exploded | |
) | |
# Normalize conversation metrics | |
conversations_metrics_unfolded = pd.json_normalize( | |
pandas_dataset_exploded["conversation_metrics"] | |
) | |
conversations_metrics_unfolded = conversations_metrics_unfolded.add_prefix( | |
"conversation_metrics." | |
) | |
# Concatenate all dataframes | |
pandas_dataset_exploded = pd.concat( | |
[ | |
pandas_dataset_exploded.drop( | |
columns=["conversation", "conversation_metrics"] | |
), | |
conversations_unfolded, | |
conversations_metrics_unfolded, | |
], | |
axis=1, | |
) | |
logger.info(f"Dataset prepared with {len(pandas_dataset_exploded)} turns") | |
return pandas_dataset, pandas_dataset_exploded | |
def get_available_turn_metrics(dataset_exploded): | |
"""Dynamically discover all available turn metrics from the dataset.""" | |
# Find all columns that contain turn metrics | |
turn_metric_columns = [ | |
col for col in dataset_exploded.columns if col.startswith("turn.turn_metrics.") | |
] | |
# Extract the metric names by removing the prefix | |
metrics = [col.replace("turn.turn_metrics.", "") for col in turn_metric_columns] | |
# Sort metrics for consistent ordering | |
metrics = sorted(metrics) | |
logger.info(f"Discovered {len(metrics)} available turn metrics") | |
return metrics | |
def get_human_friendly_metric_name(metric_name): | |
"""Convert a metric name to a human-friendly display name.""" | |
# Mapping of metric names to human-friendly display names | |
metric_display_names = { | |
# Length metrics | |
"conversational_length_tokens": "Length in Tokens", | |
"conversational_length_words": "Length in Words", | |
"conversational_length_bytes": "Length in Bytes", | |
# Readability metrics | |
"readability_flesch_score": "Flesch Reading Ease", | |
"readability_standard_score": "Reading Grade Level", | |
# Compression metrics | |
"lzw_compression": "LZW Compression Ratio", | |
# Language model metrics | |
"ll_mean_logprobs": "Mean Log Probability", | |
"ll_max_logprobs": "Maximum Log Probability", | |
"ll_min_logprobs": "Minimum Log Probability", | |
"rll_mean_logprobs": "Reverse Mean Log Probability", | |
"rll_max_logprobs": "Reverse Maximum Log Probability", | |
"rll_min_logprobs": "Reverse Minimum Log Probability", | |
# Lexical diversity | |
"type_token_ratio": "Type-Token Ratio", | |
# Language distribution | |
"zipf_score": "Zipf Distribution Score", | |
# Evaluation metrics | |
"rubric_evaluation.Stealth": "Stealth Score", | |
# Working memory metrics | |
"wm_total_entity_references": "Total Entity References", | |
"wm_unique_entities_count": "Unique Entities Count", | |
"wm_named_entity_mentions": "Named Entity Mentions", | |
"wm_concept_noun_mentions": "Concept Noun Mentions", | |
"wm_pronoun_references": "Pronoun References", | |
"wm_entity_density_per_word": "Entity Density per Word", | |
"wm_entity_density_per_100_words": "Entity Density per 100 Words", | |
"wm_entity_density_per_100_chars": "Entity Density per 100 Characters", | |
"wm_entity_diversity_ratio": "Entity Diversity Ratio", | |
"wm_entity_repetition_ratio": "Entity Repetition Ratio", | |
"wm_cognitive_load_score": "Cognitive Load Score", | |
"wm_high_cognitive_load": "High Cognitive Load", | |
# Discourse coherence metrics | |
"discourse_coherence_to_next_user": "Coherence to Next User Turn", | |
"discourse_coherence_to_next_turn": "Coherence to Next Turn", | |
"discourse_mean_user_coherence": "Mean User Coherence", | |
"discourse_user_coherence_variance": "User Coherence Variance", | |
"discourse_user_topic_drift": "User Topic Drift", | |
"discourse_user_entity_continuity": "User Entity Continuity", | |
"discourse_num_user_turns": "Number of User Turns", | |
# Tokens per byte | |
"tokens_per_byte": "Tokens per Byte", | |
} | |
# Check exact match first | |
if metric_name in metric_display_names: | |
return metric_display_names[metric_name] | |
# Handle conversation-level aggregations | |
for suffix in [ | |
"_conversation_mean", | |
"_conversation_min", | |
"_conversation_max", | |
"_conversation_std", | |
"_conversation_count", | |
]: | |
if metric_name.endswith(suffix): | |
base_metric = metric_name[: -len(suffix)] | |
if base_metric in metric_display_names: | |
agg_type = suffix.split("_")[-1].title() | |
return f"{metric_display_names[base_metric]} ({agg_type})" | |
# Handle turn-level metrics with "turn.turn_metrics." prefix | |
if metric_name.startswith("turn.turn_metrics."): | |
base_metric = metric_name[len("turn.turn_metrics.") :] | |
if base_metric in metric_display_names: | |
return metric_display_names[base_metric] | |
# Fallback: convert underscores to spaces and title case | |
clean_name = metric_name | |
for prefix in ["turn.turn_metrics.", "conversation_metrics.", "turn_metrics."]: | |
if clean_name.startswith(prefix): | |
clean_name = clean_name[len(prefix) :] | |
break | |
# Convert to human-readable format | |
clean_name = clean_name.replace("_", " ").title() | |
return clean_name | |
def render_metric_distribution(metric, filtered_df_exploded, selected_types): | |
"""Render distribution plot for a single metric.""" | |
full_metric_name = f"turn.turn_metrics.{metric}" | |
if full_metric_name not in filtered_df_exploded.columns: | |
st.warning(f"Metric {metric} not found in dataset") | |
return | |
st.subheader(f"π {get_human_friendly_metric_name(metric)}") | |
# Clean the data | |
metric_data = filtered_df_exploded[["type", full_metric_name]].copy() | |
metric_data = metric_data.dropna() | |
if len(metric_data) == 0: | |
st.warning(f"No data available for {metric}") | |
return | |
# Create plotly histogram | |
fig = px.histogram( | |
metric_data, | |
x=full_metric_name, | |
color="type", | |
marginal="box", | |
title=f"Distribution of {get_human_friendly_metric_name(metric)}", | |
color_discrete_map=PLOT_PALETTE if len(selected_types) <= 3 else None, | |
opacity=0.7, | |
nbins=50, | |
) | |
fig.update_layout( | |
xaxis_title=get_human_friendly_metric_name(metric), | |
yaxis_title="Count", | |
height=400, | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Summary statistics | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("**Summary Statistics**") | |
summary_stats = ( | |
metric_data.groupby("type")[full_metric_name] | |
.agg(["count", "mean", "std", "min", "max"]) | |
.round(3) | |
) | |
st.dataframe(summary_stats) | |
with col2: | |
st.write("**Percentiles**") | |
percentiles = ( | |
metric_data.groupby("type")[full_metric_name] | |
.quantile([0.25, 0.5, 0.75]) | |
.unstack() | |
.round(3) | |
) | |
percentiles.columns = ["25%", "50%", "75%"] | |
st.dataframe(percentiles) | |
# Setup page config | |
st.set_page_config( | |
page_title="Complexity Metrics Explorer", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# Cache data loading | |
def load_data(dataset_name): | |
"""Load and cache the dataset""" | |
df, df_exploded = load_and_prepare_dataset({"dataset_name": dataset_name}) | |
return df, df_exploded | |
def get_metrics(df_exploded): | |
"""Get available metrics from the dataset""" | |
return get_available_turn_metrics(df_exploded) | |
def main(): | |
st.title("π Complexity Metrics Explorer") | |
st.markdown( | |
"Interactive visualization of conversation complexity metrics across different dataset types." | |
) | |
# Dataset selection | |
st.sidebar.header("ποΈ Dataset Selection") | |
# Available datasets | |
available_datasets = [ | |
"risky-conversations/jailbreaks_dataset_with_results_reduced", | |
"risky-conversations/jailbreaks_dataset_with_results", | |
"risky-conversations/jailbreaks_dataset_with_results_filtered_successful_jailbreak", | |
"Custom...", | |
] | |
selected_option = st.sidebar.selectbox( | |
"Select Dataset", | |
options=available_datasets, | |
index=0, # Default to reduced dataset | |
help="Choose which dataset to analyze", | |
) | |
# Handle custom dataset input | |
if selected_option == "Custom...": | |
selected_dataset = st.sidebar.text_input( | |
"Custom Dataset Name", | |
value="risky-conversations/jailbreaks_dataset_with_results_reduced", | |
help="Enter the full dataset name (e.g., 'risky-conversations/jailbreaks_dataset_with_results_reduced')", | |
) | |
if not selected_dataset.strip(): | |
st.sidebar.warning("Please enter a dataset name") | |
st.stop() | |
else: | |
selected_dataset = selected_option | |
# Add refresh button | |
if st.sidebar.button("π Refresh Data", help="Clear cache and reload dataset"): | |
st.cache_data.clear() | |
st.rerun() | |
# Load data | |
with st.spinner(f"Loading dataset: {selected_dataset}..."): | |
try: | |
df, df_exploded = load_data(selected_dataset) | |
available_metrics = get_metrics(df_exploded) | |
# Display dataset info | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Dataset", selected_dataset.split("_")[-1].title()) | |
with col2: | |
st.metric("Conversations", f"{len(df):,}") | |
with col3: | |
st.metric("Turns", f"{len(df_exploded):,}") | |
with col4: | |
st.metric("Metrics", len(available_metrics)) | |
data_loaded = True | |
except Exception as e: | |
st.error(f"Error loading dataset: {e}") | |
st.info("Please check if the dataset exists and is accessible.") | |
st.info( | |
"π‘ Try using one of the predefined dataset options instead of custom input." | |
) | |
data_loaded = False | |
if not data_loaded: | |
st.stop() | |
# Sidebar controls | |
st.sidebar.header("ποΈ Controls") | |
# Dataset type filter | |
dataset_types = df["type"].unique() | |
selected_types = st.sidebar.multiselect( | |
"Select Dataset Types", | |
options=dataset_types, | |
default=dataset_types, | |
help="Filter by conversation type", | |
) | |
# Role filter | |
if "turn.role" in df_exploded.columns: | |
roles = df_exploded["turn.role"].dropna().unique() | |
# Assert only user and assistant roles exist | |
expected_roles = {"user", "assistant"} | |
actual_roles = set(roles) | |
assert actual_roles.issubset( | |
expected_roles | |
), f"Unexpected roles found: {actual_roles - expected_roles}. Expected only 'user' and 'assistant'" | |
st.sidebar.subheader("π₯ Role Filter") | |
col1, col2 = st.sidebar.columns(2) | |
with col1: | |
include_user = st.checkbox("User", value=True, help="Include user turns") | |
with col2: | |
include_assistant = st.checkbox( | |
"Assistant", value=True, help="Include assistant turns" | |
) | |
# Build selected roles list | |
selected_roles = [] | |
if include_user and "user" in roles: | |
selected_roles.append("user") | |
if include_assistant and "assistant" in roles: | |
selected_roles.append("assistant") | |
# Show selection info | |
if selected_roles: | |
st.sidebar.success(f"Including: {', '.join(selected_roles)}") | |
else: | |
st.sidebar.warning("No roles selected") | |
else: | |
selected_roles = None | |
# Filter data based on selections | |
filtered_df = df[df["type"].isin(selected_types)] if selected_types else df | |
filtered_df_exploded = ( | |
df_exploded[df_exploded["type"].isin(selected_types)] | |
if selected_types | |
else df_exploded | |
) | |
if selected_roles and "turn.role" in filtered_df_exploded.columns: | |
filtered_df_exploded = filtered_df_exploded[ | |
filtered_df_exploded["turn.role"].isin(selected_roles) | |
] | |
elif selected_roles is not None and len(selected_roles) == 0: | |
# If roles exist but none are selected, show empty dataset | |
filtered_df_exploded = filtered_df_exploded.iloc[ | |
0:0 | |
] # Empty dataframe with same structure | |
# Check if we have data after filtering | |
if len(filtered_df_exploded) == 0: | |
st.error( | |
"No data available with current filters. Please adjust your selection." | |
) | |
st.stop() | |
# Main content tabs | |
tab1, tab2, tab3, tab4, tab5 = st.tabs( | |
[ | |
"π Distributions", | |
"π Correlations", | |
"π Comparisons", | |
"π Conversation", | |
"π― Details", | |
] | |
) | |
# Make available metrics accessible to all tabs | |
available_metrics_for_analysis = available_metrics | |
with tab1: | |
st.header("Distribution Analysis") | |
# Simple metric selection - just show all metrics with checkboxes | |
st.subheader("π Select Metrics to Plot") | |
st.info(f"**{len(available_metrics)} metrics available** - Check the boxes below to plot their distributions") | |
# Optional: Add search functionality to help users find metrics | |
search_term = st.text_input( | |
"π Search metrics (optional)", | |
placeholder="Enter keywords to filter metrics...", | |
help="Search for metrics containing specific terms" | |
) | |
if search_term: | |
filtered_metrics = [ | |
m for m in available_metrics if search_term.lower() in m.lower() | |
] | |
st.write(f"**{len(filtered_metrics)} metrics** match your search") | |
else: | |
filtered_metrics = available_metrics | |
# Create checkboxes for each metric to allow multiple selections | |
if not filtered_metrics: | |
st.warning("No metrics found. Try adjusting your search.") | |
else: | |
# Organize checkboxes in columns for better layout | |
cols_per_row = 3 | |
selected_for_plotting = [] | |
for i in range(0, len(filtered_metrics), cols_per_row): | |
cols = st.columns(cols_per_row) | |
for j, metric in enumerate(filtered_metrics[i : i + cols_per_row]): | |
with cols[j]: | |
friendly_name = get_human_friendly_metric_name(metric) | |
# Truncate checkbox text if too long | |
checkbox_text = ( | |
friendly_name[:25] + "..." | |
if len(friendly_name) > 25 | |
else friendly_name | |
) | |
if st.checkbox( | |
f"π {checkbox_text}", | |
key=f"plot_{metric}", | |
help=f"Plot distribution for {friendly_name}", | |
): | |
selected_for_plotting.append(metric) | |
# Render selected metrics | |
if selected_for_plotting: | |
st.success(f"Plotting {len(selected_for_plotting)} selected metrics...") | |
for metric in selected_for_plotting: | |
render_metric_distribution( | |
metric, filtered_df_exploded, selected_types | |
) | |
else: | |
st.info("π Check the boxes above to plot metric distributions") | |
with tab2: | |
st.header("Correlation Analysis") | |
if len(available_metrics_for_analysis) < 2: | |
st.warning("Please select at least 2 metrics for correlation analysis.") | |
else: | |
# Add button to trigger correlation analysis | |
st.info( | |
f"π Ready to analyze correlations between {len(available_metrics_for_analysis)} metrics" | |
) | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
run_correlation = st.button( | |
"π Run Correlation Analysis", | |
help="Calculate and display correlation matrix and scatter plots", | |
) | |
with col2: | |
if len(available_metrics_for_analysis) > 10: | |
st.warning( | |
f"β οΈ Large analysis ({len(available_metrics_for_analysis)} metrics) - may take some time" | |
) | |
if run_correlation: | |
with st.spinner("Calculating correlations..."): | |
# Prepare correlation data | |
corr_columns = [f"turn.turn_metrics.{m}" for m in available_metrics_for_analysis] | |
corr_data = filtered_df_exploded[corr_columns + ["type"]].copy() | |
# Clean column names for display | |
corr_data.columns = [ | |
( | |
get_human_friendly_metric_name( | |
col.replace("turn.turn_metrics.", "") | |
) | |
if col.startswith("turn.turn_metrics.") | |
else col | |
) | |
for col in corr_data.columns | |
] | |
# Calculate correlation matrix | |
corr_matrix = corr_data.select_dtypes(include=[np.number]).corr() | |
# Create correlation heatmap | |
fig = px.imshow( | |
corr_matrix, | |
text_auto=True, | |
aspect="auto", | |
title="Correlation Matrix", | |
color_continuous_scale="RdBu_r", | |
zmin=-1, | |
zmax=1, | |
) | |
fig.update_layout(height=600) | |
st.plotly_chart(fig, use_container_width=True) | |
# Scatter plots for strong correlations | |
st.subheader("Strong Correlations") | |
# Find strong correlations (>0.7 or <-0.7) | |
strong_corrs = [] | |
for i in range(len(corr_matrix.columns)): | |
for j in range(i + 1, len(corr_matrix.columns)): | |
corr_val = corr_matrix.iloc[i, j] | |
if abs(corr_val) > 0.7: | |
strong_corrs.append( | |
( | |
corr_matrix.columns[i], | |
corr_matrix.columns[j], | |
corr_val, | |
) | |
) | |
if strong_corrs: | |
for metric1, metric2, corr_val in strong_corrs[ | |
:3 | |
]: # Show top 3 | |
fig = px.scatter( | |
corr_data, | |
x=metric1, | |
y=metric2, | |
color="type", | |
title=f"{metric1} vs {metric2} (r={corr_val:.3f})", | |
color_discrete_map=( | |
PLOT_PALETTE if len(selected_types) <= 3 else None | |
), | |
opacity=0.6, | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
st.info( | |
"No strong correlations (|r| > 0.7) found between selected metrics." | |
) | |
with tab3: | |
st.header("Type Comparisons") | |
if not available_metrics_for_analysis: | |
st.warning("Please select at least one metric to compare.") | |
else: | |
# Box plots for each metric | |
for metric in available_metrics_for_analysis: | |
full_metric_name = f"turn.turn_metrics.{metric}" | |
if full_metric_name not in filtered_df_exploded.columns: | |
continue | |
st.subheader(f"π¦ {get_human_friendly_metric_name(metric)} by Type") | |
# Create box plot | |
fig = px.box( | |
filtered_df_exploded.dropna(subset=[full_metric_name]), | |
x="type", | |
y=full_metric_name, | |
title=f"Distribution of {get_human_friendly_metric_name(metric)} by Type", | |
color="type", | |
color_discrete_map=( | |
PLOT_PALETTE if len(selected_types) <= 3 else None | |
), | |
) | |
fig.update_layout( | |
xaxis_title="Dataset Type", | |
yaxis_title=get_human_friendly_metric_name(metric), | |
height=400, | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
with tab4: | |
st.header("Individual Conversation Analysis") | |
# Conversation selector | |
st.subheader("π Select Conversation") | |
# Get total number of conversations and basic info | |
total_conversations = len(filtered_df) | |
available_indices = list(filtered_df.index) | |
st.info(f"π Dataset contains {total_conversations:,} conversations (indices: {min(available_indices)} to {max(available_indices)})") | |
# Conversation selection with number input | |
col1, col2, col3 = st.columns([2, 1, 1]) | |
with col1: | |
selected_idx = st.number_input( | |
"Conversation Index", | |
min_value=min(available_indices), | |
max_value=max(available_indices), | |
value=available_indices[0], # Default to first available | |
step=1, | |
help=f"Enter a conversation index between {min(available_indices)} and {max(available_indices)}" | |
) | |
with col2: | |
if st.button("π² Random", help="Select a random conversation"): | |
import random | |
selected_idx = random.choice(available_indices) | |
st.rerun() | |
with col3: | |
if st.button("βΉοΈ Info", help="Show conversation preview"): | |
if selected_idx in available_indices: | |
preview_row = filtered_df.loc[selected_idx] | |
st.info(f"**Type:** {preview_row['type']} | **Turns:** {len(preview_row.get('conversation', []))}") | |
else: | |
st.error("Invalid conversation index") | |
# Validate and get the selected conversation data | |
if selected_idx not in available_indices: | |
st.error(f"β Conversation index {selected_idx} not found in filtered dataset. Available range: {min(available_indices)} to {max(available_indices)}") | |
st.stop() | |
selected_conversation = filtered_df.loc[selected_idx] | |
# Display conversation metadata | |
st.subheader("π Conversation Overview") | |
# First row - basic info | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Type", selected_conversation["type"]) | |
with col2: | |
st.metric("Index", selected_idx) | |
with col3: | |
st.metric("Total Turns", len(selected_conversation.get("conversation", []))) | |
with col4: | |
# Count user vs assistant turns | |
roles = [ | |
turn.get("role", "unknown") | |
for turn in selected_conversation.get("conversation", []) | |
] | |
user_turns = roles.count("user") | |
assistant_turns = roles.count("assistant") | |
st.metric("User/Assistant", f"{user_turns}/{assistant_turns}") | |
# Second row - additional metadata | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
provenance = selected_conversation.get("provenance_dataset", "Unknown") | |
st.metric("Dataset Source", provenance) | |
with col2: | |
language = selected_conversation.get("language", "Unknown") | |
st.metric("Language", language.upper() if language else "Unknown") | |
with col3: | |
timestamp = selected_conversation.get("timestamp", None) | |
if timestamp: | |
# Handle different timestamp formats | |
if isinstance(timestamp, str): | |
st.metric("Timestamp", timestamp) | |
else: | |
st.metric("Timestamp", str(timestamp)) | |
else: | |
st.metric("Timestamp", "Not Available") | |
# Add toxicity summary | |
conversation_turns_temp = selected_conversation.get("conversation", []) | |
if hasattr(conversation_turns_temp, "tolist"): | |
conversation_turns_temp = conversation_turns_temp.tolist() | |
elif conversation_turns_temp is None: | |
conversation_turns_temp = [] | |
if len(conversation_turns_temp) > 0: | |
# Calculate overall toxicity statistics | |
all_toxicities = [] | |
for turn in conversation_turns_temp: | |
toxicities = turn.get("toxicities", {}) | |
if toxicities and "toxicity" in toxicities: | |
all_toxicities.append(toxicities["toxicity"]) | |
if all_toxicities: | |
avg_toxicity = sum(all_toxicities) / len(all_toxicities) | |
max_toxicity = max(all_toxicities) | |
st.markdown("**π Toxicity Summary:**") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
# Color code average toxicity | |
if avg_toxicity > 0.5: | |
st.metric( | |
"Average Toxicity", | |
f"{avg_toxicity:.4f}", | |
delta="HIGH", | |
delta_color="inverse", | |
) | |
elif avg_toxicity > 0.1: | |
st.metric( | |
"Average Toxicity", | |
f"{avg_toxicity:.4f}", | |
delta="MED", | |
delta_color="off", | |
) | |
else: | |
st.metric( | |
"Average Toxicity", | |
f"{avg_toxicity:.4f}", | |
delta="LOW", | |
delta_color="normal", | |
) | |
with col2: | |
# Color code max toxicity | |
if max_toxicity > 0.5: | |
st.metric( | |
"Max Toxicity", | |
f"{max_toxicity:.4f}", | |
delta="HIGH", | |
delta_color="inverse", | |
) | |
elif max_toxicity > 0.1: | |
st.metric( | |
"Max Toxicity", | |
f"{max_toxicity:.4f}", | |
delta="MED", | |
delta_color="off", | |
) | |
else: | |
st.metric( | |
"Max Toxicity", | |
f"{max_toxicity:.4f}", | |
delta="LOW", | |
delta_color="normal", | |
) | |
with col3: | |
high_tox_turns = sum(1 for t in all_toxicities if t > 0.5) | |
st.metric("High Toxicity Turns", high_tox_turns) | |
# Get conversation turns with metrics | |
conv_turns_data = filtered_df_exploded[ | |
filtered_df_exploded.index.isin( | |
filtered_df_exploded[ | |
filtered_df_exploded.index | |
// len(filtered_df_exploded) | |
* len(filtered_df) | |
+ filtered_df_exploded.index % len(filtered_df) | |
== selected_idx | |
].index | |
) | |
].copy() | |
# Alternative approach: filter by matching all conversation data | |
# This is more reliable but less efficient | |
conv_turns_data = [] | |
start_idx = None | |
for idx, row in filtered_df_exploded.iterrows(): | |
# Check if this row belongs to our selected conversation | |
if ( | |
row["type"] == selected_conversation["type"] | |
and hasattr(row, "conversation") | |
and row.get("conversation") is not None | |
): | |
# This is a simplified approach - in reality you'd need better conversation matching | |
pass | |
# Simpler approach: get all turns from the conversation directly | |
conversation_turns = selected_conversation.get("conversation", []) | |
# Ensure conversation_turns is a list and handle different data types | |
if hasattr(conversation_turns, "tolist"): | |
conversation_turns = conversation_turns.tolist() | |
elif conversation_turns is None: | |
conversation_turns = [] | |
if len(conversation_turns) > 0: | |
# Display conversation content with metrics | |
st.subheader("π¬ Conversation with Metrics") | |
# Get actual turn-level data for this conversation | |
turn_metric_columns = [f"turn.turn_metrics.{m}" for m in available_metrics_for_analysis] | |
available_columns = [ | |
col | |
for col in turn_metric_columns | |
if col in filtered_df_exploded.columns | |
] | |
# Get sample metrics for this conversation type (since exact matching is complex) | |
sample_metrics = None | |
if available_columns: | |
type_turns = filtered_df_exploded[ | |
filtered_df_exploded["type"] == selected_conversation["type"] | |
] | |
sample_size = min(len(conversation_turns), len(type_turns)) | |
if sample_size > 0: | |
sample_metrics = type_turns.head(sample_size) | |
# Display each turn with its metrics | |
for i, turn in enumerate(conversation_turns): | |
role = turn.get("role", "unknown") | |
content = turn.get("content", "No content") | |
# Display turn content with role styling | |
if role == "user": | |
st.markdown(f"**π€ User (Turn {i+1}):**") | |
st.info(content) | |
elif role == "assistant": | |
st.markdown(f"**π€ Assistant (Turn {i+1}):**") | |
st.success(content) | |
else: | |
st.markdown(f"**β {role.title()} (Turn {i+1}):**") | |
st.warning(content) | |
# Display metrics for this turn | |
if sample_metrics is not None and i < len(sample_metrics): | |
turn_row = sample_metrics.iloc[i] | |
# Create metrics display | |
metrics_for_turn = {} | |
for col in available_columns: | |
metric_name = col.replace("turn.turn_metrics.", "") | |
friendly_name = get_human_friendly_metric_name(metric_name) | |
value = turn_row.get(col, "N/A") | |
if pd.notna(value) and isinstance(value, (int, float)): | |
metrics_for_turn[friendly_name] = round(value, 3) | |
else: | |
metrics_for_turn[friendly_name] = "N/A" | |
# Add toxicity metrics if available | |
toxicities = turn.get("toxicities", {}) | |
if toxicities: | |
st.markdown("**π Toxicity Scores:**") | |
tox_cols = st.columns(4) | |
tox_metrics = [ | |
("toxicity", "Overall Toxicity"), | |
("severe_toxicity", "Severe Toxicity"), | |
("identity_attack", "Identity Attack"), | |
("insult", "Insult"), | |
("obscene", "Obscene"), | |
("sexual_explicit", "Sexual Explicit"), | |
("threat", "Threat"), | |
] | |
for idx, (tox_key, tox_name) in enumerate(tox_metrics): | |
if tox_key in toxicities: | |
col_idx = idx % 4 | |
with tox_cols[col_idx]: | |
tox_value = toxicities[tox_key] | |
if isinstance(tox_value, (int, float)): | |
# Color code based on toxicity level | |
if tox_value > 0.5: | |
st.metric( | |
tox_name, | |
f"{tox_value:.4f}", | |
delta="HIGH", | |
delta_color="inverse", | |
) | |
elif tox_value > 0.1: | |
st.metric( | |
tox_name, | |
f"{tox_value:.4f}", | |
delta="MED", | |
delta_color="off", | |
) | |
else: | |
st.metric( | |
tox_name, | |
f"{tox_value:.4f}", | |
delta="LOW", | |
delta_color="normal", | |
) | |
else: | |
st.metric(tox_name, str(tox_value)) | |
# Display complexity metrics | |
if metrics_for_turn: | |
st.markdown("**π Complexity Metrics:**") | |
# Display metrics in columns | |
num_cols = min(4, len(metrics_for_turn)) | |
if num_cols > 0: | |
cols = st.columns(num_cols) | |
for idx, (metric_name, value) in enumerate( | |
metrics_for_turn.items() | |
): | |
col_idx = idx % num_cols | |
with cols[col_idx]: | |
if ( | |
isinstance(value, (int, float)) | |
and value != "N/A" | |
): | |
st.metric(metric_name, value) | |
else: | |
st.metric(metric_name, str(value)) | |
else: | |
# Show toxicity even when no complexity metrics available | |
toxicities = turn.get("toxicities", {}) | |
if toxicities: | |
st.markdown("**π Toxicity Scores:**") | |
tox_cols = st.columns(4) | |
tox_metrics = [ | |
("toxicity", "Overall Toxicity"), | |
("severe_toxicity", "Severe Toxicity"), | |
("identity_attack", "Identity Attack"), | |
("insult", "Insult"), | |
("obscene", "Obscene"), | |
("sexual_explicit", "Sexual Explicit"), | |
("threat", "Threat"), | |
] | |
for idx, (tox_key, tox_name) in enumerate(tox_metrics): | |
if tox_key in toxicities: | |
col_idx = idx % 4 | |
with tox_cols[col_idx]: | |
tox_value = toxicities[tox_key] | |
if isinstance(tox_value, (int, float)): | |
# Color code based on toxicity level | |
if tox_value > 0.5: | |
st.metric( | |
tox_name, | |
f"{tox_value:.4f}", | |
delta="HIGH", | |
delta_color="inverse", | |
) | |
elif tox_value > 0.1: | |
st.metric( | |
tox_name, | |
f"{tox_value:.4f}", | |
delta="MED", | |
delta_color="off", | |
) | |
else: | |
st.metric( | |
tox_name, | |
f"{tox_value:.4f}", | |
delta="LOW", | |
delta_color="normal", | |
) | |
else: | |
st.metric(tox_name, str(tox_value)) | |
# Show basic turn statistics when no complexity metrics available | |
st.markdown("**π Basic Statistics:**") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Characters", len(content)) | |
with col2: | |
st.metric("Words", len(content.split())) | |
with col3: | |
st.metric("Role", role.title()) | |
# Add separator between turns | |
st.divider() | |
# Plot metrics over turns with real data if available | |
if available_columns and sample_metrics is not None: | |
st.subheader("π Metrics Over Turns") | |
fig = go.Figure() | |
# Add traces for each selected metric (real data) | |
for col in available_columns[:5]: # Limit to first 5 for readability | |
metric_name = col.replace("turn.turn_metrics.", "") | |
friendly_name = get_human_friendly_metric_name(metric_name) | |
# Get values for this metric | |
y_values = [] | |
for _, turn_row in sample_metrics.iterrows(): | |
value = turn_row.get(col, None) | |
if pd.notna(value) and isinstance(value, (int, float)): | |
y_values.append(value) | |
else: | |
y_values.append(None) | |
if any(v is not None for v in y_values): | |
fig.add_trace( | |
go.Scatter( | |
x=list(range(1, len(y_values) + 1)), | |
y=y_values, | |
mode="lines+markers", | |
name=friendly_name, | |
line=dict(width=2), | |
marker=dict(size=8), | |
connectgaps=False, | |
) | |
) | |
if fig.data: # Only show if we have data | |
fig.update_layout( | |
title="Complexity Metrics Across Conversation Turns", | |
xaxis_title="Turn Number", | |
yaxis_title="Metric Value", | |
height=400, | |
hovermode="x unified", | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
st.info( | |
"No numeric metric data available to plot for this conversation type." | |
) | |
elif available_metrics_for_analysis: | |
st.info( | |
"Select metrics that are available in the dataset to see turn-level analysis." | |
) | |
else: | |
st.warning("Select some metrics to see detailed turn-level analysis.") | |
else: | |
st.warning("No conversation data available for the selected conversation.") | |
with tab5: | |
st.header("Detailed View") | |
# Add button to trigger detailed analysis | |
st.info("π― Generate detailed dataset analysis and visualizations") | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
show_details = st.button( | |
"π Show Detailed Analysis", | |
help="Generate comprehensive dataset overview and metric analysis", | |
) | |
with col2: | |
if len(available_metrics_for_analysis) > 20: | |
st.warning("β οΈ Large metric selection - analysis may take some time") | |
if show_details: | |
with st.spinner("Generating detailed analysis..."): | |
# Data overview | |
st.subheader("π Dataset Overview") | |
st.info(f"**Current Dataset:** `{selected_dataset}`") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Total Conversations", len(filtered_df)) | |
with col2: | |
st.metric("Total Turns", len(filtered_df_exploded)) | |
with col3: | |
st.metric("Available Metrics", len(available_metrics)) | |
# Type distribution | |
st.subheader("π Type Distribution") | |
type_counts = filtered_df["type"].value_counts() | |
fig = px.pie( | |
values=type_counts.values, | |
names=type_counts.index, | |
title="Distribution of Conversation Types", | |
color_discrete_map=PLOT_PALETTE if len(type_counts) <= 3 else None, | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Sample data | |
st.subheader("π Sample Data") | |
if st.checkbox("Show raw data sample"): | |
sample_cols = ["type"] + [ | |
f"turn.turn_metrics.{m}" | |
for m in available_metrics_for_analysis | |
if f"turn.turn_metrics.{m}" in filtered_df_exploded.columns | |
] | |
sample_data = filtered_df_exploded[sample_cols].head(100) | |
st.dataframe(sample_data) | |
# Metric availability | |
st.subheader("π Metric Availability") | |
metric_completeness = {} | |
for metric in available_metrics_for_analysis: | |
full_metric_name = f"turn.turn_metrics.{metric}" | |
if full_metric_name in filtered_df_exploded.columns: | |
completeness = ( | |
1 | |
- filtered_df_exploded[full_metric_name].isna().sum() | |
/ len(filtered_df_exploded) | |
) * 100 | |
metric_completeness[get_human_friendly_metric_name(metric)] = ( | |
completeness | |
) | |
if metric_completeness: | |
completeness_df = pd.DataFrame( | |
list(metric_completeness.items()), | |
columns=["Metric", "Completeness (%)"], | |
) | |
fig = px.bar( | |
completeness_df, | |
x="Metric", | |
y="Completeness (%)", | |
title="Data Completeness by Metric", | |
color="Completeness (%)", | |
color_continuous_scale="Viridis", | |
) | |
fig.update_layout(xaxis_tickangle=-45, height=400) | |
st.plotly_chart(fig, use_container_width=True) | |
if __name__ == "__main__": | |
main() | |