Lisa Dunlap
restart
4862c84
"""
Plotting functionality for functional metrics.
This module provides comprehensive visualization of metrics from functional_metrics.py,
"""
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, Any, List, Optional
import warnings
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
# Set plotly template
pio.templates.default = "plotly_white"
warnings.filterwarnings('ignore')
def create_model_cluster_dataframe(model_cluster_scores: Dict[str, Any]) -> pd.DataFrame:
"""Convert model-cluster scores to a tidy dataframe."""
rows = []
for model, clusters in model_cluster_scores.items():
for cluster, metrics in clusters.items():
# Filter out "No properties" clusters
if cluster == "No properties":
continue
row = {
'model': model,
'cluster': cluster,
'size': metrics.get('size', 0),
'proportion': metrics.get('proportion', 0),
'proportion_delta': metrics.get('proportion_delta', 0)
}
# Add confidence intervals if available
if 'proportion_ci' in metrics:
ci = metrics['proportion_ci']
row.update({
'proportion_ci_lower': ci.get('lower', 0),
'proportion_ci_upper': ci.get('upper', 0),
'proportion_ci_mean': ci.get('mean', 0)
})
if 'proportion_delta_ci' in metrics:
ci = metrics['proportion_delta_ci']
row.update({
'proportion_delta_ci_lower': ci.get('lower', 0),
'proportion_delta_ci_upper': ci.get('upper', 0),
'proportion_delta_ci_mean': ci.get('mean', 0)
})
# Add significance flags
row['proportion_delta_significant'] = metrics.get('proportion_delta_significant', False)
# Add quality metrics
quality = metrics.get('quality', {})
quality_delta = metrics.get('quality_delta', {})
quality_ci = metrics.get('quality_ci', {})
quality_delta_ci = metrics.get('quality_delta_ci', {})
quality_delta_significant = metrics.get('quality_delta_significant', {})
for metric_name in quality.keys():
row[f'quality_{metric_name}'] = quality[metric_name]
row[f'quality_delta_{metric_name}'] = quality_delta.get(metric_name, 0)
row[f'quality_delta_{metric_name}_significant'] = quality_delta_significant.get(metric_name, False)
if metric_name in quality_ci:
ci = quality_ci[metric_name]
row.update({
f'quality_{metric_name}_ci_lower': ci.get('lower', 0),
f'quality_{metric_name}_ci_upper': ci.get('upper', 0),
f'quality_{metric_name}_ci_mean': ci.get('mean', 0)
})
if metric_name in quality_delta_ci:
ci = quality_delta_ci[metric_name]
row.update({
f'quality_delta_{metric_name}_ci_lower': ci.get('lower', 0),
f'quality_delta_{metric_name}_ci_upper': ci.get('upper', 0),
f'quality_delta_{metric_name}_ci_mean': ci.get('mean', 0)
})
rows.append(row)
return pd.DataFrame(rows)
def create_cluster_dataframe(cluster_scores: Dict[str, Any]) -> pd.DataFrame:
"""Convert cluster scores to a tidy dataframe."""
rows = []
for cluster, metrics in cluster_scores.items():
# Filter out "No properties" clusters
if cluster == "No properties":
continue
row = {
'cluster': cluster,
'size': metrics.get('size', 0),
'proportion': metrics.get('proportion', 0)
}
# Add confidence intervals if available
if 'proportion_ci' in metrics:
ci = metrics['proportion_ci']
row.update({
'proportion_ci_lower': ci.get('lower', 0),
'proportion_ci_upper': ci.get('upper', 0),
'proportion_ci_mean': ci.get('mean', 0)
})
# Add quality metrics
quality = metrics.get('quality', {})
quality_delta = metrics.get('quality_delta', {})
quality_ci = metrics.get('quality_ci', {})
quality_delta_ci = metrics.get('quality_delta_ci', {})
quality_delta_significant = metrics.get('quality_delta_significant', {})
for metric_name in quality.keys():
row[f'quality_{metric_name}'] = quality[metric_name]
row[f'quality_delta_{metric_name}'] = quality_delta.get(metric_name, 0)
row[f'quality_delta_{metric_name}_significant'] = quality_delta_significant.get(metric_name, False)
if metric_name in quality_ci:
ci = quality_ci[metric_name]
row.update({
f'quality_{metric_name}_ci_lower': ci.get('lower', 0),
f'quality_{metric_name}_ci_upper': ci.get('upper', 0),
f'quality_{metric_name}_ci_mean': ci.get('mean', 0)
})
if metric_name in quality_delta_ci:
ci = quality_delta_ci[metric_name]
row.update({
f'quality_delta_{metric_name}_ci_lower': ci.get('lower', 0),
f'quality_delta_{metric_name}_ci_upper': ci.get('upper', 0),
f'quality_delta_{metric_name}_ci_mean': ci.get('mean', 0)
})
rows.append(row)
return pd.DataFrame(rows)
def create_model_dataframe(model_scores: Dict[str, Any]) -> pd.DataFrame:
"""Convert model scores to a tidy dataframe."""
rows = []
for model, metrics in model_scores.items():
row = {
'model': model,
'size': metrics.get('size', 0),
'proportion': metrics.get('proportion', 0)
}
# Add confidence intervals if available
if 'proportion_ci' in metrics:
ci = metrics['proportion_ci']
row.update({
'proportion_ci_lower': ci.get('lower', 0),
'proportion_ci_upper': ci.get('upper', 0),
'proportion_ci_mean': ci.get('mean', 0)
})
# Add quality metrics
quality = metrics.get('quality', {})
quality_delta = metrics.get('quality_delta', {})
quality_ci = metrics.get('quality_ci', {})
quality_delta_ci = metrics.get('quality_delta_ci', {})
quality_delta_significant = metrics.get('quality_delta_significant', {})
for metric_name in quality.keys():
row[f'quality_{metric_name}'] = quality[metric_name]
row[f'quality_delta_{metric_name}'] = quality_delta.get(metric_name, 0)
row[f'quality_delta_{metric_name}_significant'] = quality_delta_significant.get(metric_name, False)
if metric_name in quality_ci:
ci = quality_ci[metric_name]
row.update({
f'quality_{metric_name}_ci_lower': ci.get('lower', 0),
f'quality_{metric_name}_ci_upper': ci.get('upper', 0),
f'quality_{metric_name}_ci_mean': ci.get('mean', 0)
})
if metric_name in quality_delta_ci:
ci = quality_delta_ci[metric_name]
row.update({
f'quality_delta_{metric_name}_ci_lower': ci.get('lower', 0),
f'quality_delta_{metric_name}_ci_upper': ci.get('upper', 0),
f'quality_delta_{metric_name}_ci_mean': ci.get('mean', 0)
})
rows.append(row)
return pd.DataFrame(rows)
def get_quality_metrics(df: pd.DataFrame) -> List[str]:
"""Extract quality metric names from dataframe columns."""
quality_cols = [col for col in df.columns if col.startswith('quality_') and not col.endswith(('_ci_lower', '_ci_upper', '_ci_mean', '_significant'))]
return [col.replace('quality_', '') for col in quality_cols]
def create_interactive_cluster_plot(cluster_df: pd.DataFrame, model_cluster_df: pd.DataFrame,
metric_col: str, title: str,
ci_lower_col: Optional[str] = None, ci_upper_col: Optional[str] = None,
significant_col: Optional[str] = None) -> go.Figure:
"""Create an interactive cluster plot with dropdown for view mode."""
# Create the figure with subplots
fig = make_subplots(
rows=1, cols=1,
specs=[[{"secondary_y": False}]],
subplot_titles=[title]
)
# Prepare cluster_df - reset index if cluster is the index
if 'cluster' not in cluster_df.columns and cluster_df.index.name == 'cluster':
cluster_df = cluster_df.reset_index()
# Sort clusters by metric value in descending order for consistent ordering
cluster_df = cluster_df.sort_values(metric_col, ascending=False)
# Add aggregated view (default) - using cluster_df
if ci_lower_col and ci_upper_col and ci_lower_col in cluster_df.columns and ci_upper_col in cluster_df.columns:
fig.add_trace(
go.Bar(
x=cluster_df['cluster'],
y=cluster_df[metric_col],
name='Aggregated (All Models)',
error_y=dict(
type='data',
array=cluster_df[ci_upper_col] - cluster_df[metric_col],
arrayminus=cluster_df[metric_col] - cluster_df[ci_lower_col],
visible=True
),
visible=True
)
)
else:
fig.add_trace(
go.Bar(
x=cluster_df['cluster'],
y=cluster_df[metric_col],
name='Aggregated (All Models)',
visible=True
)
)
# Grouped by model view - using model_cluster_df
for model in model_cluster_df['model'].unique():
model_df = model_cluster_df[model_cluster_df['model'] == model]
# Sort model_df to match the cluster order
model_df = model_df.set_index('cluster').reindex(cluster_df['cluster']).reset_index()
if ci_lower_col and ci_upper_col and ci_lower_col in model_cluster_df.columns and ci_upper_col in model_cluster_df.columns:
fig.add_trace(
go.Bar(
x=model_df['cluster'],
y=model_df[metric_col],
name=f'Model: {model}',
error_y=dict(
type='data',
array=model_df[ci_upper_col] - model_df[metric_col],
arrayminus=model_df[metric_col] - model_df[ci_lower_col],
visible=False
),
visible=False
)
)
else:
fig.add_trace(
go.Bar(
x=model_df['cluster'],
y=model_df[metric_col],
name=f'Model: {model}',
visible=False
)
)
# Add significance markers if available (for aggregated view)
# Red asterisks (*) indicate clusters with statistically significant quality delta values
# (confidence intervals that do not contain 0)
if significant_col and significant_col in cluster_df.columns:
for i, (cluster, is_sig) in enumerate(zip(cluster_df['cluster'], cluster_df[significant_col])):
if is_sig:
fig.add_annotation(
x=cluster,
y=cluster_df[cluster_df['cluster'] == cluster][metric_col].iloc[0],
text="*",
showarrow=False,
font=dict(size=16, color="red"),
yshift=10
)
# Update layout
fig.update_layout(
title=title,
xaxis_title="Cluster",
yaxis_title=metric_col.replace('_', ' ').title(),
barmode='group',
height=500,
showlegend=True,
annotations=[
dict(
text="* = Statistically significant (CI does not contain 0)",
showarrow=False,
xref="paper", yref="paper",
x=0.01, y=0.01,
xanchor="left", yanchor="bottom",
font=dict(size=10, color="red")
)
] if significant_col and significant_col in cluster_df.columns else []
)
# Add dropdown for view selection - only 2 options
buttons = []
# Aggregated view button (all models combined)
visibility = [True] + [False] * len(model_cluster_df['model'].unique())
buttons.append(
dict(
label="Aggregated (All Models)",
method="update",
args=[{"visible": visibility, "barmode": "group"}]
)
)
# Grouped by model view (each model as separate bars)
visibility = [False] + [True] * len(model_cluster_df['model'].unique())
buttons.append(
dict(
label="Grouped by Model",
method="update",
args=[{"visible": visibility, "barmode": "group"}]
)
)
fig.update_layout(
updatemenus=[
dict(
buttons=buttons,
direction="down",
showactive=True,
x=0.95,
xanchor="right",
y=1.25,
yanchor="top"
)
]
)
return fig
def create_interactive_heatmap(df: pd.DataFrame, value_col: str, title: str,
pivot_index: str = 'model', pivot_columns: str = 'cluster',
significant_col: Optional[str] = None) -> go.Figure:
"""Create an interactive heatmap with hover information."""
# Create pivot table
pivot_df = df.pivot(index=pivot_index, columns=pivot_columns, values=value_col)
# Sort by mean values for consistent ordering
if pivot_index == 'model':
# Sort models by their mean values across clusters
model_means = pivot_df.mean(axis=1).sort_values(ascending=False)
pivot_df = pivot_df.reindex(model_means.index)
else:
# Sort clusters by their mean values across models
cluster_means = pivot_df.mean(axis=0).sort_values(ascending=False)
pivot_df = pivot_df.reindex(columns=cluster_means.index)
# Transpose the data for more intuitive visualization (models on x-axis, clusters on y-axis)
pivot_df = pivot_df.T
# Create heatmap
fig = go.Figure(data=go.Heatmap(
z=pivot_df.values,
x=pivot_df.columns, # Models
y=pivot_df.index, # Clusters
colorscale='RdBu_r' if 'delta' in value_col else 'Viridis',
zmid=0 if 'delta' in value_col else None,
text=pivot_df.values.round(3),
texttemplate="%{text}",
textfont={"size": 10},
hoverongaps=False
))
# Add significance markers if available
if significant_col and significant_col in df.columns:
sig_pivot = df.pivot(index=pivot_index, columns=pivot_columns, values=significant_col)
# Apply same sorting as the main pivot
if pivot_index == 'model':
sig_pivot = sig_pivot.reindex(model_means.index)
else:
sig_pivot = sig_pivot.reindex(columns=cluster_means.index)
sig_pivot = sig_pivot.T # Transpose to match the main heatmap
for i, cluster in enumerate(pivot_df.index):
for j, model in enumerate(pivot_df.columns):
if sig_pivot.loc[cluster, model]:
fig.add_annotation(
x=model,
y=cluster,
text="*",
showarrow=False,
font=dict(size=16, color="red"),
xshift=10,
yshift=10
)
fig.update_layout(
title=title,
xaxis_title="Model",
yaxis_title="Cluster",
height=500,
annotations=[
dict(
text="* = Statistically significant (CI does not contain 0)",
showarrow=False,
xref="paper", yref="paper",
x=0.01, y=0.01,
xanchor="left", yanchor="bottom",
font=dict(size=10, color="red")
)
] if significant_col and significant_col in df.columns else []
)
return fig
def create_interactive_model_plot(model_df: pd.DataFrame, model_cluster_df: pd.DataFrame,
metric_col: str, title: str,
ci_lower_col: Optional[str] = None, ci_upper_col: Optional[str] = None,
significant_col: Optional[str] = None) -> go.Figure:
"""Create an interactive model plot with dropdown for view mode."""
# Create the figure with subplots
fig = make_subplots(
rows=1, cols=1,
specs=[[{"secondary_y": False}]],
subplot_titles=[title]
)
# Prepare model_df - reset index if model is the index
if 'model' not in model_df.columns and model_df.index.name == 'model':
model_df = model_df.reset_index()
# Add aggregated view (default) - using model_df
if ci_lower_col and ci_upper_col and ci_lower_col in model_df.columns and ci_upper_col in model_df.columns:
fig.add_trace(
go.Bar(
x=model_df['model'],
y=model_df[metric_col],
name='Aggregated (All Clusters)',
error_y=dict(
type='data',
array=model_df[ci_upper_col] - model_df[metric_col],
arrayminus=model_df[metric_col] - model_df[ci_lower_col],
visible=True
),
visible=True
)
)
else:
fig.add_trace(
go.Bar(
x=model_df['model'],
y=model_df[metric_col],
name='Aggregated (All Clusters)',
visible=True
)
)
# Grouped by cluster view - using model_cluster_df
for cluster in model_cluster_df['cluster'].unique():
cluster_df = model_cluster_df[model_cluster_df['cluster'] == cluster]
if ci_lower_col and ci_upper_col and ci_lower_col in cluster_df.columns and ci_upper_col in cluster_df.columns:
fig.add_trace(
go.Bar(
x=cluster_df['model'],
y=cluster_df[metric_col],
name=f'Cluster: {cluster}',
error_y=dict(
type='data',
array=cluster_df[ci_upper_col] - cluster_df[metric_col],
arrayminus=cluster_df[metric_col] - cluster_df[ci_lower_col],
visible=False
),
visible=False
)
)
else:
fig.add_trace(
go.Bar(
x=cluster_df['model'],
y=cluster_df[metric_col],
name=f'Cluster: {cluster}',
visible=False
)
)
# Add significance markers if available (for aggregated view)
if significant_col and significant_col in model_df.columns:
for i, (model, is_sig) in enumerate(zip(model_df['model'], model_df[significant_col])):
if is_sig:
fig.add_annotation(
x=model,
y=model_df[model_df['model'] == model][metric_col].iloc[0],
text="*",
showarrow=False,
font=dict(size=16, color="red"),
yshift=10
)
# Update layout
fig.update_layout(
title=title,
xaxis_title="Model",
yaxis_title=metric_col.replace('_', ' ').title(),
barmode='group',
height=500,
showlegend=True
)
# Add dropdown for view selection - only 2 options
buttons = []
# Aggregated view button (all clusters combined)
visibility = [True] + [False] * len(model_cluster_df['cluster'].unique())
buttons.append(
dict(
label="Aggregated (All Clusters)",
method="update",
args=[{"visible": visibility, "barmode": "group"}]
)
)
# Grouped by cluster view (each cluster as separate bars)
visibility = [False] + [True] * len(model_cluster_df['cluster'].unique())
buttons.append(
dict(
label="Grouped by Cluster",
method="update",
args=[{"visible": visibility, "barmode": "group"}]
)
)
fig.update_layout(
updatemenus=[
dict(
buttons=buttons,
direction="down",
showactive=True,
x=0.95,
xanchor="right",
y=1.25,
yanchor="top"
)
]
)
return fig
def create_interactive_model_cluster_plot(df: pd.DataFrame, metric_col: str, title: str,
ci_lower_col: Optional[str] = None, ci_upper_col: Optional[str] = None,
significant_col: Optional[str] = None) -> go.Figure:
"""Create an interactive model-cluster plot with grouped bars."""
# Create grouped bar chart
if ci_lower_col and ci_upper_col and ci_lower_col in df.columns and ci_upper_col in df.columns:
fig = px.bar(
df,
x='cluster',
y=metric_col,
color='model',
error_y=df[ci_upper_col] - df[metric_col],
error_y_minus=df[metric_col] - df[ci_lower_col],
title=title,
barmode='group'
)
else:
fig = px.bar(
df,
x='cluster',
y=metric_col,
color='model',
title=title,
barmode='group'
)
# Add significance markers if available
if significant_col and significant_col in df.columns:
for i, row in df.iterrows():
if row[significant_col]:
fig.add_annotation(
x=row['cluster'],
y=row[metric_col],
text="*",
showarrow=False,
font=dict(size=16, color="red"),
yshift=10
)
fig.update_layout(
height=500,
xaxis_title="Cluster",
yaxis_title=metric_col.replace('_', ' ').title()
)
return fig