baconnier commited on
Commit
f53ad88
·
verified ·
1 Parent(s): 3d2805d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -95
app.py CHANGED
@@ -1,4 +1,8 @@
1
  import gradio as gr
 
 
 
 
2
  import json
3
  import os
4
  from dotenv import load_dotenv
@@ -7,6 +11,8 @@ from typing import Dict, Any, Optional, Union
7
  from datetime import datetime
8
  import networkx as nx
9
  import plotly.graph_objects as go
 
 
10
 
11
  # Load environment variables
12
  load_dotenv()
@@ -21,79 +27,83 @@ except Exception as e:
21
  print(f"Error initializing ExplorationPathGenerator: {e}")
22
  raise
23
 
24
- def create_interactive_graph(nodes):
25
- """Create an interactive graph visualization using plotly"""
26
- G = nx.DiGraph()
27
-
28
- # Add nodes to the graph
29
- node_x = []
30
- node_y = []
31
- node_text = []
32
- node_color = []
33
- node_size = []
34
-
35
- # Create layout
36
- pos = nx.spring_layout(G, k=1, iterations=50)
37
-
38
- for node in nodes:
39
- G.add_node(node['id'])
40
- x, y = pos[node['id']]
41
- node_x.append(x)
42
- node_y.append(y)
43
- node_text.append(f"Title: {node['title']}<br>Description: {node['description']}")
44
- node_color.append(['#FF9999', '#99FF99', '#9999FF', '#FFFF99'][node['depth'] % 4])
45
- node_size.append(30) # Base size for all nodes
46
-
47
- edge_x = []
48
- edge_y = []
49
- edge_text = []
50
-
51
- # Add edges
52
- for node in nodes:
53
- for conn in node['connections']:
54
- G.add_edge(node['id'], conn['target_id'])
55
- x0, y0 = pos[node['id']]
56
- x1, y1 = pos[conn['target_id']]
57
- edge_x.extend([x0, x1, None])
58
- edge_y.extend([y0, y1, None])
59
- edge_text.append(f"Score: {conn.get('relevance_score', 'N/A')}")
60
-
61
- # Create the interactive plot
62
- fig = go.Figure()
63
-
64
- # Add edges
65
- fig.add_trace(go.Scatter(
66
- x=edge_x, y=edge_y,
67
- line=dict(width=0.5, color='#888'),
68
- hoverinfo='none',
69
- mode='lines'
70
- ))
71
 
72
- # Add nodes
73
- fig.add_trace(go.Scatter(
74
- x=node_x, y=node_y,
75
- mode='markers+text',
76
- hoverinfo='text',
77
- text=[node['title'] for node in nodes],
78
- hovertext=node_text,
79
- marker=dict(
80
- showscale=False,
81
- color=node_color,
82
- size=node_size,
83
- line_width=2
84
  )
85
- ))
86
-
87
- fig.update_layout(
88
- showlegend=False,
89
- hovermode='closest',
90
- margin=dict(b=0, l=0, r=0, t=0),
91
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
92
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
93
- plot_bgcolor='white'
94
- )
95
 
96
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def summarize_node(node_data: Dict[str, Any]) -> str:
99
  """Generate a summary for a selected node"""
@@ -112,22 +122,17 @@ Connections:
112
 
113
  return summary
114
 
115
- def handle_node_click(node_id: str, current_result: Dict[str, Any]) -> tuple[go.Figure, str]:
116
  """Handle node click event and update visualization"""
117
  nodes = current_result.get('nodes', [])
118
  selected_node = next((node for node in nodes if node['id'] == node_id), None)
119
 
120
  if selected_node:
121
- # Update graph to highlight selected node and its connections
122
- fig = create_interactive_graph(nodes) # Create new graph
123
  summary = summarize_node(selected_node)
124
  else:
125
- fig = create_interactive_graph(nodes) # Create default graph
126
  summary = "No node selected"
127
 
128
- return fig, summary
129
-
130
- # [Previous helper functions remain the same: format_output, parse_json_input, validate_parameters]
131
 
