Encyclopedia / app.py
baconnier's picture
Update app.py
3c74d99 verified
raw
history blame
12.6 kB
import os
import json
import gradio as gr
from datetime import datetime
from dotenv import load_dotenv
from openai import OpenAI
# Load environment variables
load_dotenv()
class ExplorationPathGenerator:
def __init__(self, api_key: str):
self.client = OpenAI(
api_key=api_key,
base_url="https://api.groq.com/openai/v1"
)
def generate_exploration_path(self, query: str, selected_path=None, exploration_parameters=None):
try:
if selected_path is None:
selected_path = []
if exploration_parameters is None:
exploration_parameters = {}
system_prompt = """You are an expert art historian AI that helps users explore art history topics by generating
interconnected exploration paths. Generate a JSON response with nodes representing concepts, artworks, or historical
events, and connections showing their relationships."""
user_prompt = f"""Query: {query}
Selected Path: {json.dumps(selected_path)}
Parameters: {json.dumps(exploration_parameters)}
Generate an exploration path that includes:
- Multiple interconnected nodes
- Clear relationships between nodes
- Depth-based organization
- Relevant historical context
Response must be valid JSON with this structure:
{{
"nodes": [
{{
"id": "unique_string",
"title": "node_title",
"description": "detailed_description",
"depth": number,
"connections": [
{{
"target_id": "connected_node_id",
"relevance_score": float
}}
]
}}
]
}}"""
response = self.client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.7,
max_tokens=4000
)
result = json.loads(response.choices[0].message.content)
return result
except Exception as e:
print(f"Error generating exploration path: {e}")
return {"error": str(e)}
def create_interactive_graph(nodes):
"""Create an interactive graph visualization using D3.js"""
# First, let's create the data structure D3 expects
nodes_data = [{
'id': node['id'],
'title': node['title'],
'description': node['description'],
'depth': node['depth']
} for node in nodes]
links_data = [{
'source': node['id'],
'target': conn['target_id'],
'value': conn.get('relevance_score', 1)
} for node in nodes for conn in node.get('connections', [])]
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<script src="https://d3js.org/d3.v7.min.js"></script>
<style>
#graph-container {{
width: 100%;
height: 600px;
border: 1px solid #ddd;
border-radius: 4px;
}}
.node {{
cursor: pointer;
}}
.node text {{
font-size: 12px;
font-family: Arial, sans-serif;
}}
.link {{
stroke: #999;
stroke-opacity: 0.6;
}}
.tooltip {{
position: absolute;
padding: 8px;
background: rgba(0, 0, 0, 0.8);
color: white;
border-radius: 4px;
font-size: 12px;
pointer-events: none;
}}
</style>
</head>
<body>
<div id="graph-container"></div>
<script>
// Data
const data = {{
nodes: {json.dumps(nodes_data)},
links: {json.dumps(links_data)}
}};
// Set up the SVG container
const container = document.getElementById('graph-container');
const width = container.clientWidth;
const height = container.clientHeight;
const svg = d3.select("#graph-container")
.append("svg")
.attr("width", width)
.attr("height", height);
// Add zoom capabilities
const g = svg.append("g");
const zoom = d3.zoom()
.scaleExtent([0.1, 4])
.on("zoom", (event) => g.attr("transform", event.transform));
svg.call(zoom);
// Create a force simulation
const simulation = d3.forceSimulation(data.nodes)
.force("link", d3.forceLink(data.links).id(d => d.id))
.force("charge", d3.forceManyBody().strength(-400))
.force("center", d3.forceCenter(width / 2, height / 2));
// Create the links
const link = g.append("g")
.selectAll("line")
.data(data.links)
.join("line")
.attr("stroke", "#999")
.attr("stroke-width", 1);
// Create the nodes
const node = g.append("g")
.selectAll("g")
.data(data.nodes)
.join("g")
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("end", dragended));
// Add circles to nodes
node.append("circle")
.attr("r", 20)
.attr("fill", d => ['#FF9999', '#99FF99', '#9999FF'][d.depth % 3]);
// Add labels to nodes
node.append("text")
.text(d => d.title)
.attr("x", 25)
.attr("y", 5);
// Add tooltip
const tooltip = d3.select("body").append("div")
.attr("class", "tooltip")
.style("opacity", 0);
// Add hover effects
node.on("mouseover", function(event, d) {{
tooltip.transition()
.duration(200)
.style("opacity", .9);
tooltip.html(`<strong>${{d.title}}</strong><br/>${{d.description}}`)
.style("left", (event.pageX + 10) + "px")
.style("top", (event.pageY - 10) + "px");
}})
.on("mouseout", function() {{
tooltip.transition()
.duration(500)
.style("opacity", 0);
}});
// Add click handler
node.on("click", function(event, d) {{
if (window.gradio) {{
window.gradio.dispatch("select", d);
}}
}});
// Update positions on each tick
simulation.on("tick", () => {{
link
.attr("x1", d => d.source.x)
.attr("y1", d => d.source.y)
.attr("x2", d => d.target.x)
.attr("y2", d => d.target.y);
node.attr("transform", d => `translate(${{d.x}},${{d.y}})`);
}});
// Drag functions
function dragstarted(event, d) {{
if (!event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
}}
function dragged(event, d) {{
d.fx = event.x;
d.fy = event.y;
}}
function dragended(event, d) {{
if (!event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
}}
</script>
</body>
</html>
"""
return html_content
def explore(query: str, path_history: str = "[]", parameters: str = "{}", depth: int = 5, domain: str = "") -> tuple:
"""Generate exploration path and visualization"""
try:
# Initialize generator
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("GROQ_API_KEY not found in environment variables")
generator = ExplorationPathGenerator(api_key=api_key)
# Parse inputs
try:
selected_path = json.loads(path_history)
exploration_parameters = json.loads(parameters)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON input: {str(e)}")
# Add domain to parameters if provided
if domain:
exploration_parameters["domain"] = domain
# Generate result
result = generator.generate_exploration_path(
query=query,
selected_path=selected_path,
exploration_parameters=exploration_parameters
)
# Create visualization
graph_html = create_interactive_graph(result.get('nodes', []))
# Initial summary
summary = "Click on nodes in the graph to see detailed information"
return json.dumps(result), graph_html, summary
except Exception as e:
error_response = {
"error": str(e),
"status": "failed",
"timestamp": datetime.now().isoformat(),
"query": query
}
return json.dumps(error_response), "<div>Error generating visualization</div>", f"Error: {str(e)}"
def create_interface() -> gr.Blocks:
"""Create and configure the Gradio interface"""
with gr.Blocks(
title="Art History Exploration Path Generator",
theme=gr.themes.Soft(),
css="#graph-visualization {min-height: 600px;}"
) as interface:
gr.Markdown("""
# Art History Exploration Path Generator
Generate interactive exploration paths through art history topics.
Drag nodes to rearrange, zoom with mouse wheel, and click nodes for details.
""")
with gr.Row():
with gr.Column(scale=1):
query_input = gr.Textbox(
label="Exploration Query",
placeholder="Enter your art history exploration query...",
lines=2
)
path_history = gr.Textbox(
label="Path History (JSON)",
placeholder="[]",
lines=3,
value="[]"
)
parameters = gr.Textbox(
label="Additional Parameters (JSON)",
placeholder="{}",
lines=3,
value="{}"
)
depth = gr.Slider(
label="Exploration Depth",
minimum=1,
maximum=10,
value=5,
step=1
)
domain = gr.Textbox(
label="Domain Context",
placeholder="Optional: Specify art history period or movement",
lines=1
)
generate_btn = gr.Button("Generate Exploration Path", variant="primary")
with gr.Column(scale=2):
with gr.Accordion("Exploration Result", open=False):
text_output = gr.JSON(label="Raw Result")
graph_output = gr.HTML(
label="Interactive Exploration Graph",
value="<div>Generate a path to see the visualization</div>",
elem_id="graph-visualization"
)
node_summary = gr.Textbox(
label="Node Details",
lines=5,
placeholder="Click on nodes to see details"
)
generate_btn.click(
fn=explore,
inputs=[query_input, path_history, parameters, depth, domain],
outputs=[text_output, graph_output, node_summary]
)
gr.Examples(
examples=[
["Explore the evolution of Renaissance painting techniques", "[]", "{}", 5, "Renaissance"],
["Investigate the influence of Japanese art on Impressionism", "[]", "{}", 7, "Impressionism"],
["Analyze the development of Cubism through Picasso's work", "[]", "{}", 6, "Cubism"]
],
inputs=[query_input, path_history, parameters, depth, domain]
)
return interface
if __name__ == "__main__":
try:
print(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
except Exception as e:
print(f"Failed to launch interface: {e}")