Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
import json
|
3 |
import os
|
4 |
from dotenv import load_dotenv
|
@@ -7,6 +11,8 @@ from typing import Dict, Any, Optional, Union
|
|
7 |
from datetime import datetime
|
8 |
import networkx as nx
|
9 |
import plotly.graph_objects as go
|
|
|
|
|
10 |
|
11 |
# Load environment variables
|
12 |
load_dotenv()
|
@@ -21,79 +27,83 @@ except Exception as e:
|
|
21 |
print(f"Error initializing ExplorationPathGenerator: {e}")
|
22 |
raise
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
# Add edges
|
65 |
-
fig.add_trace(go.Scatter(
|
66 |
-
x=edge_x, y=edge_y,
|
67 |
-
line=dict(width=0.5, color='#888'),
|
68 |
-
hoverinfo='none',
|
69 |
-
mode='lines'
|
70 |
-
))
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
hoverinfo='text',
|
77 |
-
text=[node['title'] for node in nodes],
|
78 |
-
hovertext=node_text,
|
79 |
-
marker=dict(
|
80 |
-
showscale=False,
|
81 |
-
color=node_color,
|
82 |
-
size=node_size,
|
83 |
-
line_width=2
|
84 |
)
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
93 |
-
plot_bgcolor='white'
|
94 |
-
)
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def summarize_node(node_data: Dict[str, Any]) -> str:
|
99 |
"""Generate a summary for a selected node"""
|
@@ -112,22 +122,17 @@ Connections:
|
|
112 |
|
113 |
return summary
|
114 |
|
115 |
-
def handle_node_click(node_id: str, current_result: Dict[str, Any]) -> tuple[
|
116 |
"""Handle node click event and update visualization"""
|
117 |
nodes = current_result.get('nodes', [])
|
118 |
selected_node = next((node for node in nodes if node['id'] == node_id), None)
|
119 |
|
120 |
if selected_node:
|
121 |
-
# Update graph to highlight selected node and its connections
|
122 |
-
fig = create_interactive_graph(nodes) # Create new graph
|
123 |
summary = summarize_node(selected_node)
|
124 |
else:
|
125 |
-
fig = create_interactive_graph(nodes) # Create default graph
|
126 |
summary = "No node selected"
|
127 |
|
128 |
-
return
|
129 |
-
|
130 |
-
# [Previous helper functions remain the same: format_output, parse_json_input, validate_parameters]
|
131 |
|
132 |
def explore(
|
133 |
query: str,
|
@@ -135,11 +140,16 @@ def explore(
|
|
135 |
parameters: str = "{}",
|
136 |
depth: int = 5,
|
137 |
domain: str = ""
|
138 |
-
) -> tuple[str,
|
139 |
"""Generate exploration path and visualization"""
|
140 |
try:
|
141 |
-
#
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
143 |
result = generator.generate_exploration_path(
|
144 |
query=query,
|
145 |
selected_path=selected_path,
|
@@ -149,13 +159,10 @@ def explore(
|
|
149 |
if not isinstance(result, dict):
|
150 |
raise ValueError("Invalid response format from generator")
|
151 |
|
152 |
-
# Create interactive graph
|
153 |
-
graph = create_interactive_graph(result.get('nodes', []))
|
154 |
-
|
155 |
# Initial summary (no node selected)
|
156 |
summary = "Click on a node to see details"
|
157 |
|
158 |
-
return
|
159 |
|
160 |
except Exception as e:
|
161 |
error_response = {
|
@@ -169,10 +176,14 @@ def explore(
|
|
169 |
}
|
170 |
}
|
171 |
print(f"Error in explore function: {e}")
|
172 |
-
return
|
173 |
|
174 |
def create_interface() -> gr.Blocks:
|
175 |
"""Create and configure the Gradio interface"""
|
|
|
|
|
|
|
|
|
176 |
|
177 |
with gr.Blocks(title="Art History Exploration Path Generator") as interface:
|
178 |
gr.Markdown("# Art History Exploration Path Generator")
|
@@ -214,7 +225,9 @@ def create_interface() -> gr.Blocks:
|
|
214 |
with gr.Accordion("Exploration Result", open=False):
|
215 |
text_output = gr.JSON(label="Raw Result")
|
216 |
|
217 |
-
|
|
|
|
|
218 |
node_summary = gr.Textbox(
|
219 |
label="Node Details",
|
220 |
lines=5,
|
@@ -224,17 +237,19 @@ def create_interface() -> gr.Blocks:
|
|
224 |
# Store current result for node clicking
|
225 |
current_result = gr.State({})
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
generate_btn.click(
|
228 |
-
fn=
|
229 |
inputs=[query_input, path_history, parameters, depth, domain],
|
230 |
-
outputs=[text_output,
|
231 |
-
)
|
232 |
-
|
233 |
-
# Handle node clicks
|
234 |
-
graph_output.click(
|
235 |
-
fn=handle_node_click,
|
236 |
-
inputs=["node_id", current_result],
|
237 |
-
outputs=[graph_output, node_summary]
|
238 |
)
|
239 |
|
240 |
# Examples
|
|
|
1 |
import gradio as gr
|
2 |
+
import dash
|
3 |
+
from dash import html, dcc
|
4 |
+
import dash_cytoscape as cyto
|
5 |
+
from dash.dependencies import Input, Output, State
|
6 |
import json
|
7 |
import os
|
8 |
from dotenv import load_dotenv
|
|
|
11 |
from datetime import datetime
|
12 |
import networkx as nx
|
13 |
import plotly.graph_objects as go
|
14 |
+
import threading
|
15 |
+
from flask import Flask
|
16 |
|
17 |
# Load environment variables
|
18 |
load_dotenv()
|
|
|
27 |
print(f"Error initializing ExplorationPathGenerator: {e}")
|
28 |
raise
|
29 |
|
30 |
+
class GraphVisualizer:
|
31 |
+
def __init__(self, port=8050):
|
32 |
+
self.app = dash.Dash(__name__)
|
33 |
+
self.port = port
|
34 |
+
cyto.load_extra_layouts()
|
35 |
+
|
36 |
+
self.app.layout = html.Div([
|
37 |
+
cyto.Cytoscape(
|
38 |
+
id='exploration-graph',
|
39 |
+
layout={'name': 'cose'},
|
40 |
+
style={'width': '100%', 'height': '600px'},
|
41 |
+
elements=[],
|
42 |
+
stylesheet=[
|
43 |
+
{
|
44 |
+
'selector': 'node',
|
45 |
+
'style': {
|
46 |
+
'label': 'data(label)',
|
47 |
+
'background-color': '#6FB1FC',
|
48 |
+
'content': 'data(label)',
|
49 |
+
'text-wrap': 'wrap',
|
50 |
+
'text-max-width': '100px'
|
51 |
+
}
|
52 |
+
},
|
53 |
+
{
|
54 |
+
'selector': 'edge',
|
55 |
+
'style': {
|
56 |
+
'curve-style': 'bezier',
|
57 |
+
'target-arrow-shape': 'triangle',
|
58 |
+
'line-color': '#ccc',
|
59 |
+
'target-arrow-color': '#ccc',
|
60 |
+
'label': 'data(label)',
|
61 |
+
'font-size': '10px'
|
62 |
+
}
|
63 |
+
}
|
64 |
+
]
|
65 |
+
),
|
66 |
+
html.Div(id='node-details', style={'padding': '20px'})
|
67 |
+
])
|
68 |
+
|
69 |
+
self._setup_callbacks()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
def _setup_callbacks(self):
|
72 |
+
@self.app.callback(
|
73 |
+
Output('node-details', 'children'),
|
74 |
+
[Input('exploration-graph', 'tapNodeData')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
)
|
76 |
+
def display_node_details(node_data):
|
77 |
+
if not node_data:
|
78 |
+
return "Click a node to see details"
|
79 |
+
return html.Div([
|
80 |
+
html.H4(node_data['label']),
|
81 |
+
html.P(node_data.get('description', 'No description available'))
|
82 |
+
])
|
|
|
|
|
|
|
83 |
|
84 |
+
def update_graph(self, nodes):
|
85 |
+
elements = []
|
86 |
+
for node in nodes:
|
87 |
+
elements.append({
|
88 |
+
'data': {
|
89 |
+
'id': node['id'],
|
90 |
+
'label': node['title'],
|
91 |
+
'description': node['description'],
|
92 |
+
'depth': node['depth']
|
93 |
+
}
|
94 |
+
})
|
95 |
+
for conn in node['connections']:
|
96 |
+
elements.append({
|
97 |
+
'data': {
|
98 |
+
'source': node['id'],
|
99 |
+
'target': conn['target_id'],
|
100 |
+
'label': f"Score: {conn.get('relevance_score', 'N/A')}"
|
101 |
+
}
|
102 |
+
})
|
103 |
+
return elements
|
104 |
+
|
105 |
+
def run(self):
|
106 |
+
self.app.run_server(port=self.port, debug=False)
|
107 |
|
108 |
def summarize_node(node_data: Dict[str, Any]) -> str:
|
109 |
"""Generate a summary for a selected node"""
|
|
|
122 |
|
123 |
return summary
|
124 |
|
125 |
+
def handle_node_click(node_id: str, current_result: Dict[str, Any]) -> tuple[str, str]:
|
126 |
"""Handle node click event and update visualization"""
|
127 |
nodes = current_result.get('nodes', [])
|
128 |
selected_node = next((node for node in nodes if node['id'] == node_id), None)
|
129 |
|
130 |
if selected_node:
|
|
|
|
|
131 |
summary = summarize_node(selected_node)
|
132 |
else:
|
|
|
133 |
summary = "No node selected"
|
134 |
|
135 |
+
return current_result, summary
|
|
|
|
|
136 |
|
137 |
def explore(
|
138 |
query: str,
|
|
|
140 |
parameters: str = "{}",
|
141 |
depth: int = 5,
|
142 |
domain: str = ""
|
143 |
+
) -> tuple[str, str]:
|
144 |
"""Generate exploration path and visualization"""
|
145 |
try:
|
146 |
+
# Parse inputs
|
147 |
+
try:
|
148 |
+
selected_path = json.loads(path_history)
|
149 |
+
exploration_parameters = json.loads(parameters)
|
150 |
+
except json.JSONDecodeError as e:
|
151 |
+
raise ValueError(f"Invalid JSON input: {str(e)}")
|
152 |
+
|
153 |
result = generator.generate_exploration_path(
|
154 |
query=query,
|
155 |
selected_path=selected_path,
|
|
|
159 |
if not isinstance(result, dict):
|
160 |
raise ValueError("Invalid response format from generator")
|
161 |
|
|
|
|
|
|
|
162 |
# Initial summary (no node selected)
|
163 |
summary = "Click on a node to see details"
|
164 |
|
165 |
+
return json.dumps(result), summary
|
166 |
|
167 |
except Exception as e:
|
168 |
error_response = {
|
|
|
176 |
}
|
177 |
}
|
178 |
print(f"Error in explore function: {e}")
|
179 |
+
return json.dumps(error_response), "Error occurred"
|
180 |
|
181 |
def create_interface() -> gr.Blocks:
|
182 |
"""Create and configure the Gradio interface"""
|
183 |
+
visualizer = GraphVisualizer()
|
184 |
+
# Start Dash server in a separate thread
|
185 |
+
dash_thread = threading.Thread(target=visualizer.run, daemon=True)
|
186 |
+
dash_thread.start()
|
187 |
|
188 |
with gr.Blocks(title="Art History Exploration Path Generator") as interface:
|
189 |
gr.Markdown("# Art History Exploration Path Generator")
|
|
|
225 |
with gr.Accordion("Exploration Result", open=False):
|
226 |
text_output = gr.JSON(label="Raw Result")
|
227 |
|
228 |
+
# Embed Dash visualization
|
229 |
+
gr.HTML(f'<iframe src="http://localhost:{visualizer.port}" style="width:100%; height:700px; border:none;"></iframe>')
|
230 |
+
|
231 |
node_summary = gr.Textbox(
|
232 |
label="Node Details",
|
233 |
lines=5,
|
|
|
237 |
# Store current result for node clicking
|
238 |
current_result = gr.State({})
|
239 |
|
240 |
+
def explore_and_visualize(*args):
|
241 |
+
result, summary = explore(*args)
|
242 |
+
if isinstance(result, str):
|
243 |
+
result_dict = json.loads(result)
|
244 |
+
else:
|
245 |
+
result_dict = result
|
246 |
+
visualizer.update_graph(result_dict.get('nodes', []))
|
247 |
+
return result, summary
|
248 |
+
|
249 |
generate_btn.click(
|
250 |
+
fn=explore_and_visualize,
|
251 |
inputs=[query_input, path_history, parameters, depth, domain],
|
252 |
+
outputs=[text_output, node_summary]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
)
|
254 |
|
255 |
# Examples
|