File size: 7,781 Bytes
d6a8ad7
0fc768b
 
4244d2c
0fc768b
f134ba7
10cff5d
47f3b7c
 
c6219ec
4244d2c
 
8f7fdc3
f134ba7
 
10cff5d
 
 
 
f134ba7
10cff5d
 
c38a10f
47f3b7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fc768b
10cff5d
 
 
 
 
 
 
 
 
0fc768b
 
47f3b7c
10cff5d
 
 
 
 
 
 
0fc768b
f134ba7
10cff5d
 
 
f134ba7
10cff5d
 
47f3b7c
10cff5d
 
 
 
 
 
f134ba7
0fc768b
10cff5d
 
 
 
 
47f3b7c
 
10cff5d
 
 
 
 
 
 
 
 
 
 
 
47f3b7c
10cff5d
 
 
 
 
 
 
47f3b7c
 
 
 
 
 
10cff5d
 
 
 
 
 
 
 
 
 
 
 
 
47f3b7c
10cff5d
47f3b7c
 
10cff5d
47f3b7c
 
10cff5d
 
 
 
 
 
47f3b7c
10cff5d
 
 
 
 
47f3b7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10cff5d
 
 
 
47f3b7c
 
 
10cff5d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import gradio as gr
import json
import os
from dotenv import load_dotenv
from art_explorer import ExplorationPathGenerator
from typing import Dict, Any, Optional, Union
from datetime import datetime
import networkx as nx
import matplotlib.pyplot as plt

# Load environment variables
load_dotenv()

# Initialize the generator with error handling
try:
  api_key = os.getenv("GROQ_API_KEY")
  if not api_key:
      raise ValueError("GROQ_API_KEY not found in environment variables")
  generator = ExplorationPathGenerator(api_key=api_key)
except Exception as e:
  print(f"Error initializing ExplorationPathGenerator: {e}")
  raise

def create_graph_visualization(nodes):
  """Create a graph visualization from exploration nodes"""
  try:
      # Create a new directed graph
      G = nx.DiGraph()
      
      # Add nodes to the graph
      for node in nodes:
          G.add_node(node['id'], 
                    title=node['title'],
                    description=node['description'],
                    depth=node['depth'])
      
      # Add edges based on connections
      for node in nodes:
          for conn in node['connections']:
              G.add_edge(node['id'], 
                        conn['target_id'], 
                        weight=conn.get('relevance_score', 0))
      
      # Create the plot
      plt.figure(figsize=(12, 8))
      
      # Use different colors for different depths
      colors = ['#FF9999', '#99FF99', '#9999FF', '#FFFF99']
      node_colors = [colors[G.nodes[node]['depth'] % len(colors)] for node in G.nodes()]
      
      # Create layout
      pos = nx.spring_layout(G, k=1, iterations=50)
      
      # Draw the graph
      nx.draw(G, pos,
              node_color=node_colors,
              with_labels=True,
              labels={node: G.nodes[node]['title'] for node in G.nodes()},
              node_size=2000,
              font_size=8,
              font_weight='bold',
              arrows=True,
              edge_color='gray',
              width=1)
      
      # Save to a temporary file
      plt.savefig('temp_graph.png', format='png', dpi=300, bbox_inches='tight')
      plt.close()
      
      return 'temp_graph.png'
  except Exception as e:
      print(f"Error creating graph visualization: {e}")
      return None

def format_output(result: Dict[str, Any]) -> str:
  """Format the exploration result for display with error handling"""
  try:
      return json.dumps(result, indent=2, ensure_ascii=False)
  except Exception as e:
      return json.dumps({
          "error": str(e),
          "status": "failed",
          "message": "Failed to format output"
      }, indent=2)

def parse_json_input(json_str: str, default_value: Any) -> Any:
  """Safely parse JSON input with detailed error handling"""
  if not json_str or json_str.strip() in ('', '{}', '[]'):
      return default_value
  try:
      return json.loads(json_str)
  except json.JSONDecodeError as e:
      print(f"JSON parse error: {e}")
      return default_value

