Iisakki commited on
Commit
1d70dc9
·
1 Parent(s): 06e15bc

remove another comment

Browse files
Files changed (1) hide show
  1. wanderlust.py +38 -56
wanderlust.py CHANGED
@@ -4,7 +4,6 @@ import os
4
  import ipyleaflet
5
  from openai import OpenAI, NotFoundError
6
  from openai.types.beta import Thread
7
- from openai.types.beta.threads import Run
8
 
9
  import time
10
 
@@ -13,9 +12,7 @@ import solara
13
  center_default = (0, 0)
14
  zoom_default = 2
15
 
16
- messages_default = []
17
-
18
- messages = solara.reactive(messages_default)
19
  zoom_level = solara.reactive(zoom_default)
20
  center = solara.reactive(center_default)
21
  markers = solara.reactive([])
@@ -25,6 +22,7 @@ openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
25
  model = "gpt-4-1106-preview"
26
 
27
 
 
28
  tools = [
29
  {
30
  "type": "function",
@@ -80,7 +78,6 @@ tools = [
80
 
81
 
82
  def update_map(longitude, latitude, zoom):
83
- print("update_map", longitude, latitude, zoom)
84
  center.set((latitude, longitude))
85
  zoom_level.set(zoom)
86
  return "Map updated"
@@ -111,12 +108,9 @@ def ai_call(tool_call):
111
 
112
  @solara.component
113
  def Map():
114
- print("Map", zoom_level.value, center.value, markers.value)
115
  ipyleaflet.Map.element( # type: ignore
116
  zoom=zoom_level.value,
117
- # on_zoom=zoom_level.set,
118
  center=center.value,
119
- # on_center=center.set,
120
  scroll_wheel_zoom=True,
121
  layers=[
122
  ipyleaflet.TileLayer.element(url=url),
@@ -134,7 +128,6 @@ def ChatInterface():
134
  run_id: solara.Reactive[str] = solara.use_reactive(None)
135
 
136
  thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
137
- print("thread id:", thread.id)
138
 
139
  def add_message(value: str):
140
  if value == "":
@@ -149,7 +142,6 @@ def ChatInterface():
149
  assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
150
  tools=tools,
151
  ).id
152
- print("Run id:", run_id.value)
153
 
154
  def poll():
155
  if not run_id.value:
@@ -159,7 +151,8 @@ def ChatInterface():
159
  try:
160
  run = openai.beta.threads.runs.retrieve(
161
  run_id.value, thread_id=thread.id
162
- ) # When run is complete
 
163
  except NotFoundError:
164
  continue
165
  if run.status == "requires_action":
@@ -167,6 +160,7 @@ def ChatInterface():
167
  for tool_call in run.required_action.submit_tool_outputs.tool_calls:
168
  tool_output = ai_call(tool_call)
169
  tool_outputs.append(tool_output)
 
170
  openai.beta.threads.runs.submit_tool_outputs(
171
  thread_id=thread.id,
172
  run_id=run_id.value,
@@ -182,27 +176,10 @@ def ChatInterface():
182
  run_id.set(None)
183
  completed = True
184
  time.sleep(0.1)
185
- retrieved_messages = openai.beta.threads.messages.list(thread_id=thread.id)
186
- messages.set(retrieved_messages.data)
187
 
188
  result = solara.use_thread(poll, dependencies=[run_id.value])
189
 
190
- def handle_message(message):
191
- print("handle", message)
192
- messages = []
193
- if message.role == "assistant":
194
- tools_calls = message.get("tool_calls", [])
195
- for tool_call in tools_calls:
196
- messages.append(ai_call(tool_call))
197
- return messages
198
-
199
- def handle_initial():
200
- print("handle initial", messages.value)
201
- for message in messages.value:
202
- handle_message(message)
203
-
204
- solara.use_effect(handle_initial, [])
205
- # result = solara.use_thread(ask, dependencies=[messages.value])
206
  with solara.Column(
207
  classes=["chat-interface"],
208
  ):
@@ -214,16 +191,25 @@ def ChatInterface():
214
  "overflow-y": "auto",
215
  "height": "100px",
216
  "flex-direction": "column-reverse",
217
- }
 
218
  ):
219
  for message in reversed(messages.value):
220
  with solara.Row(style={"align-items": "flex-start"}):
221
- if message.role == "user":
 
 
 
 
 
 
 
 
 
222
  solara.Text(
223
  message.content[0].text.value,
224
  classes=["chat-message", "user-message"],
225
  )
226
- assert len(message.content) == 1
227
  elif message.role == "assistant":
228
  if message.content[0].text.value:
229
  solara.v.Icon(
@@ -246,8 +232,6 @@ def ChatInterface():
246
  repr(message),
247
  classes=["chat-message", "assistant-message"],
248
  )
249
- elif message["role"] == "tool":
250
- pass # no need to display
251
  else:
252
  solara.v.Icon(
253
  children=["mdi-compass-outline"],
@@ -272,21 +256,6 @@ def ChatInterface():
272
 
273
  @solara.component
274
  def Page():
275
- reset_counter, set_reset_counter = solara.use_state(0)
276
- print("reset", reset_counter, f"chat-{reset_counter}")
277
-
278
- def reset_ui():
279
- set_reset_counter(reset_counter + 1)
280
-
281
- def save():
282
- with open("log.json", "w") as f:
283
- json.dump(messages.value, f)
284
-
285
- def load():
286
- with open("log.json", "r") as f:
287
- messages.set(json.load(f))
288
- reset_ui()
289
-
290
  with solara.Column(
291
  classes=["ui-container"],
292
  gap="5vh",
@@ -299,16 +268,12 @@ def Page():
299
  unsafe_innerHTML="Wanderlust",
300
  style={"display": "inline-block"},
301
  )
302
- # with solara.Row(gap="10px"):
303
- # solara.Button("Save", on_click=save)
304
- # solara.Button("Load", on_click=load)
305
- # solara.Button("Soft reset", on_click=reset_ui)
306
  with solara.Row(
307
  justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
308
  ):
309
- ChatInterface().key(f"chat-{reset_counter}")
310
  with solara.Column(classes=["map-container"]):
311
- Map() # .key(f"map-{reset_counter}")
312
 
313
  solara.Style(
314
  """
@@ -335,13 +300,30 @@ def Page():
335
  height: 100%;
336
  width: 38vw;
337
  justify-content: center;
338
- background: linear-gradient(0deg, transparent 75%, white 100%);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  }
340
  .map-container{
341
  width: 50vw;
342
  height: 100%;
343
  justify-content: center;
344
  }
 
 
 
345
  @media screen and (max-aspect-ratio: 1/1) {
346
  .ui-container{
347
  padding: 30px;
 
4
  import ipyleaflet
5
  from openai import OpenAI, NotFoundError
6
  from openai.types.beta import Thread
 
7
 
8
  import time
9
 
 
12
  center_default = (0, 0)
13
  zoom_default = 2
14
 
15
+ messages = solara.reactive([])
 
 
16
  zoom_level = solara.reactive(zoom_default)
17
  center = solara.reactive(center_default)
18
  markers = solara.reactive([])
 
22
  model = "gpt-4-1106-preview"
23
 
24
 
25
+ # Declare tools for openai assistant to use
26
  tools = [
27
  {
28
  "type": "function",
 
78
 
79
 
80
  def update_map(longitude, latitude, zoom):
 
81
  center.set((latitude, longitude))
82
  zoom_level.set(zoom)
83
  return "Map updated"
 
108
 
109
  @solara.component
110
  def Map():
 
111
  ipyleaflet.Map.element( # type: ignore
112
  zoom=zoom_level.value,
 
113
  center=center.value,
 
114
  scroll_wheel_zoom=True,
115
  layers=[
116
  ipyleaflet.TileLayer.element(url=url),
 
128
  run_id: solara.Reactive[str] = solara.use_reactive(None)
129
 
130
  thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
 
131
 
132
  def add_message(value: str):
133
  if value == "":
 
142
  assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
143
  tools=tools,
144
  ).id
 
145
 
146
  def poll():
147
  if not run_id.value:
 
151
  try:
152
  run = openai.beta.threads.runs.retrieve(
153
  run_id.value, thread_id=thread.id
154
+ )
155
+ # Above will raise NotFoundError when run creation is still in progress
156
  except NotFoundError:
157
  continue
158
  if run.status == "requires_action":
 
160
  for tool_call in run.required_action.submit_tool_outputs.tool_calls:
161
  tool_output = ai_call(tool_call)
162
  tool_outputs.append(tool_output)
163
+ messages.set([*messages.value, tool_output])
164
  openai.beta.threads.runs.submit_tool_outputs(
165
  thread_id=thread.id,
166
  run_id=run_id.value,
 
176
  run_id.set(None)
177
  completed = True
178
  time.sleep(0.1)
 
 
179
 
180
  result = solara.use_thread(poll, dependencies=[run_id.value])
181
 
182
+ # Create DOM for chat interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  with solara.Column(
184
  classes=["chat-interface"],
185
  ):
 
191
  "overflow-y": "auto",
192
  "height": "100px",
193
  "flex-direction": "column-reverse",
194
+ },
195
+ classes=["chat-box"],
196
  ):
197
  for message in reversed(messages.value):
198
  with solara.Row(style={"align-items": "flex-start"}):
199
+ # Catch "messages" that are actually tool calls
200
+ if isinstance(message, dict):
201
+ icon = (
202
+ "mdi-map"
203
+ if message["output"] == "Map updated"
204
+ else "mdi-map-marker"
205
+ )
206
+ solara.v.Icon(children=[icon], style_="padding-top: 10px;")
207
+ solara.Markdown(message["output"])
208
+ elif message.role == "user":
209
  solara.Text(
210
  message.content[0].text.value,
211
  classes=["chat-message", "user-message"],
212
  )
 
213
  elif message.role == "assistant":
214
  if message.content[0].text.value:
215
  solara.v.Icon(
 
232
  repr(message),
233
  classes=["chat-message", "assistant-message"],
234
  )
 
 
235
  else:
236
  solara.v.Icon(
237
  children=["mdi-compass-outline"],
 
256
 
257
  @solara.component
258
  def Page():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  with solara.Column(
260
  classes=["ui-container"],
261
  gap="5vh",
 
268
  unsafe_innerHTML="Wanderlust",
269
  style={"display": "inline-block"},
270
  )
 
 
 
 
271
  with solara.Row(
272
  justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
273
  ):
274
+ ChatInterface()
275
  with solara.Column(classes=["map-container"]):
276
+ Map()
277
 
278
  solara.Style(
279
  """
 
300
  height: 100%;
301
  width: 38vw;
302
  justify-content: center;
303
+ position: relative;
304
+ }
305
+ .chat-interface:after {
306
+ content: "";
307
+ position: absolute;
308
+ z-index: 1;
309
+ top: 0;
310
+ left: 0;
311
+ pointer-events: none;
312
+ background-image: linear-gradient(to top, rgba(255,255,255,0), rgba(255,255,255, 1) 100%);
313
+ width: 100%;
314
+ height: 15%;
315
+ }
316
+ .chat-box > :last-child{
317
+ padding-top: 7.5vh;
318
  }
319
  .map-container{
320
  width: 50vw;
321
  height: 100%;
322
  justify-content: center;
323
  }
324
+ .user-message{
325
+ font-weight: bold;
326
+ }
327
  @media screen and (max-aspect-ratio: 1/1) {
328
  .ui-container{
329
  padding: 30px;