""" Plotting functionality for functional metrics. This module provides comprehensive visualization of metrics from functional_metrics.py, """ import json import pandas as pd import numpy as np from pathlib import Path from typing import Dict, Any, List, Optional import warnings import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots import plotly.io as pio # Set plotly template pio.templates.default = "plotly_white" warnings.filterwarnings('ignore') def create_model_cluster_dataframe(model_cluster_scores: Dict[str, Any]) -> pd.DataFrame: """Convert model-cluster scores to a tidy dataframe.""" rows = [] for model, clusters in model_cluster_scores.items(): for cluster, metrics in clusters.items(): # Filter out "No properties" clusters if cluster == "No properties": continue row = { 'model': model, 'cluster': cluster, 'size': metrics.get('size', 0), 'proportion': metrics.get('proportion', 0), 'proportion_delta': metrics.get('proportion_delta', 0) } # Add confidence intervals if available if 'proportion_ci' in metrics: ci = metrics['proportion_ci'] row.update({ 'proportion_ci_lower': ci.get('lower', 0), 'proportion_ci_upper': ci.get('upper', 0), 'proportion_ci_mean': ci.get('mean', 0) }) if 'proportion_delta_ci' in metrics: ci = metrics['proportion_delta_ci'] row.update({ 'proportion_delta_ci_lower': ci.get('lower', 0), 'proportion_delta_ci_upper': ci.get('upper', 0), 'proportion_delta_ci_mean': ci.get('mean', 0) }) # Add significance flags row['proportion_delta_significant'] = metrics.get('proportion_delta_significant', False) # Add quality metrics quality = metrics.get('quality', {}) quality_delta = metrics.get('quality_delta', {}) quality_ci = metrics.get('quality_ci', {}) quality_delta_ci = metrics.get('quality_delta_ci', {}) quality_delta_significant = metrics.get('quality_delta_significant', {}) for metric_name in quality.keys(): row[f'quality_{metric_name}'] = quality[metric_name] row[f'quality_delta_{metric_name}'] = quality_delta.get(metric_name, 0) row[f'quality_delta_{metric_name}_significant'] = quality_delta_significant.get(metric_name, False) if metric_name in quality_ci: ci = quality_ci[metric_name] row.update({ f'quality_{metric_name}_ci_lower': ci.get('lower', 0), f'quality_{metric_name}_ci_upper': ci.get('upper', 0), f'quality_{metric_name}_ci_mean': ci.get('mean', 0) }) if metric_name in quality_delta_ci: ci = quality_delta_ci[metric_name] row.update({ f'quality_delta_{metric_name}_ci_lower': ci.get('lower', 0), f'quality_delta_{metric_name}_ci_upper': ci.get('upper', 0), f'quality_delta_{metric_name}_ci_mean': ci.get('mean', 0) }) rows.append(row) return pd.DataFrame(rows) def create_cluster_dataframe(cluster_scores: Dict[str, Any]) -> pd.DataFrame: """Convert cluster scores to a tidy dataframe.""" rows = [] for cluster, metrics in cluster_scores.items(): # Filter out "No properties" clusters if cluster == "No properties": continue row = { 'cluster': cluster, 'size': metrics.get('size', 0), 'proportion': metrics.get('proportion', 0) } # Add confidence intervals if available if 'proportion_ci' in metrics: ci = metrics['proportion_ci'] row.update({ 'proportion_ci_lower': ci.get('lower', 0), 'proportion_ci_upper': ci.get('upper', 0), 'proportion_ci_mean': ci.get('mean', 0) }) # Add quality metrics quality = metrics.get('quality', {}) quality_delta = metrics.get('quality_delta', {}) quality_ci = metrics.get('quality_ci', {}) quality_delta_ci = metrics.get('quality_delta_ci', {}) quality_delta_significant = metrics.get('quality_delta_significant', {}) for metric_name in quality.keys(): row[f'quality_{metric_name}'] = quality[metric_name] row[f'quality_delta_{metric_name}'] = quality_delta.get(metric_name, 0) row[f'quality_delta_{metric_name}_significant'] = quality_delta_significant.get(metric_name, False) if metric_name in quality_ci: ci = quality_ci[metric_name] row.update({ f'quality_{metric_name}_ci_lower': ci.get('lower', 0), f'quality_{metric_name}_ci_upper': ci.get('upper', 0), f'quality_{metric_name}_ci_mean': ci.get('mean', 0) }) if metric_name in quality_delta_ci: ci = quality_delta_ci[metric_name] row.update({ f'quality_delta_{metric_name}_ci_lower': ci.get('lower', 0), f'quality_delta_{metric_name}_ci_upper': ci.get('upper', 0), f'quality_delta_{metric_name}_ci_mean': ci.get('mean', 0) }) rows.append(row) return pd.DataFrame(rows) def create_model_dataframe(model_scores: Dict[str, Any]) -> pd.DataFrame: """Convert model scores to a tidy dataframe.""" rows = [] for model, metrics in model_scores.items(): row = { 'model': model, 'size': metrics.get('size', 0), 'proportion': metrics.get('proportion', 0) } # Add confidence intervals if available if 'proportion_ci' in metrics: ci = metrics['proportion_ci'] row.update({ 'proportion_ci_lower': ci.get('lower', 0), 'proportion_ci_upper': ci.get('upper', 0), 'proportion_ci_mean': ci.get('mean', 0) }) # Add quality metrics quality = metrics.get('quality', {}) quality_delta = metrics.get('quality_delta', {}) quality_ci = metrics.get('quality_ci', {}) quality_delta_ci = metrics.get('quality_delta_ci', {}) quality_delta_significant = metrics.get('quality_delta_significant', {}) for metric_name in quality.keys(): row[f'quality_{metric_name}'] = quality[metric_name] row[f'quality_delta_{metric_name}'] = quality_delta.get(metric_name, 0) row[f'quality_delta_{metric_name}_significant'] = quality_delta_significant.get(metric_name, False) if metric_name in quality_ci: ci = quality_ci[metric_name] row.update({ f'quality_{metric_name}_ci_lower': ci.get('lower', 0), f'quality_{metric_name}_ci_upper': ci.get('upper', 0), f'quality_{metric_name}_ci_mean': ci.get('mean', 0) }) if metric_name in quality_delta_ci: ci = quality_delta_ci[metric_name] row.update({ f'quality_delta_{metric_name}_ci_lower': ci.get('lower', 0), f'quality_delta_{metric_name}_ci_upper': ci.get('upper', 0), f'quality_delta_{metric_name}_ci_mean': ci.get('mean', 0) }) rows.append(row) return pd.DataFrame(rows) def get_quality_metrics(df: pd.DataFrame) -> List[str]: """Extract quality metric names from dataframe columns.""" quality_cols = [col for col in df.columns if col.startswith('quality_') and not col.endswith(('_ci_lower', '_ci_upper', '_ci_mean', '_significant'))] return [col.replace('quality_', '') for col in quality_cols] def create_interactive_cluster_plot(cluster_df: pd.DataFrame, model_cluster_df: pd.DataFrame, metric_col: str, title: str, ci_lower_col: Optional[str] = None, ci_upper_col: Optional[str] = None, significant_col: Optional[str] = None) -> go.Figure: """Create an interactive cluster plot with dropdown for view mode.""" # Create the figure with subplots fig = make_subplots( rows=1, cols=1, specs=[[{"secondary_y": False}]], subplot_titles=[title] ) # Prepare cluster_df - reset index if cluster is the index if 'cluster' not in cluster_df.columns and cluster_df.index.name == 'cluster': cluster_df = cluster_df.reset_index() # Sort clusters by metric value in descending order for consistent ordering cluster_df = cluster_df.sort_values(metric_col, ascending=False) # Add aggregated view (default) - using cluster_df if ci_lower_col and ci_upper_col and ci_lower_col in cluster_df.columns and ci_upper_col in cluster_df.columns: fig.add_trace( go.Bar( x=cluster_df['cluster'], y=cluster_df[metric_col], name='Aggregated (All Models)', error_y=dict( type='data', array=cluster_df[ci_upper_col] - cluster_df[metric_col], arrayminus=cluster_df[metric_col] - cluster_df[ci_lower_col], visible=True ), visible=True ) ) else: fig.add_trace( go.Bar( x=cluster_df['cluster'], y=cluster_df[metric_col], name='Aggregated (All Models)', visible=True ) ) # Grouped by model view - using model_cluster_df for model in model_cluster_df['model'].unique(): model_df = model_cluster_df[model_cluster_df['model'] == model] # Sort model_df to match the cluster order model_df = model_df.set_index('cluster').reindex(cluster_df['cluster']).reset_index() if ci_lower_col and ci_upper_col and ci_lower_col in model_cluster_df.columns and ci_upper_col in model_cluster_df.columns: fig.add_trace( go.Bar( x=model_df['cluster'], y=model_df[metric_col], name=f'Model: {model}', error_y=dict( type='data', array=model_df[ci_upper_col] - model_df[metric_col], arrayminus=model_df[metric_col] - model_df[ci_lower_col], visible=False ), visible=False ) ) else: fig.add_trace( go.Bar( x=model_df['cluster'], y=model_df[metric_col], name=f'Model: {model}', visible=False ) ) # Add significance markers if available (for aggregated view) # Red asterisks (*) indicate clusters with statistically significant quality delta values # (confidence intervals that do not contain 0) if significant_col and significant_col in cluster_df.columns: for i, (cluster, is_sig) in enumerate(zip(cluster_df['cluster'], cluster_df[significant_col])): if is_sig: fig.add_annotation( x=cluster, y=cluster_df[cluster_df['cluster'] == cluster][metric_col].iloc[0], text="*", showarrow=False, font=dict(size=16, color="red"), yshift=10 ) # Update layout fig.update_layout( title=title, xaxis_title="Cluster", yaxis_title=metric_col.replace('_', ' ').title(), barmode='group', height=500, showlegend=True, annotations=[ dict( text="* = Statistically significant (CI does not contain 0)", showarrow=False, xref="paper", yref="paper", x=0.01, y=0.01, xanchor="left", yanchor="bottom", font=dict(size=10, color="red") ) ] if significant_col and significant_col in cluster_df.columns else [] ) # Add dropdown for view selection - only 2 options buttons = [] # Aggregated view button (all models combined) visibility = [True] + [False] * len(model_cluster_df['model'].unique()) buttons.append( dict( label="Aggregated (All Models)", method="update", args=[{"visible": visibility, "barmode": "group"}] ) ) # Grouped by model view (each model as separate bars) visibility = [False] + [True] * len(model_cluster_df['model'].unique()) buttons.append( dict( label="Grouped by Model", method="update", args=[{"visible": visibility, "barmode": "group"}] ) ) fig.update_layout( updatemenus=[ dict( buttons=buttons, direction="down", showactive=True, x=0.95, xanchor="right", y=1.25, yanchor="top" ) ] ) return fig def create_interactive_heatmap(df: pd.DataFrame, value_col: str, title: str, pivot_index: str = 'model', pivot_columns: str = 'cluster', significant_col: Optional[str] = None) -> go.Figure: """Create an interactive heatmap with hover information.""" # Create pivot table pivot_df = df.pivot(index=pivot_index, columns=pivot_columns, values=value_col) # Sort by mean values for consistent ordering if pivot_index == 'model': # Sort models by their mean values across clusters model_means = pivot_df.mean(axis=1).sort_values(ascending=False) pivot_df = pivot_df.reindex(model_means.index) else: # Sort clusters by their mean values across models cluster_means = pivot_df.mean(axis=0).sort_values(ascending=False) pivot_df = pivot_df.reindex(columns=cluster_means.index) # Transpose the data for more intuitive visualization (models on x-axis, clusters on y-axis) pivot_df = pivot_df.T # Create heatmap fig = go.Figure(data=go.Heatmap( z=pivot_df.values, x=pivot_df.columns, # Models y=pivot_df.index, # Clusters colorscale='RdBu_r' if 'delta' in value_col else 'Viridis', zmid=0 if 'delta' in value_col else None, text=pivot_df.values.round(3), texttemplate="%{text}", textfont={"size": 10}, hoverongaps=False )) # Add significance markers if available if significant_col and significant_col in df.columns: sig_pivot = df.pivot(index=pivot_index, columns=pivot_columns, values=significant_col) # Apply same sorting as the main pivot if pivot_index == 'model': sig_pivot = sig_pivot.reindex(model_means.index) else: sig_pivot = sig_pivot.reindex(columns=cluster_means.index) sig_pivot = sig_pivot.T # Transpose to match the main heatmap for i, cluster in enumerate(pivot_df.index): for j, model in enumerate(pivot_df.columns): if sig_pivot.loc[cluster, model]: fig.add_annotation( x=model, y=cluster, text="*", showarrow=False, font=dict(size=16, color="red"), xshift=10, yshift=10 ) fig.update_layout( title=title, xaxis_title="Model", yaxis_title="Cluster", height=500, annotations=[ dict( text="* = Statistically significant (CI does not contain 0)", showarrow=False, xref="paper", yref="paper", x=0.01, y=0.01, xanchor="left", yanchor="bottom", font=dict(size=10, color="red") ) ] if significant_col and significant_col in df.columns else [] ) return fig def create_interactive_model_plot(model_df: pd.DataFrame, model_cluster_df: pd.DataFrame, metric_col: str, title: str, ci_lower_col: Optional[str] = None, ci_upper_col: Optional[str] = None, significant_col: Optional[str] = None) -> go.Figure: """Create an interactive model plot with dropdown for view mode.""" # Create the figure with subplots fig = make_subplots( rows=1, cols=1, specs=[[{"secondary_y": False}]], subplot_titles=[title] ) # Prepare model_df - reset index if model is the index if 'model' not in model_df.columns and model_df.index.name == 'model': model_df = model_df.reset_index() # Add aggregated view (default) - using model_df if ci_lower_col and ci_upper_col and ci_lower_col in model_df.columns and ci_upper_col in model_df.columns: fig.add_trace( go.Bar( x=model_df['model'], y=model_df[metric_col], name='Aggregated (All Clusters)', error_y=dict( type='data', array=model_df[ci_upper_col] - model_df[metric_col], arrayminus=model_df[metric_col] - model_df[ci_lower_col], visible=True ), visible=True ) ) else: fig.add_trace( go.Bar( x=model_df['model'], y=model_df[metric_col], name='Aggregated (All Clusters)', visible=True ) ) # Grouped by cluster view - using model_cluster_df for cluster in model_cluster_df['cluster'].unique(): cluster_df = model_cluster_df[model_cluster_df['cluster'] == cluster] if ci_lower_col and ci_upper_col and ci_lower_col in cluster_df.columns and ci_upper_col in cluster_df.columns: fig.add_trace( go.Bar( x=cluster_df['model'], y=cluster_df[metric_col], name=f'Cluster: {cluster}', error_y=dict( type='data', array=cluster_df[ci_upper_col] - cluster_df[metric_col], arrayminus=cluster_df[metric_col] - cluster_df[ci_lower_col], visible=False ), visible=False ) ) else: fig.add_trace( go.Bar( x=cluster_df['model'], y=cluster_df[metric_col], name=f'Cluster: {cluster}', visible=False ) ) # Add significance markers if available (for aggregated view) if significant_col and significant_col in model_df.columns: for i, (model, is_sig) in enumerate(zip(model_df['model'], model_df[significant_col])): if is_sig: fig.add_annotation( x=model, y=model_df[model_df['model'] == model][metric_col].iloc[0], text="*", showarrow=False, font=dict(size=16, color="red"), yshift=10 ) # Update layout fig.update_layout( title=title, xaxis_title="Model", yaxis_title=metric_col.replace('_', ' ').title(), barmode='group', height=500, showlegend=True ) # Add dropdown for view selection - only 2 options buttons = [] # Aggregated view button (all clusters combined) visibility = [True] + [False] * len(model_cluster_df['cluster'].unique()) buttons.append( dict( label="Aggregated (All Clusters)", method="update", args=[{"visible": visibility, "barmode": "group"}] ) ) # Grouped by cluster view (each cluster as separate bars) visibility = [False] + [True] * len(model_cluster_df['cluster'].unique()) buttons.append( dict( label="Grouped by Cluster", method="update", args=[{"visible": visibility, "barmode": "group"}] ) ) fig.update_layout( updatemenus=[ dict( buttons=buttons, direction="down", showactive=True, x=0.95, xanchor="right", y=1.25, yanchor="top" ) ] ) return fig def create_interactive_model_cluster_plot(df: pd.DataFrame, metric_col: str, title: str, ci_lower_col: Optional[str] = None, ci_upper_col: Optional[str] = None, significant_col: Optional[str] = None) -> go.Figure: """Create an interactive model-cluster plot with grouped bars.""" # Create grouped bar chart if ci_lower_col and ci_upper_col and ci_lower_col in df.columns and ci_upper_col in df.columns: fig = px.bar( df, x='cluster', y=metric_col, color='model', error_y=df[ci_upper_col] - df[metric_col], error_y_minus=df[metric_col] - df[ci_lower_col], title=title, barmode='group' ) else: fig = px.bar( df, x='cluster', y=metric_col, color='model', title=title, barmode='group' ) # Add significance markers if available if significant_col and significant_col in df.columns: for i, row in df.iterrows(): if row[significant_col]: fig.add_annotation( x=row['cluster'], y=row[metric_col], text="*", showarrow=False, font=dict(size=16, color="red"), yshift=10 ) fig.update_layout( height=500, xaxis_title="Cluster", yaxis_title=metric_col.replace('_', ' ').title() ) return fig