baconnier commited on
Commit
3c74d99
·
verified ·
1 Parent(s): fad8082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -195
app.py CHANGED
@@ -1,281 +1,300 @@
1
- import gradio as gr
2
- import json
3
  import os
4
- from dotenv import load_dotenv
5
- from art_explorer import ExplorationPathGenerator
6
- from typing import Dict, Any, Optional, Union
7
  from datetime import datetime
8
- import networkx as nx
9
- import plotly.graph_objects as go
10
 
11
  # Load environment variables
12
  load_dotenv()
13
 
14
- # Initialize the generator with error handling
15
- try:
16
- api_key = os.getenv("GROQ_API_KEY")
17
- if not api_key:
18
- raise ValueError("GROQ_API_KEY not found in environment variables")
19
- generator = ExplorationPathGenerator(api_key=api_key)
20
- except Exception as e:
21
- print(f"Error initializing ExplorationPathGenerator: {e}")
22
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def create_interactive_graph(nodes):
25
  """Create an interactive graph visualization using D3.js"""
26
- html_content = """
27
- <div id="network-container" style="width: 100%; height: 600px; border: 1px solid #ddd; border-radius: 4px; background-color: #ffffff;"></div>
28
- <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/7.8.5/d3.min.js"></script>
29
- <style>
30
- .node { cursor: pointer; }
31
- .node text {
32
- font-size: 12px;
33
- font-family: Arial, sans-serif;
34
- }
35
- .link {
36
- stroke: #999;
37
- stroke-opacity: 0.6;
38
- }
39
- .node-tooltip {
40
- position: absolute;
41
- padding: 8px;
42
- background: rgba(0, 0, 0, 0.8);
43
- color: #fff;
44
- border-radius: 4px;
45
- font-size: 12px;
46
- pointer-events: none;
47
- }
48
- </style>
49
- <script>
50
- (function() {
51
- // Convert Python data to JavaScript
52
- const data = {
53
- nodes: """ + json.dumps([{
54
- 'id': node['id'],
55
- 'title': node['title'],
56
- 'description': node['description'],
57
- 'depth': node['depth']
58
- } for node in nodes]) + """,
59
- links: """ + json.dumps([{
60
- 'source': node['id'],
61
- 'target': conn['target_id'],
62
- 'value': conn.get('relevance_score', 1)
63
- } for node in nodes for conn in node.get('connections', [])]) + """
64
- };
65
 
66
- // Clear any existing visualization
67
- const container = document.getElementById('network-container');
68
- container.innerHTML = '';
69
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  const width = container.clientWidth;
71
  const height = container.clientHeight;
72
-
73
- const svg = d3.select("#network-container")
74
  .append("svg")
75
  .attr("width", width)
76
  .attr("height", height);
77
-
78
- // Add zoom behavior
79
  const g = svg.append("g");
80
- svg.call(d3.zoom()
81
  .scaleExtent([0.1, 4])
82
- .on("zoom", (event) => {
83
- g.attr("transform", event.transform);
84
- }));
85
-
86
- // Create the force simulation
87
  const simulation = d3.forceSimulation(data.nodes)
88
- .force("link", d3.forceLink(data.links).id(d => d.id).distance(100))
89
- .force("charge", d3.forceManyBody().strength(-300))
90
- .force("center", d3.forceCenter(width / 2, height / 2))
91
- .force("collision", d3.forceCollide().radius(50));
92
-
93
  // Create the links
94
  const link = g.append("g")
95
  .selectAll("line")
96
  .data(data.links)
97
  .join("line")
98
- .attr("class", "link")
99
- .attr("stroke-width", d => Math.sqrt(d.value || 1));
100
-
101
  // Create the nodes
102
  const node = g.append("g")
103
- .selectAll(".node")
104
  .data(data.nodes)
105
  .join("g")
106
- .attr("class", "node")
107
  .call(d3.drag()
108
  .on("start", dragstarted)
109
  .on("drag", dragged)
110
  .on("end", dragended));
111
-
112
  // Add circles to nodes
113
  node.append("circle")
114
  .attr("r", 20)
115
  .attr("fill", d => ['#FF9999', '#99FF99', '#9999FF'][d.depth % 3]);
116
-
117
- // Add text labels
118
  node.append("text")
119
- .attr("dy", 30)
120
- .attr("text-anchor", "middle")
121
  .text(d => d.title)
122
- .call(wrap, 60);
123
-
124
- // Create tooltip
125
- const tooltip = d3.select("body")
126
- .append("div")
127
- .attr("class", "node-tooltip")
128
  .style("opacity", 0);
129
-
130
  // Add hover effects
131
- node.on("mouseover", function(event, d) {
132
  tooltip.transition()
133
  .duration(200)
134
  .style("opacity", .9);
135
- tooltip.html(`Title: ${d.title}<br/>Description: ${d.description}`)
136
  .style("left", (event.pageX + 10) + "px")
137
  .style("top", (event.pageY - 10) + "px");
138
- })
139
- .on("mouseout", function(d) {
140
  tooltip.transition()
141
  .duration(500)
142
  .style("opacity", 0);
143
- });
144
-
145
  // Add click handler
146
- node.on("click", function(event, d) {
147
- if (window.gradio_client) {
148
- window.gradio_client.dispatch("select", d);
149
- }
150
- });
151
-
152
  // Update positions on each tick
153
- simulation.on("tick", () => {
154
  link
155
  .attr("x1", d => d.source.x)
156
  .attr("y1", d => d.source.y)
157
  .attr("x2", d => d.target.x)
158
  .attr("y2", d => d.target.y);
159
-
160
- node
161
- .attr("transform", d => `translate(${d.x},${d.y})`);
162
- });
163
-
164
  // Drag functions
165
- function dragstarted(event) {
166
  if (!event.active) simulation.alphaTarget(0.3).restart();
167
- event.subject.fx = event.subject.x;
168
- event.subject.fy = event.subject.y;
169
- }
170
-
171
- function dragged(event) {
172
- event.subject.fx = event.x;
173
- event.subject.fy = event.y;
174
- }
175
-
176
- function dragended(event) {
177
  if (!event.active) simulation.alphaTarget(0);
178
- event.subject.fx = null;
179
- event.subject.fy = null;
180
- }
181
-
182
- // Text wrapping function
183
- function wrap(text, width) {
184
- text.each(function() {
185
- const text = d3.select(this);
186
- const words = text.text().split(/\s+/).reverse();
187
- let word;
188
- let line = [];
189
- let lineNumber = 0;
190
- const lineHeight = 1.1;
191
- const y = text.attr("y");
192
- const dy = parseFloat(text.attr("dy"));
193
- let tspan = text.text(null).append("tspan").attr("x", 0).attr("y", y).attr("dy", dy + "px");
194
-
195
- while (word = words.pop()) {
196
- line.push(word);
197
- tspan.text(line.join(" "));
198
- if (tspan.node().getComputedTextLength() > width) {
199
- line.pop();
200
- tspan.text(line.join(" "));
201
- line = [word];
202
- tspan = text.append("tspan").attr("x", 0).attr("y", y).attr("dy", ++lineNumber * lineHeight + dy + "px").text(word);
203
- }
204
- }
205
- });
206
- }
207
- })();
208
- </script>
209
  """
