Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| from datetime import datetime | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| # Load environment variables | |
| load_dotenv() | |
| class ExplorationPathGenerator: | |
| def __init__(self, api_key: str): | |
| self.client = OpenAI( | |
| api_key=api_key, | |
| base_url="https://api.groq.com/openai/v1" | |
| ) | |
| def generate_exploration_path(self, query: str, selected_path=None, exploration_parameters=None): | |
| try: | |
| if selected_path is None: | |
| selected_path = [] | |
| if exploration_parameters is None: | |
| exploration_parameters = {} | |
| system_prompt = """You are an expert art historian AI that helps users explore art history topics by generating | |
| interconnected exploration paths. Generate a JSON response with nodes representing concepts, artworks, or historical | |
| events, and connections showing their relationships.""" | |
| user_prompt = f"""Query: {query} | |
| Selected Path: {json.dumps(selected_path)} | |
| Parameters: {json.dumps(exploration_parameters)} | |
| Generate an exploration path that includes: | |
| - Multiple interconnected nodes | |
| - Clear relationships between nodes | |
| - Depth-based organization | |
| - Relevant historical context | |
| Response must be valid JSON with this structure: | |
| {{ | |
| "nodes": [ | |
| {{ | |
| "id": "unique_string", | |
| "title": "node_title", | |
| "description": "detailed_description", | |
| "depth": number, | |
| "connections": [ | |
| {{ | |
| "target_id": "connected_node_id", | |
| "relevance_score": float | |
| }} | |
| ] | |
| }} | |
| ] | |
| }}""" | |
| response = self.client.chat.completions.create( | |
| model="mixtral-8x7b-32768", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=0.7, | |
| max_tokens=4000 | |
| ) | |
| result = json.loads(response.choices[0].message.content) | |
| return result | |
| except Exception as e: | |
| print(f"Error generating exploration path: {e}") | |
| return {"error": str(e)} | |
| def create_interactive_graph(nodes): | |
| """Create an interactive graph visualization using D3.js""" | |
| # First, let's create the data structure D3 expects | |
| nodes_data = [{ | |
| 'id': node['id'], | |
| 'title': node['title'], | |
| 'description': node['description'], | |
| 'depth': node['depth'] | |
| } for node in nodes] | |
| links_data = [{ | |
| 'source': node['id'], | |
| 'target': conn['target_id'], | |
| 'value': conn.get('relevance_score', 1) | |
| } for node in nodes for conn in node.get('connections', [])] | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <script src="https://d3js.org/d3.v7.min.js"></script> | |
| <style> | |
| #graph-container {{ | |
| width: 100%; | |
| height: 600px; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| }} | |
| .node {{ | |
| cursor: pointer; | |
| }} | |
| .node text {{ | |
| font-size: 12px; | |
| font-family: Arial, sans-serif; | |
| }} | |
| .link {{ | |
| stroke: #999; | |
| stroke-opacity: 0.6; | |
| }} | |
| .tooltip {{ | |
| position: absolute; | |
| padding: 8px; | |
| background: rgba(0, 0, 0, 0.8); | |
| color: white; | |
| border-radius: 4px; | |
| font-size: 12px; | |
| pointer-events: none; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div id="graph-container"></div> | |
| <script> | |
| // Data | |
| const data = {{ | |
| nodes: {json.dumps(nodes_data)}, | |
| links: {json.dumps(links_data)} | |
| }}; | |
| // Set up the SVG container | |
| const container = document.getElementById('graph-container'); | |
| const width = container.clientWidth; | |
| const height = container.clientHeight; | |
| const svg = d3.select("#graph-container") | |
| .append("svg") | |
| .attr("width", width) | |
| .attr("height", height); | |
| // Add zoom capabilities | |
| const g = svg.append("g"); | |
| const zoom = d3.zoom() | |
| .scaleExtent([0.1, 4]) | |
| .on("zoom", (event) => g.attr("transform", event.transform)); | |
| svg.call(zoom); | |
| // Create a force simulation | |
| const simulation = d3.forceSimulation(data.nodes) | |
| .force("link", d3.forceLink(data.links).id(d => d.id)) | |
| .force("charge", d3.forceManyBody().strength(-400)) | |
| .force("center", d3.forceCenter(width / 2, height / 2)); | |
| // Create the links | |
| const link = g.append("g") | |
| .selectAll("line") | |
| .data(data.links) | |
| .join("line") | |
| .attr("stroke", "#999") | |
| .attr("stroke-width", 1); | |
| // Create the nodes | |
| const node = g.append("g") | |
| .selectAll("g") | |
| .data(data.nodes) | |
| .join("g") | |
| .call(d3.drag() | |
| .on("start", dragstarted) | |
| .on("drag", dragged) | |
| .on("end", dragended)); | |
| // Add circles to nodes | |
| node.append("circle") | |
| .attr("r", 20) | |
| .attr("fill", d => ['#FF9999', '#99FF99', '#9999FF'][d.depth % 3]); | |
| // Add labels to nodes | |
| node.append("text") | |
| .text(d => d.title) | |
| .attr("x", 25) | |
| .attr("y", 5); | |
| // Add tooltip | |
| const tooltip = d3.select("body").append("div") | |
| .attr("class", "tooltip") | |
| .style("opacity", 0); | |
| // Add hover effects | |
| node.on("mouseover", function(event, d) {{ | |
| tooltip.transition() | |
| .duration(200) | |
| .style("opacity", .9); | |
| tooltip.html(`<strong>${{d.title}}</strong><br/>${{d.description}}`) | |
| .style("left", (event.pageX + 10) + "px") | |
| .style("top", (event.pageY - 10) + "px"); | |
| }}) | |
| .on("mouseout", function() {{ | |
| tooltip.transition() | |
| .duration(500) | |
| .style("opacity", 0); | |
| }}); | |
| // Add click handler | |
| node.on("click", function(event, d) {{ | |
| if (window.gradio) {{ | |
| window.gradio.dispatch("select", d); | |
| }} | |
| }}); | |
| // Update positions on each tick | |
| simulation.on("tick", () => {{ | |
| link | |
| .attr("x1", d => d.source.x) | |
| .attr("y1", d => d.source.y) | |
| .attr("x2", d => d.target.x) | |
| .attr("y2", d => d.target.y); | |
| node.attr("transform", d => `translate(${{d.x}},${{d.y}})`); | |
| }}); | |
| // Drag functions | |
| function dragstarted(event, d) {{ | |
| if (!event.active) simulation.alphaTarget(0.3).restart(); | |
| d.fx = d.x; | |
| d.fy = d.y; | |
| }} | |
| function dragged(event, d) {{ | |
| d.fx = event.x; | |
| d.fy = event.y; | |
| }} | |
| function dragended(event, d) {{ | |
| if (!event.active) simulation.alphaTarget(0); | |
| d.fx = null; | |
| d.fy = null; | |
| }} | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return html_content | |
| def explore(query: str, path_history: str = "[]", parameters: str = "{}", depth: int = 5, domain: str = "") -> tuple: | |
| """Generate exploration path and visualization""" | |
| try: | |
| # Initialize generator | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError("GROQ_API_KEY not found in environment variables") | |
| generator = ExplorationPathGenerator(api_key=api_key) | |
| # Parse inputs | |
| try: | |
| selected_path = json.loads(path_history) | |
| exploration_parameters = json.loads(parameters) | |
| except json.JSONDecodeError as e: | |
| raise ValueError(f"Invalid JSON input: {str(e)}") | |
| # Add domain to parameters if provided | |
| if domain: | |
| exploration_parameters["domain"] = domain | |
| # Generate result | |
| result = generator.generate_exploration_path( | |
| query=query, | |
| selected_path=selected_path, | |
| exploration_parameters=exploration_parameters | |
| ) | |
| # Create visualization | |
| graph_html = create_interactive_graph(result.get('nodes', [])) | |
| # Initial summary | |
| summary = "Click on nodes in the graph to see detailed information" | |
| return json.dumps(result), graph_html, summary | |
| except Exception as e: | |
| error_response = { | |
| "error": str(e), | |
| "status": "failed", | |
| "timestamp": datetime.now().isoformat(), | |
| "query": query | |
| } | |
| return json.dumps(error_response), "<div>Error generating visualization</div>", f"Error: {str(e)}" | |
| def create_interface() -> gr.Blocks: | |
| """Create and configure the Gradio interface""" | |
| with gr.Blocks( | |
| title="Art History Exploration Path Generator", | |
| theme=gr.themes.Soft(), | |
| css="#graph-visualization {min-height: 600px;}" | |
| ) as interface: | |
| gr.Markdown(""" | |
| # Art History Exploration Path Generator | |
| Generate interactive exploration paths through art history topics. | |
| Drag nodes to rearrange, zoom with mouse wheel, and click nodes for details. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| query_input = gr.Textbox( | |
| label="Exploration Query", | |
| placeholder="Enter your art history exploration query...", | |
| lines=2 | |
| ) | |
| path_history = gr.Textbox( | |
| label="Path History (JSON)", | |
| placeholder="[]", | |
| lines=3, | |
| value="[]" | |
| ) | |
| parameters = gr.Textbox( | |
| label="Additional Parameters (JSON)", | |
| placeholder="{}", | |
| lines=3, | |
| value="{}" | |
| ) | |
| depth = gr.Slider( | |
| label="Exploration Depth", | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1 | |
| ) | |
| domain = gr.Textbox( | |
| label="Domain Context", | |
| placeholder="Optional: Specify art history period or movement", | |
| lines=1 | |
| ) | |
| generate_btn = gr.Button("Generate Exploration Path", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Accordion("Exploration Result", open=False): | |
| text_output = gr.JSON(label="Raw Result") | |
| graph_output = gr.HTML( | |
| label="Interactive Exploration Graph", | |
| value="<div>Generate a path to see the visualization</div>", | |
| elem_id="graph-visualization" | |
| ) | |
| node_summary = gr.Textbox( | |
| label="Node Details", | |
| lines=5, | |
| placeholder="Click on nodes to see details" | |
| ) | |
| generate_btn.click( | |
| fn=explore, | |
| inputs=[query_input, path_history, parameters, depth, domain], | |
| outputs=[text_output, graph_output, node_summary] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Explore the evolution of Renaissance painting techniques", "[]", "{}", 5, "Renaissance"], | |
| ["Investigate the influence of Japanese art on Impressionism", "[]", "{}", 7, "Impressionism"], | |
| ["Analyze the development of Cubism through Picasso's work", "[]", "{}", 6, "Cubism"] | |
| ], | |
| inputs=[query_input, path_history, parameters, depth, domain] | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| try: | |
| print(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====") | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) | |
| except Exception as e: | |
| print(f"Failed to launch interface: {e}") |