525GradioApp / visualization /topic_visualizer.py
Ryan
update
f533950
raw
history blame
9.61 kB
"""
Improved visualization for topic modeling analysis results
"""
import gradio as gr
import json
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('topic_visualizer')
def create_topic_visualization(analysis_results):
"""
Create visualizations for topic modeling analysis results
Args:
analysis_results (dict): Analysis results from the topic modeling analysis
Returns:
list: List of gradio components with visualizations
"""
# Initialize output components list
output_components = []
# Check if we have valid results
if not analysis_results or "analyses" not in analysis_results:
logger.warning("No valid analysis results found")
return [gr.Markdown("No analysis results found.")]
try:
# Process each prompt
for prompt, analyses in analysis_results["analyses"].items():
# Process Topic Modeling analysis if available
if "topic_modeling" in analyses:
topic_results = analyses["topic_modeling"]
# Check for errors in the analysis
if "error" in topic_results:
error_msg = topic_results.get("error", "Unknown error in topic modeling")
logger.warning(f"Topic modeling error: {error_msg}")
output_components.append(gr.Markdown(f"**Error in topic modeling analysis:** {error_msg}"))
continue
# Show method and number of topics
method = topic_results.get("method", "lda").upper()
n_topics = topic_results.get("n_topics", 3)
logger.info(f"Creating visualization for {method} with {n_topics} topics")
# Get models being compared
models = topic_results.get("models", [])
if not models or len(models) < 2:
logger.warning("Not enough models found in results")
output_components.append(gr.Markdown("Topic modeling requires at least two models to compare."))
continue
output_components.append(gr.Markdown(f"### Topic Modeling Analysis ({method}, {n_topics} topics)"))
output_components.append(gr.Markdown(f"Comparing responses from **{models[0]}** and **{models[1]}**"))
# Visualize discovered topics
topics = topic_results.get("topics", [])
if topics:
output_components.append(gr.Markdown("#### Discovered Topics"))
# Display topics in a more readable format
for i, topic in enumerate(topics):
topic_id = topic.get("id", i)
words = topic.get("words", [])
if words:
topic_words = ", ".join(words[:5]) # Show top 5 words
output_components.append(gr.Markdown(f"**Topic {topic_id+1}**: {topic_words}"))
# Visualize topic distributions for each model
model_topics = topic_results.get("model_topics", {})
if model_topics and all(model in model_topics for model in models):
output_components.append(gr.Markdown("#### Topic Distribution by Model"))
# Display topic distributions in a readable format
for model in models:
if model in model_topics:
dist = model_topics[model]
# Format the distribution
dist_str = ", ".join([f"Topic {i+1}: {v:.2f}" for i, v in enumerate(dist[:n_topics])])
output_components.append(gr.Markdown(f"**{model}**: {dist_str}"))
# Create multi-model topic distribution visualization
try:
# Prepare data for visualization
model_data = []
for model in models:
if model in model_topics:
dist = model_topics[model]
for i, weight in enumerate(dist[:n_topics]):
model_data.append({
"Model": model,
"Topic": f"Topic {i+1}",
"Weight": weight
})
if model_data:
df = pd.DataFrame(model_data)
# Create grouped bar chart
fig = px.bar(
df,
x="Topic",
y="Weight",
color="Model",
title="Topic Distribution Comparison",
barmode="group",
height=400
)
fig.update_layout(
xaxis_title="Topic",
yaxis_title="Weight",
legend_title="Model"
)
output_components.append(gr.Plot(value=fig))
except Exception as e:
logger.error(f"Error creating topic distribution plot: {str(e)}")
output_components.append(gr.Markdown(f"*Error creating visualization: {str(e)}*"))
# Display similarity metrics
comparisons = topic_results.get("comparisons", {})
if comparisons:
output_components.append(gr.Markdown("#### Similarity Metrics"))
for comparison_key, comparison_data in comparisons.items():
js_div = comparison_data.get("js_divergence", 0)
# Jensen-Shannon divergence interpretation
similarity_text = ""
if js_div < 0.2:
similarity_text = "very similar"
elif js_div < 0.4:
similarity_text = "somewhat similar"
elif js_div < 0.6:
similarity_text = "moderately different"
else:
similarity_text = "very different"
output_components.append(gr.Markdown(
f"**Topic Distribution Divergence**: {js_div:.4f} - Topic distributions are {similarity_text}"
))
# Explain what the metric means
output_components.append(gr.Markdown(
"*Lower divergence values indicate more similar topic distributions between models*"
))
except Exception as e:
logger.error(f"Error in create_topic_visualization: {str(e)}")
output_components.append(gr.Markdown(f"**Error creating topic visualization:** {str(e)}"))
# If no components were added, show a message
if len(output_components) == 0:
output_components.append(gr.Markdown("No detailed Topic Modeling analysis found in results."))
return output_components
def process_and_visualize_topic_analysis(analysis_results):
"""
Process the topic modeling analysis results and create visualization components
Args:
analysis_results (dict): The analysis results
Returns:
list: List of gradio components for visualization
"""
try:
logger.info(f"Starting visualization of topic modeling analysis results")
# Debug output - print the structure of analysis_results
if "analyses" in analysis_results:
for prompt, analyses in analysis_results["analyses"].items():
if "topic_modeling" in analyses:
topic_results = analyses["topic_modeling"]
logger.info(f"Found topic_modeling results with keys: {topic_results.keys()}")
if "models" in topic_results:
logger.info(f"Models: {topic_results['models']}")
if "topics" in topic_results:
logger.info(f"Found {len(topic_results['topics'])} topics")
if "model_topics" in topic_results:
logger.info(f"Model_topics keys: {topic_results['model_topics'].keys()}")
return create_topic_visualization(analysis_results)
except Exception as e:
import traceback
error_msg = f"Topic modeling visualization error: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return [gr.Markdown(f"**Error during topic modeling visualization:**\n\n```\n{error_msg}\n```")]