Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import gradio as gr
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import plotly.express as px
|
5 |
+
import networkx as nx
|
6 |
+
from typing import List, Dict, Any
|
7 |
+
|
8 |
+
|
9 |
+
from langchain_openai.chat_models import ChatOpenAI
|
10 |
+
from dialog2graph.pipelines.model_storage import ModelStorage
|
11 |
+
from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline
|
12 |
+
from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType
|
13 |
+
|
14 |
+
# Initialize the pipeline
|
15 |
+
def initialize_pipeline():
|
16 |
+
ms = ModelStorage()
|
17 |
+
ms.add(
|
18 |
+
"my_filling_model",
|
19 |
+
config={"model_name": "gpt-3.5-turbo"},
|
20 |
+
model_type=ChatOpenAI,
|
21 |
+
)
|
22 |
+
return D2GLLMPipeline("d2g_pipeline", model_storage=ms, filling_llm="my_filling_model")
|
23 |
+
|
24 |
+
def load_dialog_data(json_file: str) -> List[Dict[str, str]]:
|
25 |
+
"""Load dialog data from JSON file"""
|
26 |
+
file_path = f"{json_file}.json"
|
27 |
+
try:
|
28 |
+
with open(file_path, 'r') as f:
|
29 |
+
return json.load(f)
|
30 |
+
except FileNotFoundError:
|
31 |
+
gr.Error(f"File {file_path} not found!")
|
32 |
+
return []
|
33 |
+
except json.JSONDecodeError:
|
34 |
+
gr.Error(f"Invalid JSON format in {file_path}!")
|
35 |
+
return []
|
36 |
+
|
37 |
+
def create_network_visualization(graph: nx.Graph) -> go.Figure:
|
38 |
+
"""Create a Plotly network visualization from NetworkX graph"""
|
39 |
+
|
40 |
+
# Get node positions using spring layout
|
41 |
+
pos = nx.spring_layout(graph, k=1, iterations=50)
|
42 |
+
|
43 |
+
# Extract node and edge information
|
44 |
+
node_x = []
|
45 |
+
node_y = []
|
46 |
+
node_text = []
|
47 |
+
node_ids = []
|
48 |
+
|
49 |
+
for node in graph.nodes():
|
50 |
+
x, y = pos[node]
|
51 |
+
node_x.append(x)
|
52 |
+
node_y.append(y)
|
53 |
+
|
54 |
+
# Get node attributes if available
|
55 |
+
node_attrs = graph.nodes[node]
|
56 |
+
node_label = node_attrs.get('label', str(node))
|
57 |
+
node_text.append(f"Node {node}<br>{node_label}")
|
58 |
+
node_ids.append(node)
|
59 |
+
|
60 |
+
# Create edge traces
|
61 |
+
edge_x = []
|
62 |
+
edge_y = []
|
63 |
+
edge_info = []
|
64 |
+
|
65 |
+
for edge in graph.edges():
|
66 |
+
x0, y0 = pos[edge[0]]
|
67 |
+
x1, y1 = pos[edge[1]]
|
68 |
+
edge_x.extend([x0, x1, None])
|
69 |
+
edge_y.extend([y0, y1, None])
|
70 |
+
|
71 |
+
# Get edge attributes if available
|
72 |
+
edge_attrs = graph.edges[edge]
|
73 |
+
edge_label = edge_attrs.get('label', f"{edge[0]}-{edge[1]}")
|
74 |
+
edge_info.append(edge_label)
|
75 |
+
|
76 |
+
# Create the edge trace
|
77 |
+
edge_trace = go.Scatter(
|
78 |
+
x=edge_x, y=edge_y,
|
79 |
+
line=dict(width=2, color='#888'),
|
80 |
+
hoverinfo='none',
|
81 |
+
mode='lines'
|
82 |
+
)
|
83 |
+
|
84 |
+
# Create the node trace
|
85 |
+
node_trace = go.Scatter(
|
86 |
+
x=node_x, y=node_y,
|
87 |
+
mode='markers+text',
|
88 |
+
hoverinfo='text',
|
89 |
+
hovertext=node_text,
|
90 |
+
text=[str(node) for node in node_ids],
|
91 |
+
textposition="middle center",
|
92 |
+
marker=dict(
|
93 |
+
size=20,
|
94 |
+
line=dict(width=2)
|
95 |
+
)
|
96 |
+
)
|
97 |
+
|
98 |
+
# Color nodes by number of connections
|
99 |
+
node_adjacencies = []
|
100 |
+
for node in graph.nodes():
|
101 |
+
node_adjacencies.append(len(list(graph.neighbors(node))))
|
102 |
+
|
103 |
+
# Update marker color
|
104 |
+
node_trace.marker = dict(
|
105 |
+
showscale=True,
|
106 |
+
colorscale='YlGnBu',
|
107 |
+
reversescale=True,
|
108 |
+
color=node_adjacencies,
|
109 |
+
size=20,
|
110 |
+
colorbar=dict(
|
111 |
+
thickness=15,
|
112 |
+
len=0.5,
|
113 |
+
x=1.02,
|
114 |
+
title="Node Connections",
|
115 |
+
xanchor="left"
|
116 |
+
),
|
117 |
+
line=dict(width=2)
|
118 |
+
)
|
119 |
+
|
120 |
+
# Create the figure
|
121 |
+
fig = go.Figure(data=[edge_trace, node_trace],
|
122 |
+
layout=go.Layout(
|
123 |
+
title=dict(
|
124 |
+
text='Dialog Graph Visualization',
|
125 |
+
font=dict(
|
126 |
+
size=16,
|
127 |
+
),
|
128 |
+
),
|
129 |
+
showlegend=False,
|
130 |
+
hovermode='closest',
|
131 |
+
margin=dict(b=20,l=5,r=5,t=40),
|
132 |
+
annotations=[ dict(
|
133 |
+
text="Hover over nodes for more information",
|
134 |
+
showarrow=False,
|
135 |
+
xref="paper", yref="paper",
|
136 |
+
x=0.005, y=-0.002,
|
137 |
+
xanchor='left', yanchor='bottom',
|
138 |
+
font=dict(color="#888", size=12)
|
139 |
+
)],
|
140 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
141 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
142 |
+
plot_bgcolor='white'
|
143 |
+
))
|
144 |
+
|
145 |
+
return fig
|
146 |
+
|
147 |
+
def create_chat_visualization(dialog_data: List[Dict[str, str]]) -> str:
|
148 |
+
"""Create a chat-like visualization of the dialog"""
|
149 |
+
chat_html = """
|
150 |
+
<div style="max-height: 500px; overflow-y: auto; border: 1px solid #ddd; border-radius: 10px; padding: 20px; background-color: #f9f9f9;">
|
151 |
+
"""
|
152 |
+
|
153 |
+
for i, turn in enumerate(dialog_data):
|
154 |
+
participant = turn['participant']
|
155 |
+
text = turn['text']
|
156 |
+
|
157 |
+
if participant == 'assistant':
|
158 |
+
# Assistant messages on the left with blue background
|
159 |
+
chat_html += f"""
|
160 |
+
<div style="display: flex; justify-content: flex-start; margin-bottom: 15px;">
|
161 |
+
<div style="max-width: 70%; background-color: #e3f2fd; padding: 12px 16px; border-radius: 18px; border-bottom-left-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);">
|
162 |
+
<div style="font-weight: bold; color: #1976d2; font-size: 12px; margin-bottom: 4px;">Assistant</div>
|
163 |
+
<div style="color: #333; line-height: 1.4;">{text}</div>
|
164 |
+
</div>
|
165 |
+
</div>
|
166 |
+
"""
|
167 |
+
else:
|
168 |
+
# User messages on the right with green background
|
169 |
+
chat_html += f"""
|
170 |
+
<div style="display: flex; justify-content: flex-end; margin-bottom: 15px;">
|
171 |
+
<div style="max-width: 70%; background-color: #e8f5e8; padding: 12px 16px; border-radius: 18px; border-bottom-right-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);">
|
172 |
+
<div style="font-weight: bold; color: #388e3c; font-size: 12px; margin-bottom: 4px;">User</div>
|
173 |
+
<div style="color: #333; line-height: 1.4;">{text}</div>
|
174 |
+
</div>
|
175 |
+
</div>
|
176 |
+
"""
|
177 |
+
|
178 |
+
chat_html += "</div>"
|
179 |
+
return chat_html
|
180 |
+
|
181 |
+
def process_dialog_and_visualize(dialog_choice: str) -> tuple:
|
182 |
+
"""Process the selected dialog and create visualization"""
|
183 |
+
try:
|
184 |
+
# Load the selected dialog data
|
185 |
+
dialog_data = load_dialog_data(dialog_choice)
|
186 |
+
|
187 |
+
if not dialog_data:
|
188 |
+
return None, "Failed to load dialog data", ""
|
189 |
+
|
190 |
+
# Initialize pipeline
|
191 |
+
pipe = initialize_pipeline()
|
192 |
+
|
193 |
+
# Process the data
|
194 |
+
data = PipelineRawDataType(dialogs=dialog_data)
|
195 |
+
graph, report = pipe.invoke(data)
|
196 |
+
|
197 |
+
# Create visualization
|
198 |
+
fig = create_network_visualization(graph.graph)
|
199 |
+
|
200 |
+
# Create chat visualization
|
201 |
+
chat_viz = create_chat_visualization(dialog_data)
|
202 |
+
|
203 |
+
# Create summary information
|
204 |
+
num_nodes = graph.graph.number_of_nodes()
|
205 |
+
num_edges = graph.graph.number_of_edges()
|
206 |
+
|
207 |
+
summary = f"""
|
208 |
+
## Graph Summary
|
209 |
+
- **Number of nodes**: {num_nodes}
|
210 |
+
- **Number of edges**: {num_edges}
|
211 |
+
- **Dialog turns**: {len(dialog_data)}
|
212 |
+
|
213 |
+
## Processing Report
|
214 |
+
Generated graph from {len(dialog_data)} dialog turns with {num_nodes} nodes and {num_edges} edges.
|
215 |
+
"""
|
216 |
+
|
217 |
+
return fig, summary, chat_viz
|
218 |
+
|
219 |
+
except Exception as e:
|
220 |
+
return None, f"Error processing dialog: {str(e)}", ""
|
221 |
+
|
222 |
+
# Create the Gradio interface
|
223 |
+
def create_gradio_app():
|
224 |
+
with gr.Blocks(title="Dialog2Graph Visualizer") as app:
|
225 |
+
gr.Markdown("# Dialog2Graph Interactive Visualizer")
|
226 |
+
gr.Markdown("Select a dialog dataset to process and visualize as a graph network using Plotly.")
|
227 |
+
|
228 |
+
with gr.Row():
|
229 |
+
with gr.Column(scale=1):
|
230 |
+
dialog_selector = gr.Radio(
|
231 |
+
choices=["dialog1", "dialog2", "dialog3"],
|
232 |
+
label="Select Dialog Dataset",
|
233 |
+
value="dialog1",
|
234 |
+
info="Choose one of the available dialog datasets"
|
235 |
+
)
|
236 |
+
|
237 |
+
process_btn = gr.Button(
|
238 |
+
"Process Dialog & Generate Graph",
|
239 |
+
variant="primary",
|
240 |
+
size="lg"
|
241 |
+
)
|
242 |
+
|
243 |
+
with gr.Accordion("Dialog Datasets Info", open=False):
|
244 |
+
gr.Markdown("""
|
245 |
+
- **dialog1**: Hotel booking conversation
|
246 |
+
- **dialog2**: Food delivery conversation
|
247 |
+
- **dialog3**: Technical support conversation
|
248 |
+
""")
|
249 |
+
|
250 |
+
with gr.Column(scale=3):
|
251 |
+
plot_output = gr.Plot(label="Graph Visualization")
|
252 |
+
|
253 |
+
with gr.Row():
|
254 |
+
with gr.Column(scale=1):
|
255 |
+
summary_output = gr.Markdown(label="Analysis Summary")
|
256 |
+
|
257 |
+
with gr.Column(scale=1):
|
258 |
+
gr.Markdown("### Dialog Conversation")
|
259 |
+
chat_output = gr.HTML(label="Chat Visualization")
|
260 |
+
|
261 |
+
# Event handlers
|
262 |
+
process_btn.click(
|
263 |
+
fn=process_dialog_and_visualize,
|
264 |
+
inputs=[dialog_selector],
|
265 |
+
outputs=[plot_output, summary_output, chat_output]
|
266 |
+
)
|
267 |
+
|
268 |
+
# Auto-process on selection change
|
269 |
+
dialog_selector.change(
|
270 |
+
fn=process_dialog_and_visualize,
|
271 |
+
inputs=[dialog_selector],
|
272 |
+
outputs=[plot_output, summary_output, chat_output]
|
273 |
+
)
|
274 |
+
|
275 |
+
return app
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
app = create_gradio_app()
|
279 |
+
app.launch(
|
280 |
+
server_name="0.0.0.0",
|
281 |
+
server_port=7860,
|
282 |
+
share=True,
|
283 |
+
debug=True
|
284 |
+
)
|