Tialo commited on
Commit
080d211
·
verified ·
1 Parent(s): 05bdd7c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import plotly.graph_objects as go
4
+ import plotly.express as px
5
+ import networkx as nx
6
+ from typing import List, Dict, Any
7
+
8
+
9
+ from langchain_openai.chat_models import ChatOpenAI
10
+ from dialog2graph.pipelines.model_storage import ModelStorage
11
+ from dialog2graph.pipelines.d2g_llm.pipeline import D2GLLMPipeline
12
+ from dialog2graph.pipelines.helpers.parse_data import PipelineRawDataType
13
+
14
+ # Initialize the pipeline
15
+ def initialize_pipeline():
16
+ ms = ModelStorage()
17
+ ms.add(
18
+ "my_filling_model",
19
+ config={"model_name": "gpt-3.5-turbo"},
20
+ model_type=ChatOpenAI,
21
+ )
22
+ return D2GLLMPipeline("d2g_pipeline", model_storage=ms, filling_llm="my_filling_model")
23
+
24
+ def load_dialog_data(json_file: str) -> List[Dict[str, str]]:
25
+ """Load dialog data from JSON file"""
26
+ file_path = f"{json_file}.json"
27
+ try:
28
+ with open(file_path, 'r') as f:
29
+ return json.load(f)
30
+ except FileNotFoundError:
31
+ gr.Error(f"File {file_path} not found!")
32
+ return []
33
+ except json.JSONDecodeError:
34
+ gr.Error(f"Invalid JSON format in {file_path}!")
35
+ return []
36
+
37
+ def create_network_visualization(graph: nx.Graph) -> go.Figure:
38
+ """Create a Plotly network visualization from NetworkX graph"""
39
+
40
+ # Get node positions using spring layout
41
+ pos = nx.spring_layout(graph, k=1, iterations=50)
42
+
43
+ # Extract node and edge information
44
+ node_x = []
45
+ node_y = []
46
+ node_text = []
47
+ node_ids = []
48
+
49
+ for node in graph.nodes():
50
+ x, y = pos[node]
51
+ node_x.append(x)
52
+ node_y.append(y)
53
+
54
+ # Get node attributes if available
55
+ node_attrs = graph.nodes[node]
56
+ node_label = node_attrs.get('label', str(node))
57
+ node_text.append(f"Node {node}<br>{node_label}")
58
+ node_ids.append(node)
59
+
60
+ # Create edge traces
61
+ edge_x = []
62
+ edge_y = []
63
+ edge_info = []
64
+
65
+ for edge in graph.edges():
66
+ x0, y0 = pos[edge[0]]
67
+ x1, y1 = pos[edge[1]]
68
+ edge_x.extend([x0, x1, None])
69
+ edge_y.extend([y0, y1, None])
70
+
71
+ # Get edge attributes if available
72
+ edge_attrs = graph.edges[edge]
73
+ edge_label = edge_attrs.get('label', f"{edge[0]}-{edge[1]}")
74
+ edge_info.append(edge_label)
75
+
76
+ # Create the edge trace
77
+ edge_trace = go.Scatter(
78
+ x=edge_x, y=edge_y,
79
+ line=dict(width=2, color='#888'),
80
+ hoverinfo='none',
81
+ mode='lines'
82
+ )
83
+
84
+ # Create the node trace
85
+ node_trace = go.Scatter(
86
+ x=node_x, y=node_y,
87
+ mode='markers+text',
88
+ hoverinfo='text',
89
+ hovertext=node_text,
90
+ text=[str(node) for node in node_ids],
91
+ textposition="middle center",
92
+ marker=dict(
93
+ size=20,
94
+ line=dict(width=2)
95
+ )
96
+ )
97
+
98
+ # Color nodes by number of connections
99
+ node_adjacencies = []
100
+ for node in graph.nodes():
101
+ node_adjacencies.append(len(list(graph.neighbors(node))))
102
+
103
+ # Update marker color
104
+ node_trace.marker = dict(
105
+ showscale=True,
106
+ colorscale='YlGnBu',
107
+ reversescale=True,
108
+ color=node_adjacencies,
109
+ size=20,
110
+ colorbar=dict(
111
+ thickness=15,
112
+ len=0.5,
113
+ x=1.02,
114
+ title="Node Connections",
115
+ xanchor="left"
116
+ ),
117
+ line=dict(width=2)
118
+ )
119
+
120
+ # Create the figure
121
+ fig = go.Figure(data=[edge_trace, node_trace],
122
+ layout=go.Layout(
123
+ title=dict(
124
+ text='Dialog Graph Visualization',
125
+ font=dict(
126
+ size=16,
127
+ ),
128
+ ),
129
+ showlegend=False,
130
+ hovermode='closest',
131
+ margin=dict(b=20,l=5,r=5,t=40),
132
+ annotations=[ dict(
133
+ text="Hover over nodes for more information",
134
+ showarrow=False,
135
+ xref="paper", yref="paper",
136
+ x=0.005, y=-0.002,
137
+ xanchor='left', yanchor='bottom',
138
+ font=dict(color="#888", size=12)
139
+ )],
140
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
141
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
142
+ plot_bgcolor='white'
143
+ ))
144
+
145
+ return fig
146
+
147
+ def create_chat_visualization(dialog_data: List[Dict[str, str]]) -> str:
148
+ """Create a chat-like visualization of the dialog"""
149
+ chat_html = """
150
+ <div style="max-height: 500px; overflow-y: auto; border: 1px solid #ddd; border-radius: 10px; padding: 20px; background-color: #f9f9f9;">
151
+ """
152
+
153
+ for i, turn in enumerate(dialog_data):
154
+ participant = turn['participant']
155
+ text = turn['text']
156
+
157
+ if participant == 'assistant':
158
+ # Assistant messages on the left with blue background
159
+ chat_html += f"""
160
+ <div style="display: flex; justify-content: flex-start; margin-bottom: 15px;">
161
+ <div style="max-width: 70%; background-color: #e3f2fd; padding: 12px 16px; border-radius: 18px; border-bottom-left-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);">
162
+ <div style="font-weight: bold; color: #1976d2; font-size: 12px; margin-bottom: 4px;">Assistant</div>
163
+ <div style="color: #333; line-height: 1.4;">{text}</div>
164
+ </div>
165
+ </div>
166
+ """
167
+ else:
168
+ # User messages on the right with green background
169
+ chat_html += f"""
170
+ <div style="display: flex; justify-content: flex-end; margin-bottom: 15px;">
171
+ <div style="max-width: 70%; background-color: #e8f5e8; padding: 12px 16px; border-radius: 18px; border-bottom-right-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);">
172
+ <div style="font-weight: bold; color: #388e3c; font-size: 12px; margin-bottom: 4px;">User</div>
173
+ <div style="color: #333; line-height: 1.4;">{text}</div>
174
+ </div>
175
+ </div>
176
+ """
177
+
178
+ chat_html += "</div>"
179
+ return chat_html
180
+
181
+ def process_dialog_and_visualize(dialog_choice: str) -> tuple:
182
+ """Process the selected dialog and create visualization"""
183
+ try:
184
+ # Load the selected dialog data
185
+ dialog_data = load_dialog_data(dialog_choice)
186
+
187
+ if not dialog_data:
188
+ return None, "Failed to load dialog data", ""
189
+
190
+ # Initialize pipeline
191
+ pipe = initialize_pipeline()
192
+
193
+ # Process the data
194
+ data = PipelineRawDataType(dialogs=dialog_data)
195
+ graph, report = pipe.invoke(data)
196
+
197
+ # Create visualization
198
+ fig = create_network_visualization(graph.graph)
199
+
200
+ # Create chat visualization
201
+ chat_viz = create_chat_visualization(dialog_data)
202
+
203
+ # Create summary information
204
+ num_nodes = graph.graph.number_of_nodes()
205
+ num_edges = graph.graph.number_of_edges()
206
+
207
+ summary = f"""
208
+ ## Graph Summary
209
+ - **Number of nodes**: {num_nodes}
210
+ - **Number of edges**: {num_edges}
211
+ - **Dialog turns**: {len(dialog_data)}
212
+
213
+ ## Processing Report
214
+ Generated graph from {len(dialog_data)} dialog turns with {num_nodes} nodes and {num_edges} edges.
215
+ """
216
+
217
+ return fig, summary, chat_viz
218
+
219
+ except Exception as e:
220
+ return None, f"Error processing dialog: {str(e)}", ""
221
+
222
+ # Create the Gradio interface
223
+ def create_gradio_app():
224
+ with gr.Blocks(title="Dialog2Graph Visualizer") as app:
225
+ gr.Markdown("# Dialog2Graph Interactive Visualizer")
226
+ gr.Markdown("Select a dialog dataset to process and visualize as a graph network using Plotly.")
227
+
228
+ with gr.Row():
229
+ with gr.Column(scale=1):
230
+ dialog_selector = gr.Radio(
231
+ choices=["dialog1", "dialog2", "dialog3"],
232
+ label="Select Dialog Dataset",
233
+ value="dialog1",
234
+ info="Choose one of the available dialog datasets"
235
+ )
236
+
237
+ process_btn = gr.Button(
238
+ "Process Dialog & Generate Graph",
239
+ variant="primary",
240
+ size="lg"
241
+ )
242
+
243
+ with gr.Accordion("Dialog Datasets Info", open=False):
244
+ gr.Markdown("""
245
+ - **dialog1**: Hotel booking conversation
246
+ - **dialog2**: Food delivery conversation
247
+ - **dialog3**: Technical support conversation
248
+ """)
249
+
250
+ with gr.Column(scale=3):
251
+ plot_output = gr.Plot(label="Graph Visualization")
252
+
253
+ with gr.Row():
254
+ with gr.Column(scale=1):
255
+ summary_output = gr.Markdown(label="Analysis Summary")
256
+
257
+ with gr.Column(scale=1):
258
+ gr.Markdown("### Dialog Conversation")
259
+ chat_output = gr.HTML(label="Chat Visualization")
260
+
261
+ # Event handlers
262
+ process_btn.click(
263
+ fn=process_dialog_and_visualize,
264
+ inputs=[dialog_selector],
265
+ outputs=[plot_output, summary_output, chat_output]
266
+ )
267
+
268
+ # Auto-process on selection change
269
+ dialog_selector.change(
270
+ fn=process_dialog_and_visualize,
271
+ inputs=[dialog_selector],
272
+ outputs=[plot_output, summary_output, chat_output]
273
+ )
274
+
275
+ return app
276
+
277
+ if __name__ == "__main__":
278
+ app = create_gradio_app()
279
+ app.launch(
280
+ server_name="0.0.0.0",
281
+ server_port=7860,
282
+ share=True,
283
+ debug=True
284
+ )