""" Improved visualization for topic modeling analysis results """ import gradio as gr import json import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots import logging # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger('topic_visualizer') def create_topic_visualization(analysis_results): """ Create visualizations for topic modeling analysis results Args: analysis_results (dict): Analysis results from the topic modeling analysis Returns: list: List of gradio components with visualizations """ # Initialize output components list output_components = [] # Check if we have valid results if not analysis_results or "analyses" not in analysis_results: logger.warning("No valid analysis results found") return [gr.Markdown("No analysis results found.")] try: # Process each prompt for prompt, analyses in analysis_results["analyses"].items(): # Process Topic Modeling analysis if available if "topic_modeling" in analyses: topic_results = analyses["topic_modeling"] # Check for errors in the analysis if "error" in topic_results: error_msg = topic_results.get("error", "Unknown error in topic modeling") logger.warning(f"Topic modeling error: {error_msg}") output_components.append(gr.Markdown(f"**Error in topic modeling analysis:** {error_msg}")) continue # Show method and number of topics method = topic_results.get("method", "lda").upper() n_topics = topic_results.get("n_topics", 3) logger.info(f"Creating visualization for {method} with {n_topics} topics") # Get models being compared models = topic_results.get("models", []) if not models or len(models) < 2: logger.warning("Not enough models found in results") output_components.append(gr.Markdown("Topic modeling requires at least two models to compare.")) continue output_components.append(gr.Markdown(f"### Topic Modeling Analysis ({method}, {n_topics} topics)")) output_components.append(gr.Markdown(f"Comparing responses from **{models[0]}** and **{models[1]}**")) # Visualize discovered topics topics = topic_results.get("topics", []) if topics: output_components.append(gr.Markdown("#### Discovered Topics")) # Display topics in a more readable format for i, topic in enumerate(topics): topic_id = topic.get("id", i) words = topic.get("words", []) if words: topic_words = ", ".join(words[:5]) # Show top 5 words output_components.append(gr.Markdown(f"**Topic {topic_id+1}**: {topic_words}")) # Visualize topic distributions for each model model_topics = topic_results.get("model_topics", {}) if model_topics and all(model in model_topics for model in models): output_components.append(gr.Markdown("#### Topic Distribution by Model")) # Display topic distributions in a readable format for model in models: if model in model_topics: dist = model_topics[model] # Format the distribution dist_str = ", ".join([f"Topic {i+1}: {v:.2f}" for i, v in enumerate(dist[:n_topics])]) output_components.append(gr.Markdown(f"**{model}**: {dist_str}")) # Create multi-model topic distribution visualization try: # Prepare data for visualization model_data = [] for model in models: if model in model_topics: dist = model_topics[model] for i, weight in enumerate(dist[:n_topics]): model_data.append({ "Model": model, "Topic": f"Topic {i+1}", "Weight": weight }) if model_data: df = pd.DataFrame(model_data) # Create grouped bar chart fig = px.bar( df, x="Topic", y="Weight", color="Model", title="Topic Distribution Comparison", barmode="group", height=400 ) fig.update_layout( xaxis_title="Topic", yaxis_title="Weight", legend_title="Model" ) output_components.append(gr.Plot(value=fig)) except Exception as e: logger.error(f"Error creating topic distribution plot: {str(e)}") output_components.append(gr.Markdown(f"*Error creating visualization: {str(e)}*")) # Display similarity metrics comparisons = topic_results.get("comparisons", {}) if comparisons: output_components.append(gr.Markdown("#### Similarity Metrics")) for comparison_key, comparison_data in comparisons.items(): js_div = comparison_data.get("js_divergence", 0) # Jensen-Shannon divergence interpretation similarity_text = "" if js_div < 0.2: similarity_text = "very similar" elif js_div < 0.4: similarity_text = "somewhat similar" elif js_div < 0.6: similarity_text = "moderately different" else: similarity_text = "very different" output_components.append(gr.Markdown( f"**Topic Distribution Divergence**: {js_div:.4f} - Topic distributions are {similarity_text}" )) # Explain what the metric means output_components.append(gr.Markdown( "*Lower divergence values indicate more similar topic distributions between models*" )) except Exception as e: logger.error(f"Error in create_topic_visualization: {str(e)}") output_components.append(gr.Markdown(f"**Error creating topic visualization:** {str(e)}")) # If no components were added, show a message if len(output_components) == 0: output_components.append(gr.Markdown("No detailed Topic Modeling analysis found in results.")) return output_components def process_and_visualize_topic_analysis(analysis_results): """ Process the topic modeling analysis results and create visualization components Args: analysis_results (dict): The analysis results Returns: list: List of gradio components for visualization """ try: logger.info(f"Starting visualization of topic modeling analysis results") # Debug output - print the structure of analysis_results if "analyses" in analysis_results: for prompt, analyses in analysis_results["analyses"].items(): if "topic_modeling" in analyses: topic_results = analyses["topic_modeling"] logger.info(f"Found topic_modeling results with keys: {topic_results.keys()}") if "models" in topic_results: logger.info(f"Models: {topic_results['models']}") if "topics" in topic_results: logger.info(f"Found {len(topic_results['topics'])} topics") if "model_topics" in topic_results: logger.info(f"Model_topics keys: {topic_results['model_topics'].keys()}") return create_topic_visualization(analysis_results) except Exception as e: import traceback error_msg = f"Topic modeling visualization error: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return [gr.Markdown(f"**Error during topic modeling visualization:**\n\n```\n{error_msg}\n```")]