210
  return html_content
211
-
212
- def summarize_node(node_data: Dict[str, Any]) -> str:
213
- """Generate a summary for a selected node"""
214
- summary = f"""
215
- Title: {node_data['title']}
216
- Depth: {node_data['depth']}
217
- Description: {node_data['description']}
218
-
219
- Connections:
220
- """
221
- for conn in node_data['connections']:
222
- summary += f"- Connected to: {conn['target_id']}"
223
- if 'relevance_score' in conn:
224
- summary += f" (Relevance: {conn['relevance_score']})"
225
- summary += "\n"
226
-
227
- return summary
228
 
229
- def explore(
230
- query: str,
231
- path_history: str = "[]",
232
- parameters: str = "{}",
233
- depth: int = 5,
234
- domain: str = ""
235
- ) -> tuple[str, go.Figure, str]:
236
  """Generate exploration path and visualization"""
237
  try:
 
 
 
 
 
 
 
238
  # Parse inputs
239
  try:
240
  selected_path = json.loads(path_history)
241
  exploration_parameters = json.loads(parameters)
242
  except json.JSONDecodeError as e:
243
  raise ValueError(f"Invalid JSON input: {str(e)}")
244
-
 
 
 
 
 
245
  result = generator.generate_exploration_path(
246
  query=query,
247
  selected_path=selected_path,
248
  exploration_parameters=exploration_parameters
249
  )
250
 
251
- if not isinstance(result, dict):
252
- raise ValueError("Invalid response format from generator")
253
-
254
  # Create visualization
255
- fig = create_interactive_graph(result.get('nodes', []))
256
 
257
  # Initial summary
