import json import gradio as gr import plotly.graph_objects as go import networkx as nx from typing import List, Dict, Optional from langchain_openai.chat_models import ChatOpenAI from dialog2graph.pipelines.model_storage import ModelStorage from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType # Initialize the pipeline def initialize_pipeline(): ms = ModelStorage() ms.add( "my_filling_model", config={"model_name": "gpt-3.5-turbo"}, model_type=ChatOpenAI, ) return D2GLLMPipeline("d2g_pipeline", model_storage=ms, filling_llm="my_filling_model") def load_dialog_data(json_file: str, custom_dialog_json: Optional[str] = None) -> List[List[Dict[str, str]]]: """Load dialog data from JSON file or custom JSON string and ensure it's in list[list[dict]] format""" if json_file == "custom" and custom_dialog_json: try: data = json.loads(custom_dialog_json) except json.JSONDecodeError as e: gr.Error(f"Invalid JSON format in custom dialog: {str(e)}") return [] else: file_path = f"{json_file}.json" try: with open(file_path, 'r') as f: data = json.load(f) except FileNotFoundError: gr.Error(f"File {file_path} not found!") return [] except json.JSONDecodeError: gr.Error(f"Invalid JSON format in {file_path}!") return [] # Convert to list[list[dict]] format if needed if not data: return [] # Check if it's already list[list[dict]] if isinstance(data, list) and len(data) > 0: if isinstance(data[0], list): # Already in list[list[dict]] format return data elif isinstance(data[0], dict): # Convert list[dict] to list[list[dict]] return [data] # If it's something else, wrap it in double list return [data] if isinstance(data, list) else [[data]] def create_network_visualization(graph: nx.Graph) -> go.Figure: """Create a Plotly network visualization from NetworkX graph""" # Get node positions using spring layout pos = nx.spring_layout(graph, k=1, iterations=50) # Extract node and edge information node_x = [] node_y = [] node_text = [] node_ids = [] for node in graph.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) # Get node attributes if available node_attrs = graph.nodes[node] node_label = node_attrs.get('label', str(node)) node_text.append(f"Node {node}
{node_label}") node_ids.append(node) # Create edge traces edge_x = [] edge_y = [] edge_info = [] for edge in graph.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) # Get edge attributes if available edge_attrs = graph.edges[edge] edge_label = edge_attrs.get('label', f"{edge[0]}-{edge[1]}") edge_info.append(edge_label) # Create the edge trace edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=2, color='#888'), hoverinfo='none', mode='lines' ) # Create the node trace node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', hoverinfo='text', hovertext=node_text, text=[str(node) for node in node_ids], textposition="middle center", marker=dict( size=20, line=dict(width=2) ) ) # Color nodes by number of connections node_adjacencies = [] for node in graph.nodes(): node_adjacencies.append(len(list(graph.neighbors(node)))) # Update marker color node_trace.marker = dict( showscale=True, colorscale='YlGnBu', reversescale=True, color=node_adjacencies, size=20, colorbar=dict( thickness=15, len=0.5, x=1.02, title="Node Connections", xanchor="left" ), line=dict(width=2) ) # Create the figure fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( title=dict( text='Dialog Graph Visualization', font=dict( size=16, ), ), showlegend=False, hovermode='closest', margin=dict(b=20,l=5,r=5,t=40), annotations=[ dict( text="Hover over nodes for more information", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002, xanchor='left', yanchor='bottom', font=dict(color="#888", size=12) )], xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='white' )) return fig def create_chat_visualization(dialog_data: List[Dict[str, str]], dialog_index: int = 0, total_dialogs: int = 1) -> str: """Create a chat-like visualization of the dialog with navigation info""" chat_html = f"""
Dialog {dialog_index + 1} of {total_dialogs}
""" for i, turn in enumerate(dialog_data): participant = turn['participant'] text = turn['text'] if participant == 'assistant': # Assistant messages on the left with blue background chat_html += f"""
Assistant
{text}
""" else: # User messages on the right with green background chat_html += f"""
User
{text}
""" chat_html += "
" return chat_html def process_dialog_and_visualize(dialog_choice: str, custom_dialog: str = "", current_dialog_index: int = 0) -> tuple: """Process the selected dialog and create visualization""" try: # Load the selected dialog data dialog_data_list = load_dialog_data(dialog_choice, custom_dialog if dialog_choice == "custom" else None) if not dialog_data_list: return None, "Failed to load dialog data", "", 0, len(dialog_data_list), gr.update(visible=False), gr.update(visible=False) # Ensure current_dialog_index is within bounds current_dialog_index = max(0, min(current_dialog_index, len(dialog_data_list) - 1)) # Initialize pipeline pipe = initialize_pipeline() # Process the data (use all dialogs for graph generation) data = PipelineRawDataType(dialogs=dialog_data_list) graph, report = pipe.invoke(data) # Create visualization fig = create_network_visualization(graph.graph) # Create chat visualization for the current dialog current_dialog = dialog_data_list[current_dialog_index] chat_viz = create_chat_visualization(current_dialog, current_dialog_index, len(dialog_data_list)) # Create summary information num_nodes = graph.graph.number_of_nodes() num_edges = graph.graph.number_of_edges() total_turns = sum(len(dialog) for dialog in dialog_data_list) summary = f""" ## Graph Summary - **Number of nodes**: {num_nodes} - **Number of edges**: {num_edges} - **Total dialogs**: {len(dialog_data_list)} - **Total dialog turns**: {total_turns} - **Currently viewing**: Dialog {current_dialog_index + 1} ({len(current_dialog)} turns) ## Processing Report Generated graph from {len(dialog_data_list)} dialog(s) with {total_turns} total turns resulting in {num_nodes} nodes and {num_edges} edges. """ # Show navigation buttons only if there are multiple dialogs show_nav = len(dialog_data_list) > 1 return (fig, summary, chat_viz, current_dialog_index, len(dialog_data_list), gr.update(visible=show_nav), gr.update(visible=show_nav)) except Exception as e: return None, f"Error processing dialog: {str(e)}", "", 0, 0, gr.update(visible=False), gr.update(visible=False) # Create the Gradio interface def create_gradio_app(): with gr.Blocks(title="Dialog2Graph Visualizer") as app: gr.Markdown("# Dialog2Graph Interactive Visualizer") gr.Markdown("Select a dialog dataset to process and visualize as a graph network using Plotly.") # State variables for dialog navigation current_dialog_index_state = gr.State(0) total_dialogs_state = gr.State(0) with gr.Row(): with gr.Column(scale=1): dialog_selector = gr.Radio( choices=["dialog1", "dialog2", "dialog3", "dialog4", "custom"], label="Select Dialog Dataset", value="dialog1", info="Choose one of the available dialog datasets or use custom JSON" ) custom_dialog_input = gr.Textbox( label="Custom Dialog JSON", placeholder='[{"text": "Hello! How can I help?", "participant": "assistant"}, {"text": "I need assistance", "participant": "user"}]', lines=8, visible=False, info="Enter dialog data as JSON array with 'text' and 'participant' fields" ) process_btn = gr.Button( "Process Dialog & Generate Graph", variant="primary", size="lg" ) with gr.Accordion("Dialog Datasets Info", open=False): gr.Markdown(""" - **dialog1**: Hotel booking conversation - **dialog2**: Food delivery conversation - **dialog3**: Technical support conversation - **dialog4**: Multiple dialogs (calendar, support, subscription) - **custom**: Provide your own dialog as JSON """) with gr.Column(scale=3): plot_output = gr.Plot(label="Graph Visualization") with gr.Row(): with gr.Column(scale=1): summary_output = gr.Markdown(label="Analysis Summary") with gr.Column(scale=1): gr.Markdown("### Dialog Conversation") # Navigation controls for multiple dialogs with gr.Row(visible=False) as nav_row: prev_btn = gr.Button("← Previous Dialog", size="sm") next_btn = gr.Button("Next Dialog →", size="sm") chat_output = gr.HTML(label="Chat Visualization") # Navigation functions def navigate_dialog(direction: int, current_index: int, total_dialogs: int, dialog_choice: str, custom_dialog: str): if total_dialogs <= 1: return current_index, "", "" new_index = current_index + direction new_index = max(0, min(new_index, total_dialogs - 1)) try: dialog_data_list = load_dialog_data(dialog_choice, custom_dialog if dialog_choice == "custom" else None) if dialog_data_list and new_index < len(dialog_data_list): current_dialog = dialog_data_list[new_index] chat_viz = create_chat_visualization(current_dialog, new_index, len(dialog_data_list)) total_turns = sum(len(dialog) for dialog in dialog_data_list) summary = f""" ## Graph Summary - **Total dialogs**: {len(dialog_data_list)} - **Total dialog turns**: {total_turns} - **Currently viewing**: Dialog {new_index + 1} ({len(current_dialog)} turns) ## Processing Report Navigate between dialogs to view different conversations. """ return new_index, summary, chat_viz except Exception: pass return current_index, "", "" # Event handlers def toggle_custom_input(choice): return gr.update(visible=(choice == "custom")) dialog_selector.change( fn=toggle_custom_input, inputs=[dialog_selector], outputs=[custom_dialog_input] ) process_btn.click( fn=process_dialog_and_visualize, inputs=[dialog_selector, custom_dialog_input, current_dialog_index_state], outputs=[plot_output, summary_output, chat_output, current_dialog_index_state, total_dialogs_state, nav_row, nav_row] ) # Navigation button handlers prev_btn.click( fn=lambda curr_idx, total, choice, custom: navigate_dialog(-1, curr_idx, total, choice, custom), inputs=[current_dialog_index_state, total_dialogs_state, dialog_selector, custom_dialog_input], outputs=[current_dialog_index_state, summary_output, chat_output] ) next_btn.click( fn=lambda curr_idx, total, choice, custom: navigate_dialog(1, curr_idx, total, choice, custom), inputs=[current_dialog_index_state, total_dialogs_state, dialog_selector, custom_dialog_input], outputs=[current_dialog_index_state, summary_output, chat_output] ) # Auto-process on selection change (but not for custom to avoid premature processing) def auto_process(choice, custom_text, curr_idx): if choice != "custom": return process_dialog_and_visualize(choice, custom_text, curr_idx) else: return None, "Select 'Process Dialog & Generate Graph' to process custom dialog", "", 0, 0, gr.update(visible=False), gr.update(visible=False) dialog_selector.change( fn=auto_process, inputs=[dialog_selector, custom_dialog_input, current_dialog_index_state], outputs=[plot_output, summary_output, chat_output, current_dialog_index_state, total_dialogs_state, nav_row, nav_row] ) return app if __name__ == "__main__": app = create_gradio_app() app.launch()