Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import os | |
from dotenv import load_dotenv | |
from art_explorer import ExplorationPathGenerator | |
from typing import Dict, Any, Optional, Union | |
from datetime import datetime | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
# Load environment variables | |
load_dotenv() | |
# Initialize the generator with error handling | |
try: | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
raise ValueError("GROQ_API_KEY not found in environment variables") | |
generator = ExplorationPathGenerator(api_key=api_key) | |
except Exception as e: | |
print(f"Error initializing ExplorationPathGenerator: {e}") | |
raise | |
def create_graph_visualization(nodes): | |
"""Create a graph visualization from exploration nodes""" | |
try: | |
# Create a new directed graph | |
G = nx.DiGraph() | |
# Add nodes to the graph | |
for node in nodes: | |
G.add_node(node['id'], | |
title=node['title'], | |
description=node['description'], | |
depth=node['depth']) | |
# Add edges based on connections | |
for node in nodes: | |
for conn in node['connections']: | |
G.add_edge(node['id'], | |
conn['target_id'], | |
weight=conn.get('relevance_score', 0)) | |
# Create the plot | |
plt.figure(figsize=(12, 8)) | |
# Use different colors for different depths | |
colors = ['#FF9999', '#99FF99', '#9999FF', '#FFFF99'] | |
node_colors = [colors[G.nodes[node]['depth'] % len(colors)] for node in G.nodes()] | |
# Create layout | |
pos = nx.spring_layout(G, k=1, iterations=50) | |
# Draw the graph | |
nx.draw(G, pos, | |
node_color=node_colors, | |
with_labels=True, | |
labels={node: G.nodes[node]['title'] for node in G.nodes()}, | |
node_size=2000, | |
font_size=8, | |
font_weight='bold', | |
arrows=True, | |
edge_color='gray', | |
width=1) | |
# Save to a temporary file | |
plt.savefig('temp_graph.png', format='png', dpi=300, bbox_inches='tight') | |
plt.close() | |
return 'temp_graph.png' | |
except Exception as e: | |
print(f"Error creating graph visualization: {e}") | |
return None | |
def format_output(result: Dict[str, Any]) -> str: | |
"""Format the exploration result for display with error handling""" | |
try: | |
return json.dumps(result, indent=2, ensure_ascii=False) | |
except Exception as e: | |
return json.dumps({ | |
"error": str(e), | |
"status": "failed", | |
"message": "Failed to format output" | |
}, indent=2) | |
def parse_json_input(json_str: str, default_value: Any) -> Any: | |
"""Safely parse JSON input with detailed error handling""" | |
if not json_str or json_str.strip() in ('', '{}', '[]'): | |
return default_value | |
try: | |
return json.loads(json_str) | |
except json.JSONDecodeError as e: | |
print(f"JSON parse error: {e}") | |
return default_value | |
def validate_parameters( | |
depth: int, | |
domain: str, | |
parameters: Dict[str, Any] | |
) -> Dict[str, Any]: | |
"""Validate and merge exploration parameters""" | |
validated_params = { | |
"depth": max(1, min(10, depth)), | |
"domain": domain if domain.strip() else None, | |
"previous_explorations": [] | |
} | |
if isinstance(parameters, dict): | |
validated_params.update(parameters) | |
return validated_params | |
def explore( | |
query: str, | |
path_history: str = "[]", | |
parameters: str = "{}", | |
depth: int = 5, | |
domain: str = "" | |
) -> tuple[str, Optional[str]]: | |
"""Generate exploration path and visualization""" | |
try: | |
if not query.strip(): | |
raise ValueError("Query cannot be empty") | |
selected_path = parse_json_input(path_history, []) | |
custom_parameters = parse_json_input(parameters, {}) | |
exploration_parameters = validate_parameters(depth, domain, custom_parameters) | |
print(f"Processing query: {query}") | |
print(f"Parameters: {json.dumps(exploration_parameters, indent=2)}") | |
result = generator.generate_exploration_path( | |
query=query, | |
selected_path=selected_path, | |
exploration_parameters=exploration_parameters | |
) | |
if not isinstance(result, dict): | |
raise ValueError("Invalid response format from generator") | |
# Create graph visualization if we have nodes | |
graph_path = None | |
if result.get("nodes"): | |
graph_path = create_graph_visualization(result["nodes"]) | |
return format_output(result), graph_path | |
except Exception as e: | |
error_response = { | |
"error": str(e), | |
"status": "failed", | |
"message": "Failed to generate exploration path", | |
"details": { | |
"query": query, | |
"depth": depth, | |
"domain": domain | |
} | |
} | |
print(f"Error in explore function: {e}") | |
return format_output(error_response), None | |
def create_interface() -> gr.Blocks: | |
"""Create and configure the Gradio interface""" | |
with gr.Blocks(title="Art History Exploration Path Generator") as interface: | |
gr.Markdown("""# Art History Exploration Path Generator | |
## Features: | |
- Dynamic exploration path generation | |
- Contextual understanding of art history | |
- Multi-dimensional analysis | |
- Customizable exploration depth | |
- Interactive visualization | |
## Usage: | |
1. Enter your art history query | |
2. Adjust exploration depth (1-10) | |
3. Optionally specify domain context | |
4. View generated exploration path and visualization""") | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox( | |
label="Exploration Query", | |
placeholder="Enter your art history exploration query...", | |
lines=2 | |
) | |
path_history = gr.Textbox( | |
label="Path History (JSON)", | |
placeholder="[]", | |
lines=3, | |
value="[]" | |
) | |
parameters = gr.Textbox( | |
label="Additional Parameters (JSON)", | |
placeholder="{}", | |
lines=3, | |
value="{}" | |
) | |
depth = gr.Slider( | |
label="Exploration Depth", | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1 | |
) | |
domain = gr.Textbox( | |
label="Domain Context", | |
placeholder="Optional: Specify art history period or movement", | |
lines=1 | |
) | |
generate_btn = gr.Button("Generate Exploration Path") | |
with gr.Column(): | |
text_output = gr.JSON(label="Exploration Result") | |
graph_output = gr.Image(label="Exploration Graph") | |
examples = [ | |
["Explore the evolution of Renaissance painting techniques", "[]", "{}", 5, "Renaissance"], | |
["Investigate the influence of Japanese art on Impressionism", "[]", "{}", 7, "Impressionism"], | |
["Analyze the development of Cubism through Picasso's work", "[]", "{}", 6, "Cubism"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[query_input, path_history, parameters, depth, domain] | |
) | |
generate_btn.click( | |
fn=explore, | |
inputs=[query_input, path_history, parameters, depth, domain], | |
outputs=[text_output, graph_output] | |
) | |
return interface | |
if __name__ == "__main__": | |
try: | |
print(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====") | |
# Create and launch interface | |
demo = create_interface() | |
demo.launch() | |
except Exception as e: | |
print(f"Failed to launch interface: {e}") |