baconnier commited on
Commit
8f7fdc3
Β·
verified Β·
1 Parent(s): a066e1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -79
app.py CHANGED
@@ -2,9 +2,15 @@ import os
2
  import openai
3
  import gradio as gr
4
  import json
 
5
  import plotly.graph_objects as go
 
 
6
  from variables import CONTEXTUAL_ZOOM_PROMPT, CONTEXTUAL_ZOOM_default_response
7
 
 
 
 
8
  class ArtExplorer:
9
  def __init__(self):
10
  self.client = openai.OpenAI(
@@ -16,17 +22,22 @@ class ArtExplorer:
16
  "selections": {}
17
  }
18
 
19
- def create_map(self, locations):
20
  """Create a Plotly map figure from location data"""
21
  if not locations:
22
- locations = [{"name": "Paris", "lat": 48.8566, "lon": 2.3522}]
 
 
 
 
 
23
 
24
  fig = go.Figure(go.Scattermapbox(
25
- lat=[loc.get('lat') for loc in locations],
26
- lon=[loc.get('lon') for loc in locations],
27
  mode='markers',
28
  marker=go.scattermapbox.Marker(size=10),
29
- text=[loc.get('name') for loc in locations]
30
  ))
31
 
32
  fig.update_layout(
@@ -39,29 +50,8 @@ class ArtExplorer:
39
  )
40
  return fig
41
 
42
- def clean_json_string(self, json_str: str) -> str:
43
- """Clean and prepare JSON string for parsing"""
44
- # Remove any leading/trailing whitespace and newlines
45
- json_str = json_str.strip()
46
-
47
- # Remove any BOM or special characters at the start
48
- if json_str.startswith('\ufeff'):
49
- json_str = json_str[1:]
50
-
51
- # Ensure it starts with {
52
- if not json_str.startswith('{'):
53
- start_idx = json_str.find('{')
54
- if start_idx != -1:
55
- json_str = json_str[start_idx:]
56
-
57
- # Ensure it ends with }
58
- end_idx = json_str.rfind('}')
59
- if end_idx != -1:
60
- json_str = json_str[:end_idx+1]
61
-
62
- return json_str
63
-
64
- def get_llm_response(self, query: str, zoom_context: dict = None) -> dict:
65
  try:
66
  print("\n=== Starting LLM Request ===")
67
  print(f"Input query: {query}")
@@ -87,39 +77,17 @@ class ArtExplorer:
87
  )}
88
  ]
89
 
90
- print("\nPrepared messages for LLM:")
91
- print(json.dumps(messages, indent=2))
92
-
93
  print("\nSending request to LLM...")
94
  response = self.client.chat.completions.create(
95
  model="mixtral-8x7b-32768",
96
  messages=messages,
97
  temperature=0.1,
98
- max_tokens=2048
 
99
  )
100
- print("\nReceived raw response from LLM:")
101
- print(response)
102
-
103
- # Get the response content and clean it
104
- content = response.choices[0].message.content
105
- print("\nExtracted content from response:")
106
- print(content)
107
-
108
- # Clean the JSON string
109
- cleaned_content = self.clean_json_string(content)
110
- print("\nCleaned JSON string:")
111
- print(cleaned_content)
112
-
113
- try:
114
- result = json.loads(cleaned_content)
115
- print("\nSuccessfully parsed JSON:")
116
- print(json.dumps(result, indent=2))
117
- return result
118
- except json.JSONDecodeError as e:
119
- print(f"\nJSON parsing error: {str(e)}")
120
- print(f"Failed content: {cleaned_content}")
121
- return self.get_default_response()
122
-
123
  except Exception as e:
124
  print(f"\nError in LLM response: {str(e)}")
125
  print(f"Full error details: {e.__class__.__name__}")
@@ -127,10 +95,12 @@ class ArtExplorer:
127
  print(traceback.format_exc())
128
  return self.get_default_response()
129
 
130
- def get_default_response(self):
131
- return CONTEXTUAL_ZOOM_default_response
 
132
 
133
- def create_interface(self):
 
134
  with gr.Blocks() as demo:
135
  gr.Markdown("# Art History Explorer")
136
 
@@ -160,6 +130,7 @@ class ArtExplorer:
160
  geo_zoom = gr.Button("πŸ” Zoom Geography")
161
 
162
  with gr.Row():
 
