525GradioApp / visualization /topic_visualizer.py
Ryan
update
14bac19
raw
history blame
7.59 kB
"""
Visualization for topic modeling analysis results
"""
from visualization.ngram_visualizer import create_ngram_visualization
import gradio as gr
import json
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def create_topic_visualization(analysis_results):
"""
Create visualizations for topic modeling analysis results
Args:
analysis_results (dict): Analysis results from the topic modeling analysis
Returns:
list: List of gradio components with visualizations
"""
# Initialize output components list
output_components = []
# Check if we have valid results
if not analysis_results or "analyses" not in analysis_results:
return [gr.Markdown("No analysis results found.")]
# Process each prompt
for prompt, analyses in analysis_results["analyses"].items():
# Process Topic Modeling analysis if available
if "topic_modeling" in analyses:
topic_results = analyses["topic_modeling"]
# Show method and number of topics
method = topic_results.get("method", "lda").upper()
n_topics = topic_results.get("n_topics", 3)
output_components.append(gr.Markdown(f"## Topic Modeling Analysis ({method}, {n_topics} topics)"))
# Show models being compared
models = topic_results.get("models", [])
if len(models) >= 2:
output_components.append(gr.Markdown(f"### Comparing responses from {models[0]} and {models[1]}"))
# Visualize topics
topics = topic_results.get("topics", [])
if topics:
output_components.append(gr.Markdown("### Discovered Topics"))
for topic in topics:
topic_id = topic.get("id", 0)
words = topic.get("words", [])
weights = topic.get("weights", [])
# Create topic word bar chart
if words and weights and len(words) == len(weights):
# Create dataframe for plotting
df = pd.DataFrame({
'word': words,
'weight': weights
})
# Sort by weight
df = df.sort_values('weight', ascending=False)
# Create bar chart
fig = px.bar(
df, x='word', y='weight',
title=f"Topic {topic_id+1} Top Words",
labels={'word': 'Word', 'weight': 'Weight'},
height=300
)
output_components.append(gr.Plot(value=fig))
# Visualize topic distributions for each model
model_topics = topic_results.get("model_topics", {})
if model_topics and all(model in model_topics for model in models):
output_components.append(gr.Markdown("### Topic Distribution by Model"))
# Create multi-model topic distribution comparison
fig = go.Figure()
for model in models:
if model in model_topics:
distribution = model_topics[model]
fig.add_trace(go.Bar(
x=[f"Topic {i+1}" for i in range(len(distribution))],
y=distribution,
name=model
))
fig.update_layout(
title="Topic Distributions Comparison",
xaxis_title="Topic",
yaxis_title="Weight",
barmode='group',
height=400
)
output_components.append(gr.Plot(value=fig))
# Visualize topic differences
comparisons = topic_results.get("comparisons", {})
if comparisons:
output_components.append(gr.Markdown("### Topic Distribution Differences"))
for comparison_key, comparison_data in comparisons.items():
js_divergence = comparison_data.get("js_divergence", 0)
topic_differences = comparison_data.get("topic_differences", [])
output_components.append(gr.Markdown(
f"**{comparison_key}** - Jensen-Shannon Divergence: {js_divergence:.4f}"
))
if topic_differences:
# Create DataFrame for plotting
model1, model2 = comparison_key.split(" vs ")
df_diff = pd.DataFrame(topic_differences)
# Create bar chart for topic differences
fig = go.Figure()
fig.add_trace(go.Bar(
x=[f"Topic {d['topic_id']+1}" for d in topic_differences],
y=[d["model1_weight"] for d in topic_differences],
name=model1
))
fig.add_trace(go.Bar(
x=[f"Topic {d['topic_id']+1}" for d in topic_differences],
y=[d["model2_weight"] for d in topic_differences],
name=model2
))
fig.update_layout(
title="Topic Weight Comparison",
xaxis_title="Topic",
yaxis_title="Weight",
barmode='group',
height=400
)
output_components.append(gr.Plot(value=fig))
# If no components were added, show a message
if len(output_components) <= 1:
output_components.append(gr.Markdown("No detailed Topic Modeling analysis found in results."))
return output_components
def process_and_visualize_topic_analysis(analysis_results):
"""
Process the topic modeling analysis results and create visualization components
Args:
analysis_results (dict): The analysis results
Returns:
list: List of gradio components for visualization
"""
try:
print(f"Starting visualization of topic modeling analysis results")
return create_topic_visualization(analysis_results)
except Exception as e:
import traceback
error_msg = f"Topic modeling visualization error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return [gr.Markdown(f"**Error during topic modeling visualization:**\n\n```\n{error_msg}\n```")]