258
- summary = "Interact with nodes in the graph to explore relationships"
259
 
260
- return json.dumps(result), fig, summary
261
 
262
  except Exception as e:
263
  error_response = {
264
  "error": str(e),
265
  "status": "failed",
266
- "message": "Failed to generate exploration path",
267
- "details": {
268
- "query": query,
269
- "depth": depth,
270
- "domain": domain
271
- }
272
  }
273
- print(f"Error in explore function: {e}")
274
- return json.dumps(error_response), go.Figure(), "Error occurred"
275
 
276
  def create_interface() -> gr.Blocks:
277
  """Create and configure the Gradio interface"""
278
- with gr.Blocks(title="Art History Exploration Path Generator", theme=gr.themes.Soft()) as interface:
 
 
 
 
279
  gr.Markdown("""
280
  # Art History Exploration Path Generator
281
  Generate interactive exploration paths through art history topics.
@@ -319,11 +338,15 @@ def create_interface() -> gr.Blocks:
319
  with gr.Accordion("Exploration Result", open=False):
320
  text_output = gr.JSON(label="Raw Result")
321
 
322
- graph_output = gr.HTML(label="Interactive Exploration Graph")
 
 
 
 
323
  node_summary = gr.Textbox(
324
  label="Node Details",
325
  lines=5,
326
- placeholder="Interact with nodes to see details"
327
  )
328
 
329
  generate_btn.click(
@@ -332,7 +355,6 @@ def create_interface() -> gr.Blocks:
332
  outputs=[text_output, graph_output, node_summary]
333
  )
334
 
335
- # Add examples
336
  gr.Examples(
337
  examples=[
338
  ["Explore the evolution of Renaissance painting techniques", "[]", "{}", 5, "Renaissance"],
@@ -348,6 +370,10 @@ if __name__ == "__main__":
348
  try:
349
  print(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
350
  demo = create_interface()
351
- demo.launch()
 
 
 
 
352
  except Exception as e:
353
  print(f"Failed to launch interface: {e}")
 
 
 
1
  import os
2
+ import json
3
+ import gradio as gr
 
4
  from datetime import datetime
5
+ from dotenv import load_dotenv
6
+ from openai import OpenAI
7
 
8
  # Load environment variables
9
  load_dotenv()
10
 
11
+ class ExplorationPathGenerator:
12
+ def __init__(self, api_key: str):
13
+ self.client = OpenAI(
14
+ api_key=api_key,
15
+ base_url="https://api.groq.com/openai/v1"
16
+ )
17
+
18
+ def generate_exploration_path(self, query: str, selected_path=None, exploration_parameters=None):
19
+ try:
20
+ if selected_path is None:
21
+ selected_path = []
22
+ if exploration_parameters is None:
23
+ exploration_parameters = {}
24
+
25
+ system_prompt = """You are an expert art historian AI that helps users explore art history topics by generating
26
+ interconnected exploration paths. Generate a JSON response with nodes representing concepts, artworks, or historical
27
+ events, and connections showing their relationships."""
28
+
29
+ user_prompt = f"""Query: {query}
30
+ Selected Path: {json.dumps(selected_path)}
31
+ Parameters: {json.dumps(exploration_parameters)}
32
+
33
+ Generate an exploration path that includes:
34
+ - Multiple interconnected nodes
35
+ - Clear relationships between nodes
36
+ - Depth-based organization
37
+ - Relevant historical context
38
+
39
+ Response must be valid JSON with this structure:
40
+ {{
41
+ "nodes": [
42
+ {{
43
+ "id": "unique_string",
44
+ "title": "node_title",
45
+ "description": "detailed_description",
46
+ "depth": number,
47
+ "connections": [
48
+ {{
49
+ "target_id": "connected_node_id",
50
+ "relevance_score": float
51
+ }}
52
+ ]
53
+ }}
54
+ ]
55
+ }}"""
56
+
57
+ response = self.client.chat.completions.create(
58
+ model="mixtral-8x7b-32768",
59
+ messages=[
60
+ {"role": "system", "content": system_prompt},
61
+ {"role": "user", "content": user_prompt}
62
+ ],
63
+ temperature=0.7,
64
+ max_tokens=4000
65
+ )
66
+
67
+ result = json.loads(response.choices[0].message.content)
68
+ return result
69
+
70
+ except Exception as e:
71
+ print(f"Error generating exploration path: {e}")
72
+ return {"error": str(e)}
73
 
74
  def create_interactive_graph(nodes):
75
  """Create an interactive graph visualization using D3.js"""
76
+ # First, let's create the data structure D3 expects
77
+ nodes_data = [{
78
+ 'id': node['id'],
79
+ 'title': node['title'],
80
+ 'description': node['description'],
81
+ 'depth': node['depth']
82
+ } for node in nodes]
83
+
84
+ links_data = [{
85
+ 'source': node['id'],
86
+ 'target': conn['target_id'],
87
+ 'value': conn.get('relevance_score', 1)
88
+ } for node in nodes for conn in node.get('connections', [])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ html_content = f"""
91
+ <!DOCTYPE html>
92
+ <html>
93
+ <head>
94
+ <script src="https://d3js.org/d3.v7.min.js"></script>
95
+ <style>
96
+ #graph-container {{
97
+ width: 100%;
98
+ height: 600px;
99
+ border: 1px solid #ddd;
100
+ border-radius: 4px;
101
+ }}
102
+ .node {{
103
+ cursor: pointer;
104
+ }}
105
+ .node text {{
106
+ font-size: 12px;
107
+ font-family: Arial, sans-serif;
108
+ }}
109
+ .link {{
110
+ stroke: #999;
111
+ stroke-opacity: 0.6;
112
+ }}
113
+ .tooltip {{
114
+ position: absolute;
115
+ padding: 8px;
116
+ background: rgba(0, 0, 0, 0.8);
117
+ color: white;
118
+ border-radius: 4px;
119
+ font-size: 12px;
120
+ pointer-events: none;
121
+ }}
122
+ </style>
123
+ </head>
124
+ <body>
125
+ <div id="graph-container"></div>
126
+ <script>
127
+ // Data
128
+ const data = {{
129
+ nodes: {json.dumps(nodes_data)},
130
+ links: {json.dumps(links_data)}
131
+ }};
132
+
133
+ // Set up the SVG container
134
+ const container = document.getElementById('graph-container');
135
  const width = container.clientWidth;