163
  style_select = gr.Dropdown(
164
  choices=["Classical", "Modern"],
165
  multiselect=True,
@@ -169,48 +140,50 @@ class ArtExplorer:
169
  style_explanation = gr.Markdown()
170
  style_zoom = gr.Button("πŸ” Zoom Styles")
171
 
172
- def initial_search(query):
173
  """Handle the initial search query"""
174
- config = self.get_llm_response(query)
175
- temporal_config = config["axis_configurations"]["temporal"]["current_zoom"]
176
- geographical_config = config["axis_configurations"]["geographical"]["current_zoom"]
177
- style_config = config["axis_configurations"]["style"]["current_zoom"]
 
178
 
179
- map_fig = self.create_map(geographical_config["locations"])
180
 
181
  return (
182
- temporal_config["range"],
183
  map_fig,
184
- style_config["options"],
185
- temporal_config["explanation"],
186
- geographical_config.get("explanation", ""),
187
- style_config["explanation"]
188
  )
189
 
190
- def zoom_axis(query, axis_name, current_value):
191
  """Handle zoom events for any axis"""
192
  self.current_state["zoom_level"] += 1
193
- config = self.get_llm_response(
194
  query,
195
  zoom_context={axis_name: current_value}
196
  )
197
- axis_config = config["axis_configurations"][axis_name]["current_zoom"]
 
198
 
199
  if axis_name == "temporal":
200
  return (
201
- axis_config["range"],
202
- axis_config["explanation"]
203
  )
204
  elif axis_name == "geographical":
205
- map_fig = self.create_map(axis_config["locations"])
206
  return (
207
  map_fig,
208
- axis_config.get("explanation", "")
209
  )
210
  else: # style
211
  return (
212
- axis_config["options"],
213
- axis_config["explanation"]
214
  )
215
 
216
  # Connect event handlers
@@ -248,7 +221,8 @@ class ArtExplorer:
248
  return demo
249
 
250
  def main():
251
- print("Starting initialization...")
 
252
  explorer = ArtExplorer()
253
  print("Created ArtExplorer instance")
254
  demo = explorer.create_interface()
 
2
  import openai
3
  import gradio as gr
4
  import json
5
+ from typing import Optional, Dict, Any, List
6
  import plotly.graph_objects as go
7
+ import instructor
8
+ from models import ArtHistoryResponse, Location
9
  from variables import CONTEXTUAL_ZOOM_PROMPT, CONTEXTUAL_ZOOM_default_response
10
 
11
+ # Enable instructor
12
+ instructor.patch()
13
+
14
  class ArtExplorer:
15
  def __init__(self):
16
  self.client = openai.OpenAI(
 
22
  "selections": {}
23
  }
24
 
25
+ def create_map(self, locations: List[Location]) -> go.Figure:
26
  """Create a Plotly map figure from location data"""
27
  if not locations:
28
+ locations = [Location(
29
+ name="Paris",
30
+ lat=48.8566,
31
+ lon=2.3522,
32
+ relevance="Default location"
33
+ )]
34
 
35
  fig = go.Figure(go.Scattermapbox(
36
+ lat=[loc.lat for loc in locations],
37
+ lon=[loc.lon for loc in locations],
38
  mode='markers',
39
  marker=go.scattermapbox.Marker(size=10),
40
+ text=[loc.name for loc in locations]
41
  ))
42
 
43
  fig.update_layout(
 
50
  )
51
  return fig
52
 
53
+ def get_llm_response(self, query: str, zoom_context: Optional[Dict[str, Any]] = None) -> ArtHistoryResponse:
54
+ """Get response from LLM with proper validation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  try:
56
  print("\n=== Starting LLM Request ===")
57
  print(f"Input query: {query}")
 
77
  )}
78
  ]
79
 
 
 
 
80
  print("\nSending request to LLM...")
81
  response = self.client.chat.completions.create(
82
  model="mixtral-8x7b-32768",
83
  messages=messages,
84
  temperature=0.1,
85
+ max_tokens=2048,
86
+ response_model=ArtHistoryResponse
87
  )
88
+ print("\nReceived validated response from LLM")
89
+ return response
90
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
  print(f"\nError in LLM response: {str(e)}")
93
  print(f"Full error details: {e.__class__.__name__}")
 
95
  print(traceback.format_exc())
96
  return self.get_default_response()
97
 
98
+ def get_default_response(self) -> ArtHistoryResponse:
99
+ """Return default response when LLM fails"""
100
+ return ArtHistoryResponse(**CONTEXTUAL_ZOOM_default_response)
101
 
102
+ def create_interface(self) -> gr.Blocks:
103
+ """Create the Gradio interface"""
104
  with gr.Blocks() as demo:
105
  gr.Markdown("# Art History Explorer")
106
 
 
130
  geo_zoom = gr.Button("πŸ” Zoom Geography")
131
 
132
  with gr.Row():
133
+ # Style axis
134
  style_select = gr.Dropdown(
135
  choices=["Classical", "Modern"],
136
  multiselect=True,
 
140
  style_explanation = gr.Markdown()
141
  style_zoom = gr.Button("πŸ” Zoom Styles")
142
 
143
+ def initial_search(query: str) -> tuple:
144
  """Handle the initial search query"""
145
+ response = self.get_llm_response(query)
146
+
147
+ temporal_config = response.axis_configurations["temporal"].current_zoom
148
+ geographical_config = response.axis_configurations["geographical"].current_zoom
149
+ style_config = response.axis_configurations["style"].current_zoom
150
 
151
+ map_fig = self.create_map(geographical_config.locations)
152
 
153
  return (
154
+ temporal_config.range,
155
  map_fig,
156
+ style_config.options,
157
+ temporal_config.explanation,
158
+ geographical_config.explanation,
159
+ style_config.explanation
160
  )
161
 
162
+ def zoom_axis(query: str, axis_name: str, current_value: Any) -> tuple:
163
  """Handle zoom events for any axis"""
164
  self.current_state["zoom_level"] += 1
165
+ response = self.get_llm_response(
166
  query,
167
  zoom_context={axis_name: current_value}
168
  )
169
+
170
+ axis_config = response.axis_configurations[axis_name].current_zoom
171
 
172
  if axis_name == "temporal":
173
  return (
174
+ axis_config.range,
175
+ axis_config.explanation
176
  )
177
  elif axis_name == "geographical":
178
+ map_fig = self.create_map(axis_config.locations)
179
  return (
180
  map_fig,
181
+ axis_config.explanation
182
  )
183
  else: # style
184
  return (
185
+ axis_config.options,
186
+ axis_config.explanation
187
  )
188
 
189
  # Connect event handlers
 
221
  return demo
222
 
223
  def main():
224
+ """Main entry point"""
225
+ print("Starting Art History Explorer...")
226
  explorer = ArtExplorer()
227
  print("Created ArtExplorer instance")
228
  demo = explorer.create_interface()