d2g-demo / app.py
Tialo's picture
Update app.py
4e580b0 verified
raw
history blame
11.4 kB
import json
import gradio as gr
import plotly.graph_objects as go
import networkx as nx
from typing import List, Dict, Optional
from langchain_openai.chat_models import ChatOpenAI
from dialog2graph.pipelines.model_storage import ModelStorage
from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline
from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType
# Initialize the pipeline
def initialize_pipeline():
ms = ModelStorage()
ms.add(
"my_filling_model",
config={"model_name": "gpt-3.5-turbo"},
model_type=ChatOpenAI,
)
return D2GLLMPipeline("d2g_pipeline", model_storage=ms, filling_llm="my_filling_model")
def load_dialog_data(json_file: str, custom_dialog_json: Optional[str] = None) -> List[Dict[str, str]]:
"""Load dialog data from JSON file or custom JSON string"""
if json_file == "custom" and custom_dialog_json:
try:
return json.loads(custom_dialog_json)
except json.JSONDecodeError as e:
gr.Error(f"Invalid JSON format in custom dialog: {str(e)}")
return []
file_path = f"{json_file}.json"
try:
with open(file_path, 'r') as f:
return json.load(f)
except FileNotFoundError:
gr.Error(f"File {file_path} not found!")
return []
except json.JSONDecodeError:
gr.Error(f"Invalid JSON format in {file_path}!")
return []
def create_network_visualization(graph: nx.Graph) -> go.Figure:
"""Create a Plotly network visualization from NetworkX graph"""
# Get node positions using spring layout
pos = nx.spring_layout(graph, k=1, iterations=50)
# Extract node and edge information
node_x = []
node_y = []
node_text = []
node_ids = []
for node in graph.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
# Get node attributes if available
node_attrs = graph.nodes[node]
node_label = node_attrs.get('label', str(node))
node_text.append(f"Node {node}<br>{node_label}")
node_ids.append(node)
# Create edge traces
edge_x = []
edge_y = []
edge_info = []
for edge in graph.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
# Get edge attributes if available
edge_attrs = graph.edges[edge]
edge_label = edge_attrs.get('label', f"{edge[0]}-{edge[1]}")
edge_info.append(edge_label)
# Create the edge trace
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=2, color='#888'),
hoverinfo='none',
mode='lines'
)
# Create the node trace
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
hovertext=node_text,
text=[str(node) for node in node_ids],
textposition="middle center",
marker=dict(
size=20,
line=dict(width=2)
)
)
# Color nodes by number of connections
node_adjacencies = []
for node in graph.nodes():
node_adjacencies.append(len(list(graph.neighbors(node))))
# Update marker color
node_trace.marker = dict(
showscale=True,
colorscale='YlGnBu',
reversescale=True,
color=node_adjacencies,
size=20,
colorbar=dict(
thickness=15,
len=0.5,
x=1.02,
title="Node Connections",
xanchor="left"
),
line=dict(width=2)
)
# Create the figure
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
title=dict(
text='Dialog Graph Visualization',
font=dict(
size=16,
),
),
showlegend=False,
hovermode='closest',
margin=dict(b=20,l=5,r=5,t=40),
annotations=[ dict(
text="Hover over nodes for more information",
showarrow=False,
xref="paper", yref="paper",
x=0.005, y=-0.002,
xanchor='left', yanchor='bottom',
font=dict(color="#888", size=12)
)],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white'
))
return fig
def create_chat_visualization(dialog_data: List[Dict[str, str]]) -> str:
"""Create a chat-like visualization of the dialog"""
chat_html = """
<div style="max-height: 500px; overflow-y: auto; border: 1px solid #ddd; border-radius: 10px; padding: 20px; background-color: #f9f9f9;">
"""
for i, turn in enumerate(dialog_data):
participant = turn['participant']
text = turn['text']
if participant == 'assistant':
# Assistant messages on the left with blue background
chat_html += f"""
<div style="display: flex; justify-content: flex-start; margin-bottom: 15px;">
<div style="max-width: 70%; background-color: #e3f2fd; padding: 12px 16px; border-radius: 18px; border-bottom-left-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);">
<div style="font-weight: bold; color: #1976d2; font-size: 12px; margin-bottom: 4px;">Assistant</div>
<div style="color: #333; line-height: 1.4;">{text}</div>
</div>
</div>
"""
else:
# User messages on the right with green background
chat_html += f"""
<div style="display: flex; justify-content: flex-end; margin-bottom: 15px;">
<div style="max-width: 70%; background-color: #e8f5e8; padding: 12px 16px; border-radius: 18px; border-bottom-right-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);">
<div style="font-weight: bold; color: #388e3c; font-size: 12px; margin-bottom: 4px;">User</div>
<div style="color: #333; line-height: 1.4;">{text}</div>
</div>
</div>
"""
chat_html += "</div>"
return chat_html
def process_dialog_and_visualize(dialog_choice: str, custom_dialog: str = "") -> tuple:
"""Process the selected dialog and create visualization"""
try:
# Load the selected dialog data
dialog_data = load_dialog_data(dialog_choice, custom_dialog if dialog_choice == "custom" else None)
if not dialog_data:
return None, "Failed to load dialog data", ""
# Initialize pipeline
pipe = initialize_pipeline()
# Process the data
data = PipelineRawDataType(dialogs=dialog_data)
graph, report = pipe.invoke(data)
# Create visualization
fig = create_network_visualization(graph.graph)
# Create chat visualization
chat_viz = create_chat_visualization(dialog_data)
# Create summary information
num_nodes = graph.graph.number_of_nodes()
num_edges = graph.graph.number_of_edges()
summary = f"""
## Graph Summary
- **Number of nodes**: {num_nodes}
- **Number of edges**: {num_edges}
- **Dialog turns**: {len(dialog_data)}
## Processing Report
Generated graph from {len(dialog_data)} dialog turns with {num_nodes} nodes and {num_edges} edges.
"""
return fig, summary, chat_viz
except Exception as e:
return None, f"Error processing dialog: {str(e)}", ""
# Create the Gradio interface
def create_gradio_app():
with gr.Blocks(title="Dialog2Graph Visualizer") as app:
gr.Markdown("# Dialog2Graph Interactive Visualizer")
gr.Markdown("Select a dialog dataset to process and visualize as a graph network using Plotly.")
with gr.Row():
with gr.Column(scale=1):
dialog_selector = gr.Radio(
choices=["dialog1", "dialog2", "dialog3", "custom"],
label="Select Dialog Dataset",
value="dialog1",
info="Choose one of the available dialog datasets or use custom JSON"
)
custom_dialog_input = gr.Textbox(
label="Custom Dialog JSON",
placeholder='[{"text": "Hello! How can I help?", "participant": "assistant"}, {"text": "I need assistance", "participant": "user"}]',
lines=8,
visible=False,
info="Enter dialog data as JSON array with 'text' and 'participant' fields"
)
process_btn = gr.Button(
"Process Dialog & Generate Graph",
variant="primary",
size="lg"
)
with gr.Accordion("Dialog Datasets Info", open=False):
gr.Markdown("""
- **dialog1**: Hotel booking conversation
- **dialog2**: Food delivery conversation
- **dialog3**: Technical support conversation
- **custom**: Provide your own dialog as JSON
""")
with gr.Column(scale=3):
plot_output = gr.Plot(label="Graph Visualization")
with gr.Row():
with gr.Column(scale=1):
summary_output = gr.Markdown(label="Analysis Summary")
with gr.Column(scale=1):
gr.Markdown("### Dialog Conversation")
chat_output = gr.HTML(label="Chat Visualization")
# Event handlers
def toggle_custom_input(choice):
return gr.update(visible=(choice == "custom"))
dialog_selector.change(
fn=toggle_custom_input,
inputs=[dialog_selector],
outputs=[custom_dialog_input]
)
process_btn.click(
fn=process_dialog_and_visualize,
inputs=[dialog_selector, custom_dialog_input],
outputs=[plot_output, summary_output, chat_output]
)
# Auto-process on selection change (but not for custom to avoid premature processing)
def auto_process(choice, custom_text):
if choice != "custom":
return process_dialog_and_visualize(choice, custom_text)
else:
return None, "Select 'Process Dialog & Generate Graph' to process custom dialog", ""
dialog_selector.change(
fn=auto_process,
inputs=[dialog_selector, custom_dialog_input],
outputs=[plot_output, summary_output, chat_output]
)
return app
if __name__ == "__main__":
app = create_gradio_app()
app.launch()