Spaces:
Sleeping
Sleeping
File size: 12,974 Bytes
4862c84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
"""
Plots tab for the LMM-Vibes Gradio app.
This module provides functionality to display the model cluster proportion and quality plots.
"""
import gradio as gr
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from typing import Tuple, List
from .state import app_state
def create_proportion_plot(show_ci: bool = False) -> Tuple[go.Figure, str]:
"""Create a grouped bar plot of proportion by property and model."""
if app_state.get("model_cluster_df") is None:
return None, "No model cluster data loaded. Please load data first."
model_cluster_df = app_state["model_cluster_df"]
print("DataFrame shape:", model_cluster_df.shape)
print("Columns:", model_cluster_df.columns.tolist())
print("Proportion range:", model_cluster_df['proportion'].min(), "to", model_cluster_df['proportion'].max())
print("Sample data:")
print(model_cluster_df[['model', 'cluster', 'proportion']].head(10))
if model_cluster_df.empty:
return None, "No model cluster data available."
# Ensure proportion values are numeric and in reasonable range
model_cluster_df = model_cluster_df.copy()
model_cluster_df['proportion'] = pd.to_numeric(model_cluster_df['proportion'], errors='coerce')
# Check for any unreasonable values
print("After conversion - Proportion range:", model_cluster_df['proportion'].min(), "to", model_cluster_df['proportion'].max())
print("Proportion values > 1:", (model_cluster_df['proportion'] > 1).sum())
print("Proportion values < 0:", (model_cluster_df['proportion'] < 0).sum())
# Create property name mapping with proper ordering
unique_properties = sorted(model_cluster_df['cluster'].unique())
property_mapping = {prop: f"P{i+1}" for i, prop in enumerate(unique_properties)}
# Create abbreviated property column for plotting
model_cluster_df['property_abbr'] = model_cluster_df['cluster'].map(property_mapping)
# Filter out "No properties" clusters
model_cluster_df = model_cluster_df[model_cluster_df['cluster'] != "No properties"]
# Prepare confidence interval data if requested
error_y_data = None
if show_ci and 'proportion_ci_lower' in model_cluster_df.columns and 'proportion_ci_upper' in model_cluster_df.columns:
# Calculate error bar values
model_cluster_df['y_error'] = model_cluster_df['proportion_ci_upper'] - model_cluster_df['proportion']
model_cluster_df['y_error_minus'] = model_cluster_df['proportion'] - model_cluster_df['proportion_ci_lower']
# Replace NaN values with 0
model_cluster_df['y_error'] = model_cluster_df['y_error'].fillna(0)
model_cluster_df['y_error_minus'] = model_cluster_df['y_error_minus'].fillna(0)
error_y_data = model_cluster_df['y_error']
error_y_minus_data = model_cluster_df['y_error_minus']
# Create a grouped bar plot of 'proportion' by property (x) and model (hue)
fig = px.bar(
model_cluster_df,
x="property_abbr",
y="proportion",
color="model",
barmode="group",
title="Proportion by Property and Model",
labels={"proportion": "Proportion", "property_abbr": "Property", "model": "Model"},
error_y="y_error" if error_y_data is not None else None,
error_y_minus="y_error_minus" if error_y_data is not None else None
)
# Set the x-axis order to ensure P1, P2, P3, etc.
property_order = [f"P{i+1}" for i in range(len(unique_properties))]
fig.update_xaxes(categoryorder='array', categoryarray=property_order)
fig.update_layout(xaxis_tickangle=45)
# save figure to file
fig.write_html("model_cluster_proportion_plot.html")
# Create property mapping string
mapping_text = "**Property Mapping:**\n\n"
for prop, abbr in property_mapping.items():
mapping_text += f"**{abbr}:** {prop}\n\n"
# Add confidence interval info if enabled
if show_ci:
if 'proportion_ci_lower' in model_cluster_df.columns and 'proportion_ci_upper' in model_cluster_df.columns:
mapping_text += "---\n\n**Confidence Intervals:**\n"
mapping_text += "Error bars show 95% confidence intervals for proportion values.\n"
else:
mapping_text += "---\n\n**Note:** Confidence interval data not available in the loaded dataset.\n"
return fig, mapping_text
def create_quality_plot(quality_metric: str = "helpfulness", show_ci: bool = False) -> Tuple[go.Figure, str]:
"""Create a grouped bar plot of quality by property and model."""
if app_state.get("model_cluster_df") is None:
return None, "No model cluster data loaded. Please load data first."
model_cluster_df = app_state["model_cluster_df"]
if model_cluster_df.empty:
return None, "No model cluster data available."
# Check if the quality metric exists in the data
quality_col = f"quality_{quality_metric}"
if quality_col not in model_cluster_df.columns:
# Get available quality metrics for better error message
available_metrics = [col.replace("quality_", "") for col in model_cluster_df.columns
if col.startswith("quality_")
and not col.endswith(("_ci_lower", "_ci_upper", "_ci_mean", "_significant", "_delta"))]
if not available_metrics:
return None, f"No quality metrics found in the data. Available columns: {list(model_cluster_df.columns)}"
return None, f"Quality metric '{quality_metric}' not found. Available metrics: {available_metrics}"
# Create a copy for plotting
plot_df = model_cluster_df.copy()
# Ensure quality values are numeric
plot_df[quality_col] = pd.to_numeric(plot_df[quality_col], errors='coerce')
# Check if we have any valid quality data
if plot_df[quality_col].isna().all():
return None, f"No valid quality data found for metric '{quality_metric}'. All values are missing or invalid."
# Create property name mapping with proper ordering (same as proportion plot)
unique_properties = sorted(plot_df['cluster'].unique())
property_mapping = {prop: f"P{i+1}" for i, prop in enumerate(unique_properties)}
# Create abbreviated property column for plotting
plot_df['property_abbr'] = plot_df['cluster'].map(property_mapping)
# Filter out "No properties" clusters
plot_df = plot_df[plot_df['cluster'] != "No properties"]
# Prepare confidence interval data if requested
error_y_data = None
if show_ci:
ci_lower_col = f"{quality_col}_ci_lower"
ci_upper_col = f"{quality_col}_ci_upper"
if ci_lower_col in plot_df.columns and ci_upper_col in plot_df.columns:
# Calculate error bar values
plot_df['y_error'] = plot_df[ci_upper_col] - plot_df[quality_col]
plot_df['y_error_minus'] = plot_df[quality_col] - plot_df[ci_lower_col]
# Replace NaN values with 0
plot_df['y_error'] = plot_df['y_error'].fillna(0)
plot_df['y_error_minus'] = plot_df['y_error_minus'].fillna(0)
error_y_data = plot_df['y_error']
error_y_minus_data = plot_df['y_error_minus']
# Create a grouped bar plot of quality by property (x) and model (hue)
fig = px.bar(
plot_df,
x="property_abbr",
y=quality_col,
color="model",
barmode="group",
title=f"Quality ({quality_metric.title()}) by Property and Model",
labels={quality_col: f"Quality ({quality_metric.title()})", "property_abbr": "Property", "model": "Model"},
error_y="y_error" if error_y_data is not None else None,
error_y_minus="y_error_minus" if error_y_data is not None else None
)
# Set the x-axis order to ensure P1, P2, P3, etc. (same as proportion plot)
property_order = [f"P{i+1}" for i in range(len(unique_properties))]
fig.update_xaxes(categoryorder='array', categoryarray=property_order)
fig.update_layout(xaxis_tickangle=45)
# save figure to file
fig.write_html(f"model_cluster_quality_{quality_metric}_plot.html")
# Create property mapping string (same as proportion plot)
mapping_text = "**Property Mapping:**\n\n"
for prop, abbr in property_mapping.items():
mapping_text += f"**{abbr}:** {prop}\n\n"
# Add confidence interval info if enabled
if show_ci:
ci_lower_col = f"{quality_col}_ci_lower"
ci_upper_col = f"{quality_col}_ci_upper"
if ci_lower_col in plot_df.columns and ci_upper_col in plot_df.columns:
mapping_text += "---\n\n**Confidence Intervals:**\n"
mapping_text += f"Error bars show 95% confidence intervals for {quality_metric} values.\n"
else:
mapping_text += "---\n\n**Note:** Confidence interval data not available for this quality metric.\n"
return fig, mapping_text
def get_available_quality_metrics() -> List[str]:
"""Get available quality metrics from the loaded DataFrame."""
if app_state.get("model_cluster_df") is None:
return ["helpfulness", "accuracy", "harmlessness", "honesty"]
model_cluster_df = app_state["model_cluster_df"]
# Find all quality columns (excluding CI and other suffix columns)
quality_columns = [col for col in model_cluster_df.columns
if col.startswith("quality_")
and not col.endswith(("_ci_lower", "_ci_upper", "_ci_mean", "_significant", "_delta"))]
# Extract metric names by removing "quality_" prefix
available_quality_metrics = [col.replace("quality_", "") for col in quality_columns]
# If no quality metrics found, provide defaults
if not available_quality_metrics:
available_quality_metrics = ["helpfulness", "accuracy", "harmlessness", "honesty"]
return available_quality_metrics
def update_quality_metric_dropdown() -> gr.Dropdown:
"""Update the quality metric dropdown with available metrics."""
available_metrics = get_available_quality_metrics()
return gr.Dropdown(
label="Quality Metric",
choices=available_metrics,
value=available_metrics[0] if available_metrics else "helpfulness",
info="Select which quality metric to display"
)
def update_quality_metric_visibility(plot_type: str) -> gr.Dropdown:
"""Update the quality metric dropdown visibility based on plot type."""
available_metrics = get_available_quality_metrics()
return gr.Dropdown(
label="Quality Metric",
choices=available_metrics,
value=available_metrics[0] if available_metrics else "helpfulness",
info="Select which quality metric to display",
visible=(plot_type == "quality")
)
def create_plot_with_toggle(plot_type: str, quality_metric: str = "helpfulness", show_ci: bool = False) -> Tuple[go.Figure, str]:
"""Create a plot based on the selected type (frequency or quality)."""
if plot_type == "frequency":
return create_proportion_plot(show_ci)
elif plot_type == "quality":
return create_quality_plot(quality_metric, show_ci)
else:
return None, f"Unknown plot type: {plot_type}"
def create_plots_tab() -> Tuple[gr.Plot, gr.Markdown, gr.Checkbox, gr.Dropdown, gr.Dropdown]:
"""Create the plots tab interface with a toggle between frequency and quality plots."""
gr.Markdown("Interactive grouped bar plot showing either frequency (proportion) or quality metrics by property and model. **If the plot looks wonky, just unclick and re-click the signifigance checkbox to have it resize**")
# Plot controls in a row
with gr.Row():
# Plot type toggle
plot_type_dropdown = gr.Dropdown(
label="Plot Type",
choices=["frequency", "quality"],
value="frequency",
info="Choose between frequency (proportion) or quality metrics"
)
# Quality metric dropdown (only visible for quality plots)
quality_metric_dropdown = gr.Dropdown(
label="Quality Metric",
choices=get_available_quality_metrics(),
value=get_available_quality_metrics()[0] if get_available_quality_metrics() else "helpfulness",
info="Select which quality metric to display",
visible=False # Initially hidden, shown when quality is selected
)
# Add checkbox for confidence intervals
show_ci_checkbox = gr.Checkbox(
label="Show Confidence Intervals",
value=True,
info="Display 95% confidence intervals as error bars (if available in data)"
)
plot_display = gr.Plot(
label="Model-Cluster Analysis Plot",
value=None
)
plot_info = gr.Markdown("")
return plot_display, plot_info, show_ci_checkbox, plot_type_dropdown, quality_metric_dropdown |