def validate_parameters(
  depth: int,
  domain: str,
  parameters: Dict[str, Any]
) -> Dict[str, Any]:
  """Validate and merge exploration parameters"""
  validated_params = {
      "depth": max(1, min(10, depth)),
      "domain": domain if domain.strip() else None,
      "previous_explorations": []
  }
  if isinstance(parameters, dict):
      validated_params.update(parameters)
  return validated_params

def explore(
  query: str,
  path_history: str = "[]",
  parameters: str = "{}",
  depth: int = 5,
  domain: str = ""
) -> tuple[str, Optional[str]]:
  """Generate exploration path and visualization"""
  try:
      if not query.strip():
          raise ValueError("Query cannot be empty")

      selected_path = parse_json_input(path_history, [])
      custom_parameters = parse_json_input(parameters, {})
      exploration_parameters = validate_parameters(depth, domain, custom_parameters)

      print(f"Processing query: {query}")
      print(f"Parameters: {json.dumps(exploration_parameters, indent=2)}")

      result = generator.generate_exploration_path(
          query=query,
          selected_path=selected_path,
          exploration_parameters=exploration_parameters
      )
      
      if not isinstance(result, dict):
          raise ValueError("Invalid response format from generator")
      
      # Create graph visualization if we have nodes
      graph_path = None
      if result.get("nodes"):
          graph_path = create_graph_visualization(result["nodes"])
      
      return format_output(result), graph_path

  except Exception as e:
      error_response = {
          "error": str(e),
          "status": "failed",
          "message": "Failed to generate exploration path",
          "details": {
              "query": query,
              "depth": depth,
              "domain": domain
          }
      }
      print(f"Error in explore function: {e}")
      return format_output(error_response), None

def create_interface() -> gr.Blocks:
  """Create and configure the Gradio interface"""
  
  with gr.Blocks(title="Art History Exploration Path Generator") as interface:
      gr.Markdown("""# Art History Exploration Path Generator
      
## Features:
- Dynamic exploration path generation
- Contextual understanding of art history
- Multi-dimensional analysis
- Customizable exploration depth
- Interactive visualization

## Usage:
1. Enter your art history query
2. Adjust exploration depth (1-10)
3. Optionally specify domain context
4. View generated exploration path and visualization""")
      
      with gr.Row():
          with gr.Column():
              query_input = gr.Textbox(
                  label="Exploration Query",
                  placeholder="Enter your art history exploration query...",
                  lines=2
              )
              path_history = gr.Textbox(
                  label="Path History (JSON)",
                  placeholder="[]",
                  lines=3,
                  value="[]"
              )
              parameters = gr.Textbox(
                  label="Additional Parameters (JSON)",
                  placeholder="{}",
                  lines=3,
                  value="{}"
              )
              depth = gr.Slider(
                  label="Exploration Depth",
                  minimum=1,
                  maximum=10,
                  value=5,
                  step=1
              )
              domain = gr.Textbox(
                  label="Domain Context",
                  placeholder="Optional: Specify art history period or movement",
                  lines=1
              )
              generate_btn = gr.Button("Generate Exploration Path")

          with gr.Column():
              text_output = gr.JSON(label="Exploration Result")
              graph_output = gr.Image(label="Exploration Graph")

      examples = [
          ["Explore the evolution of Renaissance painting techniques", "[]", "{}", 5, "Renaissance"],
          ["Investigate the influence of Japanese art on Impressionism", "[]", "{}", 7, "Impressionism"],
          ["Analyze the development of Cubism through Picasso's work", "[]", "{}", 6, "Cubism"]
      ]

      gr.Examples(
          examples=examples,
          inputs=[query_input, path_history, parameters, depth, domain]
      )

      generate_btn.click(
          fn=explore,
          inputs=[query_input, path_history, parameters, depth, domain],
          outputs=[text_output, graph_output]
      )

  return interface

if __name__ == "__main__":
  try:
      print(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
      # Create and launch interface
      demo = create_interface()
      demo.launch()
  except Exception as e:
      print(f"Failed to launch interface: {e}")