136
  const height = container.clientHeight;
137
+
138
+ const svg = d3.select("#graph-container")
139
  .append("svg")
140
  .attr("width", width)
141
  .attr("height", height);
142
+
143
+ // Add zoom capabilities
144
  const g = svg.append("g");
145
+ const zoom = d3.zoom()
146
  .scaleExtent([0.1, 4])
147
+ .on("zoom", (event) => g.attr("transform", event.transform));
148
+ svg.call(zoom);
149
+
150
+ // Create a force simulation
 
151
  const simulation = d3.forceSimulation(data.nodes)
152
+ .force("link", d3.forceLink(data.links).id(d => d.id))
153
+ .force("charge", d3.forceManyBody().strength(-400))
154
+ .force("center", d3.forceCenter(width / 2, height / 2));
155
+
 
156
  // Create the links
157
  const link = g.append("g")
158
  .selectAll("line")
159
  .data(data.links)
160
  .join("line")
161
+ .attr("stroke", "#999")
162
+ .attr("stroke-width", 1);
163
+
164
  // Create the nodes
165
  const node = g.append("g")
166
+ .selectAll("g")
167
  .data(data.nodes)
168
  .join("g")
 
169
  .call(d3.drag()
170
  .on("start", dragstarted)
171
  .on("drag", dragged)
172
  .on("end", dragended));
173
+
174
  // Add circles to nodes
175
  node.append("circle")
176
  .attr("r", 20)
177
  .attr("fill", d => ['#FF9999', '#99FF99', '#9999FF'][d.depth % 3]);
178
+
179
+ // Add labels to nodes
180
  node.append("text")
 
 
181
  .text(d => d.title)
182
+ .attr("x", 25)
183
+ .attr("y", 5);
184
+
185
+ // Add tooltip
186
+ const tooltip = d3.select("body").append("div")
187
+ .attr("class", "tooltip")
188
  .style("opacity", 0);
189
+
190
  // Add hover effects
191
+ node.on("mouseover", function(event, d) {{
192
  tooltip.transition()
193
  .duration(200)
194
  .style("opacity", .9);
195
+ tooltip.html(`<strong>${{d.title}}</strong><br/>${{d.description}}`)
196
  .style("left", (event.pageX + 10) + "px")
197
  .style("top", (event.pageY - 10) + "px");
198
+ }})
199
+ .on("mouseout", function() {{
200
  tooltip.transition()
201
  .duration(500)
202
  .style("opacity", 0);
203
+ }});
204
+
205
  // Add click handler
