Spaces:
Sleeping
Sleeping
""" | |
Visualization for topic modeling analysis results | |
""" | |
from visualization.ngram_visualizer import create_ngram_visualization | |
import gradio as gr | |
import json | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
def create_topic_visualization(analysis_results): | |
""" | |
Create visualizations for topic modeling analysis results | |
Args: | |
analysis_results (dict): Analysis results from the topic modeling analysis | |
Returns: | |
list: List of gradio components with visualizations | |
""" | |
# Initialize output components list | |
output_components = [] | |
# Check if we have valid results | |
if not analysis_results or "analyses" not in analysis_results: | |
return [gr.Markdown("No analysis results found.")] | |
# Process each prompt | |
for prompt, analyses in analysis_results["analyses"].items(): | |
# Process Topic Modeling analysis if available | |
if "topic_modeling" in analyses: | |
topic_results = analyses["topic_modeling"] | |
# Show method and number of topics | |
method = topic_results.get("method", "lda").upper() | |
n_topics = topic_results.get("n_topics", 3) | |
output_components.append(gr.Markdown(f"## Topic Modeling Analysis ({method}, {n_topics} topics)")) | |
# Show models being compared | |
models = topic_results.get("models", []) | |
if len(models) >= 2: | |
output_components.append(gr.Markdown(f"### Comparing responses from {models[0]} and {models[1]}")) | |
# Visualize topics | |
topics = topic_results.get("topics", []) | |
if topics: | |
output_components.append(gr.Markdown("### Discovered Topics")) | |
for topic in topics: | |
topic_id = topic.get("id", 0) | |
words = topic.get("words", []) | |
weights = topic.get("weights", []) | |
# Create topic word bar chart | |
if words and weights and len(words) == len(weights): | |
# Create dataframe for plotting | |
df = pd.DataFrame({ | |
'word': words, | |
'weight': weights | |
}) | |
# Sort by weight | |
df = df.sort_values('weight', ascending=False) | |
# Create bar chart | |
fig = px.bar( | |
df, x='word', y='weight', | |
title=f"Topic {topic_id+1} Top Words", | |
labels={'word': 'Word', 'weight': 'Weight'}, | |
height=300 | |
) | |
output_components.append(gr.Plot(value=fig)) | |
# Visualize topic distributions for each model | |
model_topics = topic_results.get("model_topics", {}) | |
if model_topics and all(model in model_topics for model in models): | |
output_components.append(gr.Markdown("### Topic Distribution by Model")) | |
# Create multi-model topic distribution comparison | |
fig = go.Figure() | |
for model in models: | |
if model in model_topics: | |
distribution = model_topics[model] | |
fig.add_trace(go.Bar( | |
x=[f"Topic {i+1}" for i in range(len(distribution))], | |
y=distribution, | |
name=model | |
)) | |
fig.update_layout( | |
title="Topic Distributions Comparison", | |
xaxis_title="Topic", | |
yaxis_title="Weight", | |
barmode='group', | |
height=400 | |
) | |
output_components.append(gr.Plot(value=fig)) | |
# Visualize topic differences | |
comparisons = topic_results.get("comparisons", {}) | |
if comparisons: | |
output_components.append(gr.Markdown("### Topic Distribution Differences")) | |
for comparison_key, comparison_data in comparisons.items(): | |
js_divergence = comparison_data.get("js_divergence", 0) | |
topic_differences = comparison_data.get("topic_differences", []) | |
output_components.append(gr.Markdown( | |
f"**{comparison_key}** - Jensen-Shannon Divergence: {js_divergence:.4f}" | |
)) | |
if topic_differences: | |
# Create DataFrame for plotting | |
model1, model2 = comparison_key.split(" vs ") | |
df_diff = pd.DataFrame(topic_differences) | |
# Create bar chart for topic differences | |
fig = go.Figure() | |
fig.add_trace(go.Bar( | |
x=[f"Topic {d['topic_id']+1}" for d in topic_differences], | |
y=[d["model1_weight"] for d in topic_differences], | |
name=model1 | |
)) | |
fig.add_trace(go.Bar( | |
x=[f"Topic {d['topic_id']+1}" for d in topic_differences], | |
y=[d["model2_weight"] for d in topic_differences], | |
name=model2 | |
)) | |
fig.update_layout( | |
title="Topic Weight Comparison", | |
xaxis_title="Topic", | |
yaxis_title="Weight", | |
barmode='group', | |
height=400 | |
) | |
output_components.append(gr.Plot(value=fig)) | |
# If no components were added, show a message | |
if len(output_components) <= 1: | |
output_components.append(gr.Markdown("No detailed Topic Modeling analysis found in results.")) | |
return output_components | |
def process_and_visualize_topic_analysis(analysis_results): | |
""" | |
Process the topic modeling analysis results and create visualization components | |
Args: | |
analysis_results (dict): The analysis results | |
Returns: | |
list: List of gradio components for visualization | |
""" | |
try: | |
print(f"Starting visualization of topic modeling analysis results") | |
return create_topic_visualization(analysis_results) | |
except Exception as e: | |
import traceback | |
error_msg = f"Topic modeling visualization error: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
return [gr.Markdown(f"**Error during topic modeling visualization:**\n\n```\n{error_msg}\n```")] | |