File size: 9,611 Bytes
14bac19
f533950
14bac19
 
 
 
 
 
 
 
f533950
 
 
 
 
14bac19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f533950
14bac19
 
f533950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14bac19
f533950
14bac19
 
f533950
14bac19
f533950
 
 
14bac19
f533950
 
 
14bac19
 
 
 
f533950
14bac19
f533950
14bac19
 
f533950
 
 
 
14bac19
f533950
 
 
 
 
 
 
 
 
 
 
 
 
14bac19
f533950
 
14bac19
f533950
 
 
 
 
 
 
 
 
 
14bac19
 
 
 
f533950
14bac19
 
 
f533950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14bac19
 
f533950
14bac19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f533950
 
 
 
 
 
 
 
 
 
 
 
 
 
14bac19
 
 
 
f533950
14bac19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
Improved visualization for topic modeling analysis results
"""
import gradio as gr
import json
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('topic_visualizer')

def create_topic_visualization(analysis_results):
    """
    Create visualizations for topic modeling analysis results
    
    Args:
        analysis_results (dict): Analysis results from the topic modeling analysis
        
    Returns:
        list: List of gradio components with visualizations
    """
    # Initialize output components list
    output_components = []
    
    # Check if we have valid results
    if not analysis_results or "analyses" not in analysis_results:
        logger.warning("No valid analysis results found")
        return [gr.Markdown("No analysis results found.")]
    
    try:
        # Process each prompt
        for prompt, analyses in analysis_results["analyses"].items():
            # Process Topic Modeling analysis if available
            if "topic_modeling" in analyses:
                topic_results = analyses["topic_modeling"]
                
                # Check for errors in the analysis
                if "error" in topic_results:
                    error_msg = topic_results.get("error", "Unknown error in topic modeling")
                    logger.warning(f"Topic modeling error: {error_msg}")
                    output_components.append(gr.Markdown(f"**Error in topic modeling analysis:** {error_msg}"))
                    continue
                
                # Show method and number of topics
                method = topic_results.get("method", "lda").upper()
                n_topics = topic_results.get("n_topics", 3)
                logger.info(f"Creating visualization for {method} with {n_topics} topics")
                
                # Get models being compared
                models = topic_results.get("models", [])
                if not models or len(models) < 2:
                    logger.warning("Not enough models found in results")
                    output_components.append(gr.Markdown("Topic modeling requires at least two models to compare."))
                    continue
                
                output_components.append(gr.Markdown(f"### Topic Modeling Analysis ({method}, {n_topics} topics)"))
                output_components.append(gr.Markdown(f"Comparing responses from **{models[0]}** and **{models[1]}**"))
                
                # Visualize discovered topics
                topics = topic_results.get("topics", [])
                if topics:
                    output_components.append(gr.Markdown("#### Discovered Topics"))
                    
                    # Display topics in a more readable format
                    for i, topic in enumerate(topics):
                        topic_id = topic.get("id", i)
                        words = topic.get("words", [])
                        if words:
                            topic_words = ", ".join(words[:5])  # Show top 5 words
                            output_components.append(gr.Markdown(f"**Topic {topic_id+1}**: {topic_words}"))
                
                # Visualize topic distributions for each model
                model_topics = topic_results.get("model_topics", {})
                if model_topics and all(model in model_topics for model in models):
                    output_components.append(gr.Markdown("#### Topic Distribution by Model"))
                    
                    # Display topic distributions in a readable format
                    for model in models:
                        if model in model_topics:
                            dist = model_topics[model]
                            # Format the distribution
                            dist_str = ", ".join([f"Topic {i+1}: {v:.2f}" for i, v in enumerate(dist[:n_topics])])
                            output_components.append(gr.Markdown(f"**{model}**: {dist_str}"))
                    
                    # Create multi-model topic distribution visualization
                    try:
                        # Prepare data for visualization
                        model_data = []
                        for model in models:
                            if model in model_topics:
                                dist = model_topics[model]
                                for i, weight in enumerate(dist[:n_topics]):
                                    model_data.append({
                                        "Model": model,
                                        "Topic": f"Topic {i+1}",
                                        "Weight": weight
                                    })
                        
                        if model_data:
                            df = pd.DataFrame(model_data)
                            
                            # Create grouped bar chart
                            fig = px.bar(
                                df, 
                                x="Topic", 
                                y="Weight", 
                                color="Model",
                                title="Topic Distribution Comparison",
                                barmode="group",
                                height=400
                            )
                            
                            fig.update_layout(
                                xaxis_title="Topic",
                                yaxis_title="Weight",
                                legend_title="Model"
                            )
                            
                            output_components.append(gr.Plot(value=fig))
                    except Exception as e:
                        logger.error(f"Error creating topic distribution plot: {str(e)}")
                        output_components.append(gr.Markdown(f"*Error creating visualization: {str(e)}*"))
                
                # Display similarity metrics
                comparisons = topic_results.get("comparisons", {})
                if comparisons:
                    output_components.append(gr.Markdown("#### Similarity Metrics"))
                    
                    for comparison_key, comparison_data in comparisons.items():
                        js_div = comparison_data.get("js_divergence", 0)
                        
                        # Jensen-Shannon divergence interpretation
                        similarity_text = ""
                        if js_div < 0.2:
                            similarity_text = "very similar"
                        elif js_div < 0.4:
                            similarity_text = "somewhat similar"
                        elif js_div < 0.6:
                            similarity_text = "moderately different"
                        else:
                            similarity_text = "very different"
                        
                        output_components.append(gr.Markdown(
                            f"**Topic Distribution Divergence**: {js_div:.4f} - Topic distributions are {similarity_text}"
                        ))
                        
                        # Explain what the metric means
                        output_components.append(gr.Markdown(
                            "*Lower divergence values indicate more similar topic distributions between models*"
                        ))
    
    except Exception as e:
        logger.error(f"Error in create_topic_visualization: {str(e)}")
        output_components.append(gr.Markdown(f"**Error creating topic visualization:** {str(e)}"))
    
    # If no components were added, show a message
    if len(output_components) == 0:
        output_components.append(gr.Markdown("No detailed Topic Modeling analysis found in results."))
    
    return output_components


def process_and_visualize_topic_analysis(analysis_results):
    """
    Process the topic modeling analysis results and create visualization components
    
    Args:
        analysis_results (dict): The analysis results
        
    Returns:
        list: List of gradio components for visualization
    """
    try:
        logger.info(f"Starting visualization of topic modeling analysis results")
        # Debug output - print the structure of analysis_results
        if "analyses" in analysis_results:
            for prompt, analyses in analysis_results["analyses"].items():
                if "topic_modeling" in analyses:
                    topic_results = analyses["topic_modeling"]
                    logger.info(f"Found topic_modeling results with keys: {topic_results.keys()}")
                    if "models" in topic_results:
                        logger.info(f"Models: {topic_results['models']}")
                    if "topics" in topic_results:
                        logger.info(f"Found {len(topic_results['topics'])} topics")
                    if "model_topics" in topic_results:
                        logger.info(f"Model_topics keys: {topic_results['model_topics'].keys()}")
        
        return create_topic_visualization(analysis_results)
    except Exception as e:
        import traceback
        error_msg = f"Topic modeling visualization error: {str(e)}\n{traceback.format_exc()}"
        logger.error(error_msg)
        return [gr.Markdown(f"**Error during topic modeling visualization:**\n\n```\n{error_msg}\n```")]