Spaces:
Sleeping
Sleeping
""" | |
Enhanced visualization for topic modeling analysis results | |
""" | |
import gradio as gr | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import numpy as np | |
def create_topic_visualization(analysis_results): | |
""" | |
Create enhanced visualizations for topic modeling analysis results | |
Args: | |
analysis_results (dict): Analysis results from the topic modeling analysis | |
Returns: | |
list: List of gradio components with visualizations | |
""" | |
# Initialize output components list | |
output_components = [] | |
# Check if we have valid results | |
if not analysis_results or "analyses" not in analysis_results: | |
return [gr.Markdown("No analysis results found.")] | |
# Process each prompt | |
for prompt, analyses in analysis_results["analyses"].items(): | |
# Process Topic Modeling analysis if available | |
if "topic_modeling" in analyses: | |
topic_results = analyses["topic_modeling"] | |
# Enhanced error checking and messaging | |
if "error" in topic_results: | |
output_components.append(gr.Markdown(f"## ⚠️ Topic Modeling Error")) | |
output_components.append(gr.Markdown(f"Error: {topic_results['error']}")) | |
output_components.append(gr.Markdown("Suggestions:")) | |
output_components.append(gr.Markdown("1. Try with longer text samples - topic modeling typically needs 100+ words per document")) | |
output_components.append(gr.Markdown("2. Reduce the number of topics (2-3 for short texts)")) | |
output_components.append(gr.Markdown("3. Try the Bag of Words or N-gram analysis for shorter texts")) | |
continue | |
# Show method and number of topics | |
method = topic_results.get("method", "lda").upper() | |
n_topics = topic_results.get("n_topics", 3) | |
# Check if n_topics was adjusted | |
if "adjusted_n_topics" in topic_results and topic_results["adjusted_n_topics"] != topic_results.get("original_n_topics", n_topics): | |
output_components.append(gr.Markdown( | |
f"## Topic Modeling Analysis ({method}, {topic_results['adjusted_n_topics']} topics) " + | |
f"*Adjusted from {topic_results['original_n_topics']} due to limited text content*" | |
)) | |
n_topics = topic_results["adjusted_n_topics"] | |
else: | |
output_components.append(gr.Markdown(f"## Topic Modeling Analysis ({method}, {n_topics} topics)")) | |
# Check for warnings | |
if "warnings" in topic_results: | |
if isinstance(topic_results["warnings"], list): | |
for warning in topic_results["warnings"]: | |
output_components.append(gr.Markdown(f"⚠️ **Warning**: {warning}")) | |
else: | |
output_components.append(gr.Markdown(f"⚠️ **Warning**: {topic_results['warnings']}")) | |
if "warning" in topic_results: | |
output_components.append(gr.Markdown(f"⚠️ **Warning**: {topic_results['warning']}")) | |
# Show models being compared | |
models = topic_results.get("models", []) | |
if len(models) >= 2: | |
output_components.append(gr.Markdown(f"### Comparing responses from {models[0]} and {models[1]}")) | |
# Show topic quality metrics if available | |
if "coherence_scores" in topic_results: | |
coherence_html = f""" | |
<div style="margin: 20px 0; padding: 15px; background-color: #f8f9fa; border-radius: 5px;"> | |
<h4 style="margin-top: 0;">Topic Quality Metrics</h4> | |
<table style="width: 100%; border-collapse: collapse;"> | |
<tr> | |
<th style="text-align: left; padding: 8px; border-bottom: 1px solid #ddd;">Metric</th> | |
<th style="text-align: center; padding: 8px; border-bottom: 1px solid #ddd;">{models[0]}</th> | |
<th style="text-align: center; padding: 8px; border-bottom: 1px solid #ddd;">{models[1]}</th> | |
<th style="text-align: center; padding: 8px; border-bottom: 1px solid #ddd;">Combined</th> | |
</tr> | |
<tr> | |
<td style="padding: 8px; border-bottom: 1px solid #ddd;">Topic Coherence</td> | |
<td style="text-align: center; padding: 8px; border-bottom: 1px solid #ddd;"> | |
{topic_results["coherence_scores"].get(models[0], 0):.2f} | |
</td> | |
<td style="text-align: center; padding: 8px; border-bottom: 1px solid #ddd;"> | |
{topic_results["coherence_scores"].get(models[1], 0):.2f} | |
</td> | |
<td style="text-align: center; padding: 8px; border-bottom: 1px solid #ddd;"> | |
{topic_results["coherence_scores"].get("combined", 0):.2f} | |
</td> | |
</tr> | |
<tr> | |
<td style="padding: 8px;">Topic Diversity</td> | |
<td style="text-align: center; padding: 8px;"> | |
{topic_results["diversity_scores"].get(models[0], 0):.2f} | |
</td> | |
<td style="text-align: center; padding: 8px;"> | |
{topic_results["diversity_scores"].get(models[1], 0):.2f} | |
</td> | |
<td style="text-align: center; padding: 8px;"> | |
{topic_results["diversity_scores"].get("combined", 0):.2f} | |
</td> | |
</tr> | |
</table> | |
<p style="margin-bottom: 0; font-size: 0.9em; color: #666;"> | |
Higher coherence scores indicate more semantically coherent topics.<br> | |
Higher diversity scores indicate less overlap between topics. | |
</p> | |
</div> | |
""" | |
output_components.append(gr.HTML(coherence_html)) | |
# Visualize topics | |
topics = topic_results.get("topics", []) | |
if topics: | |
output_components.append(gr.Markdown("### Discovered Topics")) | |
# Create a topic word cloud using HTML/CSS for better visibility | |
for topic in topics: | |
topic_id = topic.get("id", 0) | |
words = topic.get("words", []) | |
weights = topic.get("weights", []) | |
if words and weights and len(words) == len(weights): | |
# Generate a word cloud-like div using HTML/CSS | |
word_cloud_html = f""" | |
<div style="margin-bottom: 25px;"> | |
<h4 style="margin-bottom: 10px;">Topic {topic_id+1}</h4> | |
<div style="display: flex; flex-wrap: wrap; gap: 10px; background: #f9f9f9; padding: 15px; border-radius: 5px;"> | |
""" | |
# Sort words by weight for better visualization | |
word_weight_pairs = sorted(zip(words, weights), key=lambda x: x[1], reverse=True) | |
# Add each word with size based on weight | |
for word, weight in word_weight_pairs: | |
# Scale weight to a reasonable font size (min 14px, max 28px) | |
font_size = 14 + min(14, round(weight * 30)) | |
# Color based on weight (darker = higher weight) | |
color_intensity = max(0, min(90, int(100 - weight * 100))) | |
color = f"hsl(210, 70%, {color_intensity}%)" | |
word_cloud_html += f""" | |
<span style="font-size: {font_size}px; color: {color}; margin: 3px; | |
padding: 5px; border-radius: 3px; background: rgba(0,0,0,0.03);"> | |
{word} | |
</span> | |
""" | |
word_cloud_html += """ | |
</div> | |
</div> | |
""" | |
output_components.append(gr.HTML(word_cloud_html)) | |
# Add a proper bar chart visualization for topic words | |
for topic in topics[:min(3, len(topics))]: # Show charts for max 3 topics to avoid clutter | |
topic_id = topic.get("id", 0) | |
words = topic.get("words", []) | |
weights = topic.get("weights", []) | |
if words and weights and len(words) == len(weights): | |
# Create dataframe for plotting | |
df = pd.DataFrame({ | |
'word': words, | |
'weight': weights | |
}) | |
# Sort by weight | |
df = df.sort_values('weight', ascending=False) | |
# Limit to top N words for clarity | |
df = df.head(10) | |
# Create bar chart | |
fig = px.bar( | |
df, x='weight', y='word', | |
title=f"Topic {topic_id+1} Top Words", | |
labels={'word': 'Word', 'weight': 'Weight'}, | |
height=300, | |
orientation='h' # Horizontal bars | |
) | |
# Improve layout | |
fig.update_layout( | |
margin=dict(l=10, r=10, t=40, b=10), | |
yaxis={'categoryorder': 'total ascending'} | |
) | |
output_components.append(gr.Plot(value=fig)) | |
# Visualize topic distributions for each model | |
model_topics = topic_results.get("model_topics", {}) | |
if model_topics and all(model in model_topics for model in models): | |
output_components.append(gr.Markdown("### Topic Distribution by Model")) | |
# Create multi-model topic distribution comparison | |
distribution_data = [] | |
for model in models: | |
if model in model_topics: | |
distribution = model_topics[model] | |
for i, weight in enumerate(distribution): | |
if i < 10: # Limit to 10 topics max | |
distribution_data.append({ | |
'Model': model, | |
'Topic': f"Topic {i+1}", | |
'Weight': weight | |
}) | |
if distribution_data: | |
df = pd.DataFrame(distribution_data) | |
# Create grouped bar chart | |
fig = px.bar( | |
df, x='Topic', y='Weight', color='Model', | |
barmode='group', | |
title="Topic Distribution Comparison", | |
height=400 | |
) | |
output_components.append(gr.Plot(value=fig)) | |
# Visualize topic differences as a heatmap | |
comparisons = topic_results.get("comparisons", {}) | |
if comparisons: | |
comparison_key = f"{models[0]} vs {models[1]}" | |
if comparison_key in comparisons: | |
output_components.append(gr.Markdown("### Topic Similarity Analysis")) | |
# Get JS divergence | |
js_divergence = comparisons[comparison_key].get("js_divergence", 0) | |
# Create a divergence meter | |
divergence_html = f""" | |
<div style="margin: 20px 0; padding: 20px; background-color: #f8f9fa; border-radius: 5px; text-align: center;"> | |
<h4 style="margin-top: 0;">Topic Distribution Divergence</h4> | |
<div style="display: flex; align-items: center; justify-content: center;"> | |
<div style="width: 300px; height: 40px; background: linear-gradient(to right, #1a9850, #ffffbf, #d73027); border-radius: 5px; position: relative; margin: 10px 0;"> | |
<div style="position: absolute; height: 40px; width: 2px; background-color: #000; left: {min(300, max(0, js_divergence * 300))}px;"></div> | |
</div> | |
</div> | |
<div style="display: flex; justify-content: space-between; width: 300px; margin: 0 auto;"> | |
<span>Similar (0.0)</span> | |
<span>Different (1.0)</span> | |
</div> | |
<p style="margin-top: 10px; font-weight: bold;">Score: {js_divergence:.3f}</p> | |
<p style="margin-bottom: 0; font-size: 0.9em; color: #666;"> | |
Jensen-Shannon Divergence measures the similarity between topic distributions.<br> | |
Lower values indicate more similar topic distributions between models. | |
</p> | |
</div> | |
""" | |
output_components.append(gr.HTML(divergence_html)) | |
# Create similarity matrix heatmap if available | |
similarity_matrix = topic_results.get("similarity_matrix", []) | |
if similarity_matrix and len(similarity_matrix) > 0: | |
# Convert to format for heatmap | |
z_data = similarity_matrix | |
# Create heatmap | |
fig = go.Figure(data=go.Heatmap( | |
z=z_data, | |
x=[f"{models[1]} Topic {i+1}" for i in range(len(similarity_matrix[0]))], | |
y=[f"{models[0]} Topic {i+1}" for i in range(len(similarity_matrix))], | |
colorscale='Viridis', | |
showscale=True, | |
colorbar=dict(title="Similarity") | |
)) | |
fig.update_layout( | |
title="Topic Similarity Matrix", | |
height=400, | |
margin=dict(l=50, r=50, t=50, b=50) | |
) | |
output_components.append(gr.Plot(value=fig)) | |
# Show best matching topics | |
matched_topics = topic_results.get("matched_topics", []) | |
if matched_topics: | |
output_components.append(gr.Markdown("### Most Similar Topic Pairs")) | |
# Create HTML table for matched topics | |
matched_topics_html = """ | |
<div style="margin: 20px 0;"> | |
<table style="width: 100%; border-collapse: collapse;"> | |
<tr> | |
<th style="padding: 8px; border-bottom: 2px solid #ddd; text-align: left;">Topic Pair</th> | |
<th style="padding: 8px; border-bottom: 2px solid #ddd; text-align: left;">Top Words in Model 1</th> | |
<th style="padding: 8px; border-bottom: 2px solid #ddd; text-align: left;">Top Words in Model 2</th> | |
<th style="padding: 8px; border-bottom: 2px solid #ddd; text-align: center;">Similarity</th> | |
</tr> | |
""" | |
# Sort by similarity, highest first | |
sorted_matches = sorted(matched_topics, key=lambda x: x['similarity'], reverse=True) | |
for match in sorted_matches: | |
# Format words with commas | |
words1 = ", ".join(match["set1_topic_words"][:5]) # Show top 5 words | |
words2 = ", ".join(match["set2_topic_words"][:5]) # Show top 5 words | |
# Calculate color based on similarity (green for high, red for low) | |
similarity = match["similarity"] | |
color = f"hsl({int(120 * similarity)}, 70%, 50%)" | |
matched_topics_html += f""" | |
<tr> | |
<td style="padding: 8px; border-bottom: 1px solid #ddd;"> | |
{models[0]} Topic {match['set1_topic_id']+1} ↔ {models[1]} Topic {match['set2_topic_id']+1} | |
</td> | |
<td style="padding: 8px; border-bottom: 1px solid #ddd;">{words1}</td> | |
<td style="padding: 8px; border-bottom: 1px solid #ddd;">{words2}</td> | |
<td style="padding: 8px; border-bottom: 1px solid #ddd; text-align: center; font-weight: bold; color: {color};"> | |
{similarity:.2f} | |
</td> | |
</tr> | |
""" | |
matched_topics_html += """ | |
</table> | |
</div> | |
""" | |
output_components.append(gr.HTML(matched_topics_html)) | |
# If no components were added, show a message | |
if len(output_components) <= 1: | |
output_components.append(gr.Markdown("No detailed Topic Modeling analysis found in results.")) | |
return output_components | |
def process_and_visualize_topic_analysis(analysis_results): | |
""" | |
Process the topic modeling analysis results and create visualization components | |
Args: | |
analysis_results (dict): The analysis results | |
Returns: | |
list: List of gradio components for visualization | |
""" | |
try: | |
print(f"Starting visualization of topic modeling analysis results") | |
components = create_topic_visualization(analysis_results) | |
print(f"Completed topic modeling visualization with {len(components)} components") | |
return components | |
except Exception as e: | |
import traceback | |
error_msg = f"Topic modeling visualization error: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) | |
return [ | |
gr.Markdown(f"**Error during topic modeling visualization:**"), | |
gr.Markdown(f"```\n{str(e)}\n```"), | |
gr.Markdown("Try adjusting the number of topics or using longer text inputs.") | |
] |