206
+ node.on("click", function(event, d) {{
207
+ if (window.gradio) {{
208
+ window.gradio.dispatch("select", d);
209
+ }}
210
+ }});
211
+
212
  // Update positions on each tick
213
+ simulation.on("tick", () => {{
214
  link
215
  .attr("x1", d => d.source.x)
216
  .attr("y1", d => d.source.y)
217
  .attr("x2", d => d.target.x)
218
  .attr("y2", d => d.target.y);
219
+
220
+ node.attr("transform", d => `translate(${{d.x}},${{d.y}})`);
221
+ }});
222
+
 
223
  // Drag functions
224
+ function dragstarted(event, d) {{
225
  if (!event.active) simulation.alphaTarget(0.3).restart();
226
+ d.fx = d.x;
227
+ d.fy = d.y;
228
+ }}
229
+
230
+ function dragged(event, d) {{
231
+ d.fx = event.x;
232
+ d.fy = event.y;
233
+ }}
234
+
235
+ function dragended(event, d) {{
236
  if (!event.active) simulation.alphaTarget(0);
237
+ d.fx = null;
238
+ d.fy = null;
239
+ }}
240
+ </script>
241
+ </body>
242
+ </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  """
244
  return html_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ def explore(query: str, path_history: str = "[]", parameters: str = "{}", depth: int = 5, domain: str = "") -> tuple:
 
 
 
 
 
 
247
  """Generate exploration path and visualization"""
248
  try:
249
+ # Initialize generator
250
+ api_key = os.getenv("GROQ_API_KEY")
251
+ if not api_key:
252
+ raise ValueError("GROQ_API_KEY not found in environment variables")
253
+
254
+ generator = ExplorationPathGenerator(api_key=api_key)
255
+
256
  # Parse inputs
257
  try:
258
  selected_path = json.loads(path_history)
259
  exploration_parameters = json.loads(parameters)
260
  except json.JSONDecodeError as e:
261
  raise ValueError(f"Invalid JSON input: {str(e)}")
262
+
263
+ # Add domain to parameters if provided
264
+ if domain:
265
+ exploration_parameters["domain"] = domain
266
+
267
+ # Generate result
268
  result = generator.generate_exploration_path(
269
  query=query,
270
  selected_path=selected_path,
271
  exploration_parameters=exploration_parameters
272
  )
273
 
 
 
 
274
  # Create visualization
275
+ graph_html = create_interactive_graph(result.get('nodes', []))
276
 
277
  # Initial summary
278
+ summary = "Click on nodes in the graph to see detailed information"
279
 
280
+ return json.dumps(result), graph_html, summary
281
 
282
  except Exception as e:
283
  error_response = {
284
  "error": str(e),
285
  "status": "failed",
286
+ "timestamp": datetime.now().isoformat(),
287
+ "query": query
 
 
 
 
288
  }
289
+ return json.dumps(error_response), "<div>Error generating visualization</div>", f"Error: {str(e)}"
 
290
 
291
  def create_interface() -> gr.Blocks:
292
  """Create and configure the Gradio interface"""
293
+ with gr.Blocks(
294
+ title="Art History Exploration Path Generator",
295
+ theme=gr.themes.Soft(),
296
+ css="#graph-visualization {min-height: 600px;}"
297
+ ) as interface:
298
  gr.Markdown("""
299
  # Art History Exploration Path Generator
300
  Generate interactive exploration paths through art history topics.
 
338
  with gr.Accordion("Exploration Result", open=False):
339
  text_output = gr.JSON(label="Raw Result")
340
 
341
+ graph_output = gr.HTML(
342
+ label="Interactive Exploration Graph",
343
+ value="<div>Generate a path to see the visualization</div>",
344
+ elem_id="graph-visualization"
345
+ )
346
  node_summary = gr.Textbox(
347
  label="Node Details",
348
  lines=5,
349
+ placeholder="Click on nodes to see details"
350
  )
351
 
352
  generate_btn.click(
 
355
  outputs=[text_output, graph_output, node_summary]
356
  )
357
 
 
358
  gr.Examples(
359
  examples=[
360
  ["Explore the evolution of Renaissance painting techniques", "[]", "{}", 5, "Renaissance"],
 
370
  try:
371
  print(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
372
  demo = create_interface()
373
+ demo.launch(
374
+ server_name="0.0.0.0",
375
+ server_port=7860,
376
+ share=True
377
+ )
378
  except Exception as e:
379
  print(f"Failed to launch interface: {e}")