Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
from datetime import datetime | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
from prompts import SYSTEM_PROMPT, format_exploration_prompt | |
# Load environment variables | |
load_dotenv() | |
class ExplorationPathGenerator: | |
def __init__(self, api_key: str): | |
self.client = OpenAI( | |
api_key=api_key, | |
base_url="https://api.groq.com/openai/v1" | |
) | |
def generate_exploration_path(self, query: str, selected_path=None, exploration_parameters=None): | |
try: | |
if selected_path is None: | |
selected_path = [] | |
if exploration_parameters is None: | |
exploration_parameters = {} | |
# Use the prompt from prompt.py | |
formatted_prompt = format_exploration_prompt( | |
user_query=query, | |
selected_path=selected_path, | |
exploration_parameters=exploration_parameters | |
) | |
response = self.client.chat.completions.create( | |
model="mixtral-8x7b-32768", | |
messages=[ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": formatted_prompt} | |
], | |
temperature=0.7, | |
max_tokens=4000 | |
) | |
result = json.loads(response.choices[0].message.content) | |
# Convert exploration response to graph format | |
nodes = [] | |
node_id_counter = 0 | |
# Add meta insights as central node | |
node_id_counter += 1 | |
meta_node = { | |
"id": f"meta_{node_id_counter}", | |
"title": "Exploration Summary", | |
"description": result["exploration_summary"]["current_context"], | |
"depth": 0, | |
"connections": [] | |
} | |
nodes.append(meta_node) | |
# Create nodes from standard axes | |
for axis in result["knowledge_axes"]["standard_axes"]: | |
node_id_counter += 1 | |
axis_node = { | |
"id": f"std_{node_id_counter}", | |
"title": axis["name"], | |
"description": f"Current values: {', '.join(axis['current_values'])}", | |
"depth": 1, | |
"connections": [] | |
} | |
# Connect to meta node | |
meta_node["connections"].append({ | |
"target_id": axis_node["id"], | |
"relevance_score": 0.8 | |
}) | |
# Add potential values as nodes | |
for value in axis["potential_values"]: | |
node_id_counter += 1 | |
value_node = { | |
"id": f"val_{node_id_counter}", | |
"title": value["value"], | |
"description": value["contextual_rationale"], | |
"depth": 2, | |
"connections": [] | |
} | |
nodes.append(value_node) | |
axis_node["connections"].append({ | |
"target_id": value_node["id"], | |
"relevance_score": value["relevance_score"] / 100 | |
}) | |
nodes.append(axis_node) | |
# Create nodes from emergent axes | |
for axis in result["knowledge_axes"]["emergent_axes"]: | |
node_id_counter += 1 | |
emergent_node = { | |
"id": f"emg_{node_id_counter}", | |
"title": f"{axis['name']} (Emergent)", | |
"description": f"Parent axis: {axis['parent_axis']}", | |
"depth": 2, | |
"connections": [] | |
} | |
# Connect to meta node | |
meta_node["connections"].append({ | |
"target_id": emergent_node["id"], | |
"relevance_score": 0.6 | |
}) | |
# Add innovative values | |
for value in axis["innovative_values"]: | |
node_id_counter += 1 | |
value_node = { | |
"id": f"inv_{node_id_counter}", | |
"title": value["value"], | |
"description": value["discovery_potential"], | |
"depth": 3, | |
"connections": [] | |
} | |
nodes.append(value_node) | |
emergent_node["connections"].append({ | |
"target_id": value_node["id"], | |
"relevance_score": value["innovation_score"] / 100 | |
}) | |
nodes.append(emergent_node) | |
return {"nodes": nodes} | |
except Exception as e: | |
print(f"Error generating exploration path: {e}") | |
return {"error": str(e)} | |
def create_visualization_html(self, nodes): | |
"""Create a simple HTML visualization""" | |
html_content = "<div style='padding: 20px; font-family: Arial, sans-serif;'>" | |
# Create a style for the nodes | |
html_content += """ | |
<style> | |
.node-card { | |
border: 1px solid #ddd; | |
border-radius: 8px; | |
padding: 15px; | |
margin: 10px 0; | |
background-color: #f9f9f9; | |
} | |
.node-title { | |
font-weight: bold; | |
color: #2c3e50; | |
margin-bottom: 8px; | |
} | |
.node-description { | |
color: #34495e; | |
margin-bottom: 8px; | |
} | |
.node-connections { | |
font-size: 0.9em; | |
color: #7f8c8d; | |
} | |
.depth-indicator { | |
display: inline-block; | |
padding: 3px 8px; | |
border-radius: 12px; | |
font-size: 0.8em; | |
margin-bottom: 5px; | |
} | |
</style> | |
""" | |
# Create nodes visualization | |
for node in nodes: | |
depth_color = ['#FF9999', '#99FF99', '#9999FF'][node['depth'] % 3] | |
html_content += f""" | |
<div class='node-card'> | |
<div class='depth-indicator' style='background-color: {depth_color}'> | |
Depth: {node['depth']} | |
</div> | |
<div class='node-title'>{node['title']}</div> | |
<div class='node-description'>{node['description']}</div> | |
<div class='node-connections'> | |
""" | |
# Add connections | |
if node.get('connections'): | |
html_content += "<strong>Connections:</strong><ul>" | |
for conn in node['connections']: | |
html_content += f"<li>Connected to: {conn['target_id']}" | |
if 'relevance_score' in conn: | |
html_content += f" (Relevance: {conn['relevance_score']:.2f})" | |
html_content += "</li>" | |
html_content += "</ul>" | |
html_content += "</div></div>" | |
html_content += "</div>" | |
return html_content | |
def explore(query: str, path_history: str = "[]", parameters: str = "{}", depth: int = 5, domain: str = "") -> tuple: | |
"""Generate exploration path and visualization""" | |
try: | |
# Initialize generator | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
raise ValueError("GROQ_API_KEY not found in environment variables") | |
generator = ExplorationPathGenerator(api_key=api_key) | |
# Parse inputs | |
try: | |
selected_path = json.loads(path_history) | |
exploration_parameters = json.loads(parameters) | |
except json.JSONDecodeError as e: | |
raise ValueError(f"Invalid JSON input: {str(e)}") | |
# Add parameters | |
exploration_parameters.update({ | |
"domain": domain, | |
"depth": depth | |
}) | |
# Generate result | |
result = generator.generate_exploration_path( | |
query=query, | |
selected_path=selected_path, | |
exploration_parameters=exploration_parameters | |
) | |
# Create visualization | |
graph_html = generator.create_visualization_html(result.get('nodes', [])) | |
summary = f"Exploration path generated with {len(result.get('nodes', []))} nodes" | |
return json.dumps(result), graph_html, summary | |
except Exception as e: | |
error_response = { | |
"error": str(e), | |
"status": "failed", | |
"timestamp": datetime.now().isoformat(), | |
"query": query | |
} | |
return json.dumps(error_response), "<div>Error generating visualization</div>", f"Error: {str(e)}" | |
def create_interface() -> gr.Blocks: | |
"""Create and configure the Gradio interface""" | |
with gr.Blocks( | |
title="Art History Exploration Path Generator", | |
theme=gr.themes.Soft() | |
) as interface: | |
gr.Markdown(""" | |
# Knowledge Exploration Path Generator | |
Generate interactive exploration paths through complex topics. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
query_input = gr.Textbox( | |
label="Exploration Query", | |
placeholder="Enter your exploration query...", | |
lines=2 | |
) | |
depth = gr.Slider( | |
label="Exploration Depth", | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1 | |
) | |
domain = gr.Textbox( | |
label="Domain Context", | |
placeholder="Optional: Specify domain context", | |
lines=1 | |
) | |
generate_btn = gr.Button("Generate Exploration Path", variant="primary") | |
with gr.Column(scale=2): | |
text_output = gr.JSON(label="Raw Result") | |
graph_output = gr.HTML(label="Visualization") | |
summary_output = gr.Textbox(label="Summary", lines=2) | |
generate_btn.click( | |
fn=explore, | |
inputs=[ | |
query_input, | |
gr.Textbox(value="[]", visible=False), | |
gr.Textbox(value="{}", visible=False), | |
depth, | |
domain | |
], | |
outputs=[text_output, graph_output, summary_output] | |
) | |
return interface | |
if __name__ == "__main__": | |
try: | |
print(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====") | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) | |
except Exception as e: | |
print(f"Failed to launch interface: {e}") | |