Spaces:
Sleeping
Sleeping
""" | |
Improved visualization for topic modeling analysis results | |
""" | |
import gradio as gr | |
import json | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger('topic_visualizer') | |
def create_topic_visualization(analysis_results): | |
""" | |
Create visualizations for topic modeling analysis results | |
Args: | |
analysis_results (dict): Analysis results from the topic modeling analysis | |
Returns: | |
list: List of gradio components with visualizations | |
""" | |
# Initialize output components list | |
output_components = [] | |
# Check if we have valid results | |
if not analysis_results or "analyses" not in analysis_results: | |
logger.warning("No valid analysis results found") | |
return [gr.Markdown("No analysis results found.")] | |
try: | |
# Process each prompt | |
for prompt, analyses in analysis_results["analyses"].items(): | |
# Process Topic Modeling analysis if available | |
if "topic_modeling" in analyses: | |
topic_results = analyses["topic_modeling"] | |
# Check for errors in the analysis | |
if "error" in topic_results: | |
error_msg = topic_results.get("error", "Unknown error in topic modeling") | |
logger.warning(f"Topic modeling error: {error_msg}") | |
output_components.append(gr.Markdown(f"**Error in topic modeling analysis:** {error_msg}")) | |
continue | |
# Show method and number of topics | |
method = topic_results.get("method", "lda").upper() | |
n_topics = topic_results.get("n_topics", 3) | |
logger.info(f"Creating visualization for {method} with {n_topics} topics") | |
# Get models being compared | |
models = topic_results.get("models", []) | |
if not models or len(models) < 2: | |
logger.warning("Not enough models found in results") | |
output_components.append(gr.Markdown("Topic modeling requires at least two models to compare.")) | |
continue | |
output_components.append(gr.Markdown(f"### Topic Modeling Analysis ({method}, {n_topics} topics)")) | |
output_components.append(gr.Markdown(f"Comparing responses from **{models[0]}** and **{models[1]}**")) | |
# Visualize discovered topics | |
topics = topic_results.get("topics", []) | |
if topics: | |
output_components.append(gr.Markdown("#### Discovered Topics")) | |
# Display topics in a more readable format | |
for i, topic in enumerate(topics): | |
topic_id = topic.get("id", i) | |
words = topic.get("words", []) | |
if words: | |
topic_words = ", ".join(words[:5]) # Show top 5 words | |
output_components.append(gr.Markdown(f"**Topic {topic_id+1}**: {topic_words}")) | |
# Visualize topic distributions for each model | |
model_topics = topic_results.get("model_topics", {}) | |
if model_topics and all(model in model_topics for model in models): | |
output_components.append(gr.Markdown("#### Topic Distribution by Model")) | |
# Display topic distributions in a readable format | |
for model in models: | |
if model in model_topics: | |
dist = model_topics[model] | |
# Format the distribution | |
dist_str = ", ".join([f"Topic {i+1}: {v:.2f}" for i, v in enumerate(dist[:n_topics])]) | |
output_components.append(gr.Markdown(f"**{model}**: {dist_str}")) | |
# Create multi-model topic distribution visualization | |
try: | |
# Prepare data for visualization | |
model_data = [] | |
for model in models: | |
if model in model_topics: | |
dist = model_topics[model] | |
for i, weight in enumerate(dist[:n_topics]): | |
model_data.append({ | |
"Model": model, | |
"Topic": f"Topic {i+1}", | |
"Weight": weight | |
}) | |
if model_data: | |
df = pd.DataFrame(model_data) | |
# Create grouped bar chart | |
fig = px.bar( | |
df, | |
x="Topic", | |
y="Weight", | |
color="Model", | |
title="Topic Distribution Comparison", | |
barmode="group", | |
height=400 | |
) | |
fig.update_layout( | |
xaxis_title="Topic", | |
yaxis_title="Weight", | |
legend_title="Model" | |
) | |
output_components.append(gr.Plot(value=fig)) | |
except Exception as e: | |
logger.error(f"Error creating topic distribution plot: {str(e)}") | |
output_components.append(gr.Markdown(f"*Error creating visualization: {str(e)}*")) | |
# Display similarity metrics | |
comparisons = topic_results.get("comparisons", {}) | |
if comparisons: | |
output_components.append(gr.Markdown("#### Similarity Metrics")) | |
for comparison_key, comparison_data in comparisons.items(): | |
js_div = comparison_data.get("js_divergence", 0) | |
# Jensen-Shannon divergence interpretation | |
similarity_text = "" | |
if js_div < 0.2: | |
similarity_text = "very similar" | |
elif js_div < 0.4: | |
similarity_text = "somewhat similar" | |
elif js_div < 0.6: | |
similarity_text = "moderately different" | |
else: | |
similarity_text = "very different" | |
output_components.append(gr.Markdown( | |
f"**Topic Distribution Divergence**: {js_div:.4f} - Topic distributions are {similarity_text}" | |
)) | |
# Explain what the metric means | |
output_components.append(gr.Markdown( | |
"*Lower divergence values indicate more similar topic distributions between models*" | |
)) | |
except Exception as e: | |
logger.error(f"Error in create_topic_visualization: {str(e)}") | |
output_components.append(gr.Markdown(f"**Error creating topic visualization:** {str(e)}")) | |
# If no components were added, show a message | |
if len(output_components) == 0: | |
output_components.append(gr.Markdown("No detailed Topic Modeling analysis found in results.")) | |
return output_components | |
def process_and_visualize_topic_analysis(analysis_results): | |
""" | |
Process the topic modeling analysis results and create visualization components | |
Args: | |
analysis_results (dict): The analysis results | |
Returns: | |
list: List of gradio components for visualization | |
""" | |
try: | |
logger.info(f"Starting visualization of topic modeling analysis results") | |
# Debug output - print the structure of analysis_results | |
if "analyses" in analysis_results: | |
for prompt, analyses in analysis_results["analyses"].items(): | |
if "topic_modeling" in analyses: | |
topic_results = analyses["topic_modeling"] | |
logger.info(f"Found topic_modeling results with keys: {topic_results.keys()}") | |
if "models" in topic_results: | |
logger.info(f"Models: {topic_results['models']}") | |
if "topics" in topic_results: | |
logger.info(f"Found {len(topic_results['topics'])} topics") | |
if "model_topics" in topic_results: | |
logger.info(f"Model_topics keys: {topic_results['model_topics'].keys()}") | |
return create_topic_visualization(analysis_results) | |
except Exception as e: | |
import traceback | |
error_msg = f"Topic modeling visualization error: {str(e)}\n{traceback.format_exc()}" | |
logger.error(error_msg) | |
return [gr.Markdown(f"**Error during topic modeling visualization:**\n\n```\n{error_msg}\n```")] | |