132
  def explore(
133
  query: str,
@@ -135,11 +140,16 @@ def explore(
135
  parameters: str = "{}",
136
  depth: int = 5,
137
  domain: str = ""
138
- ) -> tuple[str, go.Figure, str]:
139
  """Generate exploration path and visualization"""
140
  try:
141
- # [Previous validation and processing code remains the same]
142
-
 
 
 
 
 
143
  result = generator.generate_exploration_path(
144
  query=query,
145
  selected_path=selected_path,
@@ -149,13 +159,10 @@ def explore(
149
  if not isinstance(result, dict):
150
  raise ValueError("Invalid response format from generator")
151
 
152
- # Create interactive graph
153
- graph = create_interactive_graph(result.get('nodes', []))
154
-
155
  # Initial summary (no node selected)
156
  summary = "Click on a node to see details"
157
 
158
- return format_output(result), graph, summary
159
 
160
  except Exception as e:
161
  error_response = {
@@ -169,10 +176,14 @@ def explore(
169
  }
170
  }
171
  print(f"Error in explore function: {e}")
172
- return format_output(error_response), None, "Error occurred"
173
 
174
  def create_interface() -> gr.Blocks:
175
  """Create and configure the Gradio interface"""
 
 
 
 
176
 
177
  with gr.Blocks(title="Art History Exploration Path Generator") as interface:
178
  gr.Markdown("# Art History Exploration Path Generator")
@@ -214,7 +225,9 @@ def create_interface() -> gr.Blocks:
214
  with gr.Accordion("Exploration Result", open=False):
215
  text_output = gr.JSON(label="Raw Result")
216
 
217
- graph_output = gr.Plot(label="Interactive Exploration Graph")
 
 
218
  node_summary = gr.Textbox(
219
  label="Node Details",
220
  lines=5,
@@ -224,17 +237,19 @@ def create_interface() -> gr.Blocks:
224
  # Store current result for node clicking
225
  current_result = gr.State({})
226
 
 
 
 
 
 
 
 
 
 
227
  generate_btn.click(
228
- fn=explore,
229
  inputs=[query_input, path_history, parameters, depth, domain],
230
- outputs=[text_output, graph_output, node_summary]
231
- )
232
-
233
- # Handle node clicks
234
- graph_output.click(
235
- fn=handle_node_click,
236
- inputs=["node_id", current_result],
237
- outputs=[graph_output, node_summary]
238
  )
239
 
240
  # Examples
 
1
  import gradio as gr
2
+ import dash
3
+ from dash import html, dcc
4
+ import dash_cytoscape as cyto
5
+ from dash.dependencies import Input, Output, State
6
  import json
7
  import os
8
  from dotenv import load_dotenv
 
11
  from datetime import datetime
12
  import networkx as nx
13
  import plotly.graph_objects as go
14
+ import threading
15
+ from flask import Flask
16
 
17
  # Load environment variables
18
  load_dotenv()
 
27
  print(f"Error initializing ExplorationPathGenerator: {e}")
28
  raise
29
 
30
+ class GraphVisualizer:
31
+ def __init__(self, port=8050):
32
+ self.app = dash.Dash(__name__)
33
+ self.port = port
34
+ cyto.load_extra_layouts()
35
+
36
+ self.app.layout = html.Div([
37
+ cyto.Cytoscape(
38
+ id='exploration-graph',
39
+ layout={'name': 'cose'},
40
+ style={'width': '100%', 'height': '600px'},
41
+ elements=[],
42
+ stylesheet=[
43
+ {
44
+ 'selector': 'node',
45
+ 'style': {
46
+ 'label': 'data(label)',
47
+ 'background-color': '#6FB1FC',
48
+ 'content': 'data(label)',
49
+ 'text-wrap': 'wrap',
50
+ 'text-max-width': '100px'
51
+ }
52
+ },
53
+ {
54
+ 'selector': 'edge',
55
+ 'style': {
56
+ 'curve-style': 'bezier',
57
+ 'target-arrow-shape': 'triangle',
58
+ 'line-color': '#ccc',
59
+ 'target-arrow-color': '#ccc',
60
+ 'label': 'data(label)',
61
+ 'font-size': '10px'
62
+ }
63
+ }
64
+ ]
65
+ ),
66
+ html.Div(id='node-details', style={'padding': '20px'})
67
+ ])
68
+
69
+ self._setup_callbacks()
 
 
 
 
 
 
 
70
 
71
+ def _setup_callbacks(self):
72
+ @self.app.callback(
73
+ Output('node-details', 'children'),
74
+ [Input('exploration-graph', 'tapNodeData')]
 
 
 
 
 
 
 
 
75
  )
76
+ def display_node_details(node_data):
77
+ if not node_data:
78
+ return "Click a node to see details"
79
+ return html.Div([
80
+ html.H4(node_data['label']),
81
+ html.P(node_data.get('description', 'No description available'))
82
+ ])
 
 
 
83
 
84
+ def update_graph(self, nodes):
85
+ elements = []
86
+ for node in nodes:
87
+ elements.append({
88
+ 'data': {
89
+ 'id': node['id'],
90
+ 'label': node['title'],
91
+ 'description': node['description'],
92
+ 'depth': node['depth']
93
+ }
94
+ })
95
+ for conn in node['connections']:
96
+ elements.append({
97
+ 'data': {
98
+ 'source': node['id'],
99
+ 'target': conn['target_id'],
100
+ 'label': f"Score: {conn.get('relevance_score', 'N/A')}"
101
+ }
102
+ })
103
+ return elements
104
+
105
+ def run(self):
106
+ self.app.run_server(port=self.port, debug=False)
107
 
108
  def summarize_node(node_data: Dict[str, Any]) -> str:
109
  """Generate a summary for a selected node"""
 
122
 
123
  return summary
124
 
125
+ def handle_node_click(node_id: str, current_result: Dict[str, Any]) -> tuple[str, str]:
126
  """Handle node click event and update visualization"""
127
  nodes = current_result.get('nodes', [])
128
  selected_node = next((node for node in nodes if node['id'] == node_id), None)
129
 
130
  if selected_node:
 
 
131
  summary = summarize_node(selected_node)
132
  else:
 
133
  summary = "No node selected"
134
 
135
+ return current_result, summary
 
 
136
 
137
  def explore(
138
  query: str,
 
140
  parameters: str = "{}",
141
  depth: int = 5,
142
  domain: str = ""
143
+ ) -> tuple[str, str]:
144
  """Generate exploration path and visualization"""
145
  try:
146
+ # Parse inputs
147
+ try:
148
+ selected_path = json.loads(path_history)
149
+ exploration_parameters = json.loads(parameters)
150
+ except json.JSONDecodeError as e:
151
+ raise ValueError(f"Invalid JSON input: {str(e)}")
152
+
153
  result = generator.generate_exploration_path(
154
  query=query,
155
  selected_path=selected_path,
 
159
  if not isinstance(result, dict):
160
  raise ValueError("Invalid response format from generator")
161
 
 
 
 
162
  # Initial summary (no node selected)
163
  summary = "Click on a node to see details"
164
 
165
+ return json.dumps(result), summary
166
 
167
  except Exception as e:
168
  error_response = {
 
176
  }
177
  }
178
  print(f"Error in explore function: {e}")
179
+ return json.dumps(error_response), "Error occurred"
180
 
181
  def create_interface() -> gr.Blocks:
182
  """Create and configure the Gradio interface"""
183
+ visualizer = GraphVisualizer()
184
+ # Start Dash server in a separate thread
185
+ dash_thread = threading.Thread(target=visualizer.run, daemon=True)
186
+ dash_thread.start()
187
 
188
  with gr.Blocks(title="Art History Exploration Path Generator") as interface:
189
  gr.Markdown("# Art History Exploration Path Generator")
 
225
  with gr.Accordion("Exploration Result", open=False):
226
  text_output = gr.JSON(label="Raw Result")
227
 
228
+ # Embed Dash visualization
229
+ gr.HTML(f'<iframe src="http://localhost:{visualizer.port}" style="width:100%; height:700px; border:none;"></iframe>')
230
+
231
  node_summary = gr.Textbox(
232
  label="Node Details",
233
  lines=5,
 
237
  # Store current result for node clicking
238
  current_result = gr.State({})
239
 
240
+ def explore_and_visualize(*args):
241
+ result, summary = explore(*args)
242
+ if isinstance(result, str):
243
+ result_dict = json.loads(result)
244
+ else:
245
+ result_dict = result
246
+ visualizer.update_graph(result_dict.get('nodes', []))
247
+ return result, summary
248
+
249
  generate_btn.click(
250
+ fn=explore_and_visualize,
251
  inputs=[query_input, path_history, parameters, depth, domain],
252
+ outputs=[text_output, node_summary]
 
 
 
 
 
 
 
253
  )
254
 
255
  # Examples