d2g-demo / app.py
Tialo's picture
Update app.py
2dad6a6 verified
import json
import gradio as gr
import plotly.graph_objects as go
import networkx as nx
from typing import List, Dict, Optional
from langchain_openai.chat_models import ChatOpenAI
from dialog2graph.pipelines.model_storage import ModelStorage
from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline
from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType
# Initialize the pipeline
def initialize_pipeline():
ms = ModelStorage()
ms.add(
"my_filling_model",
config={"model_name": "gpt-3.5-turbo"},
model_type=ChatOpenAI,
)
return D2GLLMPipeline("d2g_pipeline", model_storage=ms, filling_llm="my_filling_model")
def load_dialog_data(json_file: str, custom_dialog_json: Optional[str] = None) -> List[List[Dict[str, str]]]:
"""Load dialog data from JSON file or custom JSON string and ensure it's in list[list[dict]] format"""
if json_file == "custom" and custom_dialog_json:
try:
data = json.loads(custom_dialog_json)
except json.JSONDecodeError as e:
gr.Error(f"Invalid JSON format in custom dialog: {str(e)}")
return []
else:
file_path = f"{json_file}.json"
try:
with open(file_path, 'r') as f:
data = json.load(f)
except FileNotFoundError:
gr.Error(f"File {file_path} not found!")
return []
except json.JSONDecodeError:
gr.Error(f"Invalid JSON format in {file_path}!")
return []
# Convert to list[list[dict]] format if needed
if not data:
return []
# Check if it's already list[list[dict]]
if isinstance(data, list) and len(data) > 0:
if isinstance(data[0], list):
# Already in list[list[dict]] format
return data
elif isinstance(data[0], dict):
# Convert list[dict] to list[list[dict]]
return [data]
# If it's something else, wrap it in double list
return [data] if isinstance(data, list) else [[data]]
def create_network_visualization(graph: nx.Graph) -> go.Figure:
"""Create a Plotly network visualization from NetworkX graph"""
# Get node positions using spring layout
pos = nx.spring_layout(graph, k=1, iterations=50)
# Extract node and edge information
node_x = []
node_y = []
node_text = []
node_ids = []
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
# Get node attributes if available
node_attrs = graph.nodes[node]
node_label = node_attrs.get('label', str(node))
node_text.append(f"Node {node}<br>{node_label}")
node_ids.append(node)
# Create edge traces
edge_x = []
edge_y = []
edge_info = []
for edge in graph.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
# Get edge attributes if available
edge_attrs = graph.edges[edge]
edge_label = edge_attrs.get('label', f"{edge[0]}-{edge[1]}")
edge_info.append(edge_label)
# Create the edge trace
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='#888'),
hoverinfo='none',
mode='lines'
)
# Create the node trace
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
hovertext=node_text,
text=[str(node) for node in node_ids],
textposition="middle center",
marker=dict(
size=20,
line=dict(width=2)
)
)
# Color nodes by number of connections
node_adjacencies = []
for node in graph.nodes():
node_adjacencies.append(len(list(graph.neighbors(node))))
# Update marker color
node_trace.marker = dict(
showscale=True,
colorscale='YlGnBu',
reversescale=True,
color=node_adjacencies,
size=20,
colorbar=dict(
thickness=15,
len=0.5,
x=1.02,
title="Node Connections",
xanchor="left"
),
line=dict(width=2)
)
# Create the figure
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
title=dict(
text='Dialog Graph Visualization',
font=dict(
size=16,
),
),
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
annotations=[ dict(
text="Hover over nodes for more information",
showarrow=False,
xref="paper", yref="paper",
x=0.005, y=-0.002,
xanchor='left', yanchor='bottom',
font=dict(color="#888", size=12)
)],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white'
))
return fig
def create_chat_visualization(dialog_data: List[Dict[str, str]], dialog_index: int = 0, total_dialogs: int = 1) -> str:
"""Create a chat-like visualization of the dialog with navigation info"""
chat_html = f"""
<div style="margin-bottom: 10px; text-align: center; font-weight: bold; color: #666;">
Dialog {dialog_index + 1} of {total_dialogs}
</div>
<div style="max-height: 500px; overflow-y: auto; border: 1px solid #ddd; border-radius: 10px; padding: 20px; background-color: #f9f9f9;">
"""
for i, turn in enumerate(dialog_data):
participant = turn['participant']
text = turn['text']
if participant == 'assistant':
# Assistant messages on the left with blue background
chat_html += f"""
<div style="display: flex; justify-content: flex-start; margin-bottom: 15px;">
<div style="max-width: 70%; background-color: #e3f2fd; padding: 12px 16px; border-radius: 18px; border-bottom-left-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);">
<div style="font-weight: bold; color: #1976d2; font-size: 12px; margin-bottom: 4px;">Assistant</div>
<div style="color: #333; line-height: 1.4;">{text}</div>
</div>
</div>
"""
else:
# User messages on the right with green background
chat_html += f"""
<div style="display: flex; justify-content: flex-end; margin-bottom: 15px;">
<div style="max-width: 70%; background-color: #e8f5e8; padding: 12px 16px; border-radius: 18px; border-bottom-right-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);">
<div style="font-weight: bold; color: #388e3c; font-size: 12px; margin-bottom: 4px;">User</div>
<div style="color: #333; line-height: 1.4;">{text}</div>
</div>
</div>
"""
chat_html += "</div>"
return chat_html
def process_dialog_and_visualize(dialog_choice: str, custom_dialog: str = "", current_dialog_index: int = 0) -> tuple:
"""Process the selected dialog and create visualization"""
try:
# Load the selected dialog data
dialog_data_list = load_dialog_data(dialog_choice, custom_dialog if dialog_choice == "custom" else None)
if not dialog_data_list:
return None, "Failed to load dialog data", "", 0, len(dialog_data_list), gr.update(visible=False), gr.update(visible=False)
# Ensure current_dialog_index is within bounds
current_dialog_index = max(0, min(current_dialog_index, len(dialog_data_list) - 1))
# Initialize pipeline
pipe = initialize_pipeline()
# Process the data (use all dialogs for graph generation)
data = PipelineRawDataType(dialogs=dialog_data_list)
graph, report = pipe.invoke(data)
# Create visualization
fig = create_network_visualization(graph.graph)
# Create chat visualization for the current dialog
current_dialog = dialog_data_list[current_dialog_index]
chat_viz = create_chat_visualization(current_dialog, current_dialog_index, len(dialog_data_list))
# Create summary information
num_nodes = graph.graph.number_of_nodes()
num_edges = graph.graph.number_of_edges()
total_turns = sum(len(dialog) for dialog in dialog_data_list)
summary = f"""
## Graph Summary
- **Number of nodes**: {num_nodes}
- **Number of edges**: {num_edges}
- **Total dialogs**: {len(dialog_data_list)}
- **Total dialog turns**: {total_turns}
- **Currently viewing**: Dialog {current_dialog_index + 1} ({len(current_dialog)} turns)
## Processing Report
Generated graph from {len(dialog_data_list)} dialog(s) with {total_turns} total turns resulting in {num_nodes} nodes and {num_edges} edges.
"""
# Show navigation buttons only if there are multiple dialogs
show_nav = len(dialog_data_list) > 1
return (fig, summary, chat_viz, current_dialog_index, len(dialog_data_list),
gr.update(visible=show_nav), gr.update(visible=show_nav))
except Exception as e:
return None, f"Error processing dialog: {str(e)}", "", 0, 0, gr.update(visible=False), gr.update(visible=False)
# Create the Gradio interface
def create_gradio_app():
with gr.Blocks(title="Dialog2Graph Visualizer") as app:
gr.Markdown("# Dialog2Graph Interactive Visualizer")
gr.Markdown("Select a dialog dataset to process and visualize as a graph network using Plotly.")
# State variables for dialog navigation
current_dialog_index_state = gr.State(0)
total_dialogs_state = gr.State(0)
with gr.Row():
with gr.Column(scale=1):
dialog_selector = gr.Radio(
choices=["dialog1", "dialog2", "dialog3", "dialog4", "custom"],
label="Select Dialog Dataset",
value="dialog1",
info="Choose one of the available dialog datasets or use custom JSON"
)
custom_dialog_input = gr.Textbox(
label="Custom Dialog JSON",
placeholder='[{"text": "Hello! How can I help?", "participant": "assistant"}, {"text": "I need assistance", "participant": "user"}]',
lines=8,
visible=False,
info="Enter dialog data as JSON array with 'text' and 'participant' fields"
)
process_btn = gr.Button(
"Process Dialog & Generate Graph",
variant="primary",
size="lg"
)
with gr.Accordion("Dialog Datasets Info", open=False):
gr.Markdown("""
- **dialog1**: Hotel booking conversation
- **dialog2**: Food delivery conversation
- **dialog3**: Technical support conversation
- **dialog4**: Multiple dialogs (calendar, support, subscription)
- **custom**: Provide your own dialog as JSON
""")
with gr.Column(scale=3):
plot_output = gr.Plot(label="Graph Visualization")
with gr.Row():
with gr.Column(scale=1):
summary_output = gr.Markdown(label="Analysis Summary")
with gr.Column(scale=1):
gr.Markdown("### Dialog Conversation")
# Navigation controls for multiple dialogs
with gr.Row(visible=False) as nav_row:
prev_btn = gr.Button("← Previous Dialog", size="sm")
next_btn = gr.Button("Next Dialog β†’", size="sm")
chat_output = gr.HTML(label="Chat Visualization")
# Navigation functions
def navigate_dialog(direction: int, current_index: int, total_dialogs: int, dialog_choice: str, custom_dialog: str):
if total_dialogs <= 1:
return current_index, "", ""
new_index = current_index + direction
new_index = max(0, min(new_index, total_dialogs - 1))
try:
dialog_data_list = load_dialog_data(dialog_choice, custom_dialog if dialog_choice == "custom" else None)
if dialog_data_list and new_index < len(dialog_data_list):
current_dialog = dialog_data_list[new_index]
chat_viz = create_chat_visualization(current_dialog, new_index, len(dialog_data_list))
total_turns = sum(len(dialog) for dialog in dialog_data_list)
summary = f"""
## Graph Summary
- **Total dialogs**: {len(dialog_data_list)}
- **Total dialog turns**: {total_turns}
- **Currently viewing**: Dialog {new_index + 1} ({len(current_dialog)} turns)
## Processing Report
Navigate between dialogs to view different conversations.
"""
return new_index, summary, chat_viz
except Exception:
pass
return current_index, "", ""
# Event handlers
def toggle_custom_input(choice):
return gr.update(visible=(choice == "custom"))
dialog_selector.change(
fn=toggle_custom_input,
inputs=[dialog_selector],
outputs=[custom_dialog_input]
)
process_btn.click(
fn=process_dialog_and_visualize,
inputs=[dialog_selector, custom_dialog_input, current_dialog_index_state],
outputs=[plot_output, summary_output, chat_output, current_dialog_index_state, total_dialogs_state, nav_row, nav_row]
)
# Navigation button handlers
prev_btn.click(
fn=lambda curr_idx, total, choice, custom: navigate_dialog(-1, curr_idx, total, choice, custom),
inputs=[current_dialog_index_state, total_dialogs_state, dialog_selector, custom_dialog_input],
outputs=[current_dialog_index_state, summary_output, chat_output]
)
next_btn.click(
fn=lambda curr_idx, total, choice, custom: navigate_dialog(1, curr_idx, total, choice, custom),
inputs=[current_dialog_index_state, total_dialogs_state, dialog_selector, custom_dialog_input],
outputs=[current_dialog_index_state, summary_output, chat_output]
)
# Auto-process on selection change (but not for custom to avoid premature processing)
def auto_process(choice, custom_text, curr_idx):
if choice != "custom":
return process_dialog_and_visualize(choice, custom_text, curr_idx)
else:
return None, "Select 'Process Dialog & Generate Graph' to process custom dialog", "", 0, 0, gr.update(visible=False), gr.update(visible=False)
dialog_selector.change(
fn=auto_process,
inputs=[dialog_selector, custom_dialog_input, current_dialog_index_state],
outputs=[plot_output, summary_output, chat_output, current_dialog_index_state, total_dialogs_state, nav_row, nav_row]
)
return app
if __name__ == "__main__":
app = create_gradio_app()
app.launch()