Spaces:
Runtime error
Runtime error
remove another comment
Browse files- wanderlust.py +38 -56
wanderlust.py
CHANGED
@@ -4,7 +4,6 @@ import os
|
|
4 |
import ipyleaflet
|
5 |
from openai import OpenAI, NotFoundError
|
6 |
from openai.types.beta import Thread
|
7 |
-
from openai.types.beta.threads import Run
|
8 |
|
9 |
import time
|
10 |
|
@@ -13,9 +12,7 @@ import solara
|
|
13 |
center_default = (0, 0)
|
14 |
zoom_default = 2
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
messages = solara.reactive(messages_default)
|
19 |
zoom_level = solara.reactive(zoom_default)
|
20 |
center = solara.reactive(center_default)
|
21 |
markers = solara.reactive([])
|
@@ -25,6 +22,7 @@ openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
25 |
model = "gpt-4-1106-preview"
|
26 |
|
27 |
|
|
|
28 |
tools = [
|
29 |
{
|
30 |
"type": "function",
|
@@ -80,7 +78,6 @@ tools = [
|
|
80 |
|
81 |
|
82 |
def update_map(longitude, latitude, zoom):
|
83 |
-
print("update_map", longitude, latitude, zoom)
|
84 |
center.set((latitude, longitude))
|
85 |
zoom_level.set(zoom)
|
86 |
return "Map updated"
|
@@ -111,12 +108,9 @@ def ai_call(tool_call):
|
|
111 |
|
112 |
@solara.component
|
113 |
def Map():
|
114 |
-
print("Map", zoom_level.value, center.value, markers.value)
|
115 |
ipyleaflet.Map.element( # type: ignore
|
116 |
zoom=zoom_level.value,
|
117 |
-
# on_zoom=zoom_level.set,
|
118 |
center=center.value,
|
119 |
-
# on_center=center.set,
|
120 |
scroll_wheel_zoom=True,
|
121 |
layers=[
|
122 |
ipyleaflet.TileLayer.element(url=url),
|
@@ -134,7 +128,6 @@ def ChatInterface():
|
|
134 |
run_id: solara.Reactive[str] = solara.use_reactive(None)
|
135 |
|
136 |
thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
|
137 |
-
print("thread id:", thread.id)
|
138 |
|
139 |
def add_message(value: str):
|
140 |
if value == "":
|
@@ -149,7 +142,6 @@ def ChatInterface():
|
|
149 |
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
|
150 |
tools=tools,
|
151 |
).id
|
152 |
-
print("Run id:", run_id.value)
|
153 |
|
154 |
def poll():
|
155 |
if not run_id.value:
|
@@ -159,7 +151,8 @@ def ChatInterface():
|
|
159 |
try:
|
160 |
run = openai.beta.threads.runs.retrieve(
|
161 |
run_id.value, thread_id=thread.id
|
162 |
-
)
|
|
|
163 |
except NotFoundError:
|
164 |
continue
|
165 |
if run.status == "requires_action":
|
@@ -167,6 +160,7 @@ def ChatInterface():
|
|
167 |
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
168 |
tool_output = ai_call(tool_call)
|
169 |
tool_outputs.append(tool_output)
|
|
|
170 |
openai.beta.threads.runs.submit_tool_outputs(
|
171 |
thread_id=thread.id,
|
172 |
run_id=run_id.value,
|
@@ -182,27 +176,10 @@ def ChatInterface():
|
|
182 |
run_id.set(None)
|
183 |
completed = True
|
184 |
time.sleep(0.1)
|
185 |
-
retrieved_messages = openai.beta.threads.messages.list(thread_id=thread.id)
|
186 |
-
messages.set(retrieved_messages.data)
|
187 |
|
188 |
result = solara.use_thread(poll, dependencies=[run_id.value])
|
189 |
|
190 |
-
|
191 |
-
print("handle", message)
|
192 |
-
messages = []
|
193 |
-
if message.role == "assistant":
|
194 |
-
tools_calls = message.get("tool_calls", [])
|
195 |
-
for tool_call in tools_calls:
|
196 |
-
messages.append(ai_call(tool_call))
|
197 |
-
return messages
|
198 |
-
|
199 |
-
def handle_initial():
|
200 |
-
print("handle initial", messages.value)
|
201 |
-
for message in messages.value:
|
202 |
-
handle_message(message)
|
203 |
-
|
204 |
-
solara.use_effect(handle_initial, [])
|
205 |
-
# result = solara.use_thread(ask, dependencies=[messages.value])
|
206 |
with solara.Column(
|
207 |
classes=["chat-interface"],
|
208 |
):
|
@@ -214,16 +191,25 @@ def ChatInterface():
|
|
214 |
"overflow-y": "auto",
|
215 |
"height": "100px",
|
216 |
"flex-direction": "column-reverse",
|
217 |
-
}
|
|
|
218 |
):
|
219 |
for message in reversed(messages.value):
|
220 |
with solara.Row(style={"align-items": "flex-start"}):
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
solara.Text(
|
223 |
message.content[0].text.value,
|
224 |
classes=["chat-message", "user-message"],
|
225 |
)
|
226 |
-
assert len(message.content) == 1
|
227 |
elif message.role == "assistant":
|
228 |
if message.content[0].text.value:
|
229 |
solara.v.Icon(
|
@@ -246,8 +232,6 @@ def ChatInterface():
|
|
246 |
repr(message),
|
247 |
classes=["chat-message", "assistant-message"],
|
248 |
)
|
249 |
-
elif message["role"] == "tool":
|
250 |
-
pass # no need to display
|
251 |
else:
|
252 |
solara.v.Icon(
|
253 |
children=["mdi-compass-outline"],
|
@@ -272,21 +256,6 @@ def ChatInterface():
|
|
272 |
|
273 |
@solara.component
|
274 |
def Page():
|
275 |
-
reset_counter, set_reset_counter = solara.use_state(0)
|
276 |
-
print("reset", reset_counter, f"chat-{reset_counter}")
|
277 |
-
|
278 |
-
def reset_ui():
|
279 |
-
set_reset_counter(reset_counter + 1)
|
280 |
-
|
281 |
-
def save():
|
282 |
-
with open("log.json", "w") as f:
|
283 |
-
json.dump(messages.value, f)
|
284 |
-
|
285 |
-
def load():
|
286 |
-
with open("log.json", "r") as f:
|
287 |
-
messages.set(json.load(f))
|
288 |
-
reset_ui()
|
289 |
-
|
290 |
with solara.Column(
|
291 |
classes=["ui-container"],
|
292 |
gap="5vh",
|
@@ -299,16 +268,12 @@ def Page():
|
|
299 |
unsafe_innerHTML="Wanderlust",
|
300 |
style={"display": "inline-block"},
|
301 |
)
|
302 |
-
# with solara.Row(gap="10px"):
|
303 |
-
# solara.Button("Save", on_click=save)
|
304 |
-
# solara.Button("Load", on_click=load)
|
305 |
-
# solara.Button("Soft reset", on_click=reset_ui)
|
306 |
with solara.Row(
|
307 |
justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
|
308 |
):
|
309 |
-
ChatInterface()
|
310 |
with solara.Column(classes=["map-container"]):
|
311 |
-
Map()
|
312 |
|
313 |
solara.Style(
|
314 |
"""
|
@@ -335,13 +300,30 @@ def Page():
|
|
335 |
height: 100%;
|
336 |
width: 38vw;
|
337 |
justify-content: center;
|
338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
}
|
340 |
.map-container{
|
341 |
width: 50vw;
|
342 |
height: 100%;
|
343 |
justify-content: center;
|
344 |
}
|
|
|
|
|
|
|
345 |
@media screen and (max-aspect-ratio: 1/1) {
|
346 |
.ui-container{
|
347 |
padding: 30px;
|
|
|
4 |
import ipyleaflet
|
5 |
from openai import OpenAI, NotFoundError
|
6 |
from openai.types.beta import Thread
|
|
|
7 |
|
8 |
import time
|
9 |
|
|
|
12 |
center_default = (0, 0)
|
13 |
zoom_default = 2
|
14 |
|
15 |
+
messages = solara.reactive([])
|
|
|
|
|
16 |
zoom_level = solara.reactive(zoom_default)
|
17 |
center = solara.reactive(center_default)
|
18 |
markers = solara.reactive([])
|
|
|
22 |
model = "gpt-4-1106-preview"
|
23 |
|
24 |
|
25 |
+
# Declare tools for openai assistant to use
|
26 |
tools = [
|
27 |
{
|
28 |
"type": "function",
|
|
|
78 |
|
79 |
|
80 |
def update_map(longitude, latitude, zoom):
|
|
|
81 |
center.set((latitude, longitude))
|
82 |
zoom_level.set(zoom)
|
83 |
return "Map updated"
|
|
|
108 |
|
109 |
@solara.component
|
110 |
def Map():
|
|
|
111 |
ipyleaflet.Map.element( # type: ignore
|
112 |
zoom=zoom_level.value,
|
|
|
113 |
center=center.value,
|
|
|
114 |
scroll_wheel_zoom=True,
|
115 |
layers=[
|
116 |
ipyleaflet.TileLayer.element(url=url),
|
|
|
128 |
run_id: solara.Reactive[str] = solara.use_reactive(None)
|
129 |
|
130 |
thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
|
|
|
131 |
|
132 |
def add_message(value: str):
|
133 |
if value == "":
|
|
|
142 |
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
|
143 |
tools=tools,
|
144 |
).id
|
|
|
145 |
|
146 |
def poll():
|
147 |
if not run_id.value:
|
|
|
151 |
try:
|
152 |
run = openai.beta.threads.runs.retrieve(
|
153 |
run_id.value, thread_id=thread.id
|
154 |
+
)
|
155 |
+
# Above will raise NotFoundError when run creation is still in progress
|
156 |
except NotFoundError:
|
157 |
continue
|
158 |
if run.status == "requires_action":
|
|
|
160 |
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
|
161 |
tool_output = ai_call(tool_call)
|
162 |
tool_outputs.append(tool_output)
|
163 |
+
messages.set([*messages.value, tool_output])
|
164 |
openai.beta.threads.runs.submit_tool_outputs(
|
165 |
thread_id=thread.id,
|
166 |
run_id=run_id.value,
|
|
|
176 |
run_id.set(None)
|
177 |
completed = True
|
178 |
time.sleep(0.1)
|
|
|
|
|
179 |
|
180 |
result = solara.use_thread(poll, dependencies=[run_id.value])
|
181 |
|
182 |
+
# Create DOM for chat interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
with solara.Column(
|
184 |
classes=["chat-interface"],
|
185 |
):
|
|
|
191 |
"overflow-y": "auto",
|
192 |
"height": "100px",
|
193 |
"flex-direction": "column-reverse",
|
194 |
+
},
|
195 |
+
classes=["chat-box"],
|
196 |
):
|
197 |
for message in reversed(messages.value):
|
198 |
with solara.Row(style={"align-items": "flex-start"}):
|
199 |
+
# Catch "messages" that are actually tool calls
|
200 |
+
if isinstance(message, dict):
|
201 |
+
icon = (
|
202 |
+
"mdi-map"
|
203 |
+
if message["output"] == "Map updated"
|
204 |
+
else "mdi-map-marker"
|
205 |
+
)
|
206 |
+
solara.v.Icon(children=[icon], style_="padding-top: 10px;")
|
207 |
+
solara.Markdown(message["output"])
|
208 |
+
elif message.role == "user":
|
209 |
solara.Text(
|
210 |
message.content[0].text.value,
|
211 |
classes=["chat-message", "user-message"],
|
212 |
)
|
|
|
213 |
elif message.role == "assistant":
|
214 |
if message.content[0].text.value:
|
215 |
solara.v.Icon(
|
|
|
232 |
repr(message),
|
233 |
classes=["chat-message", "assistant-message"],
|
234 |
)
|
|
|
|
|
235 |
else:
|
236 |
solara.v.Icon(
|
237 |
children=["mdi-compass-outline"],
|
|
|
256 |
|
257 |
@solara.component
|
258 |
def Page():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
with solara.Column(
|
260 |
classes=["ui-container"],
|
261 |
gap="5vh",
|
|
|
268 |
unsafe_innerHTML="Wanderlust",
|
269 |
style={"display": "inline-block"},
|
270 |
)
|
|
|
|
|
|
|
|
|
271 |
with solara.Row(
|
272 |
justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
|
273 |
):
|
274 |
+
ChatInterface()
|
275 |
with solara.Column(classes=["map-container"]):
|
276 |
+
Map()
|
277 |
|
278 |
solara.Style(
|
279 |
"""
|
|
|
300 |
height: 100%;
|
301 |
width: 38vw;
|
302 |
justify-content: center;
|
303 |
+
position: relative;
|
304 |
+
}
|
305 |
+
.chat-interface:after {
|
306 |
+
content: "";
|
307 |
+
position: absolute;
|
308 |
+
z-index: 1;
|
309 |
+
top: 0;
|
310 |
+
left: 0;
|
311 |
+
pointer-events: none;
|
312 |
+
background-image: linear-gradient(to top, rgba(255,255,255,0), rgba(255,255,255, 1) 100%);
|
313 |
+
width: 100%;
|
314 |
+
height: 15%;
|
315 |
+
}
|
316 |
+
.chat-box > :last-child{
|
317 |
+
padding-top: 7.5vh;
|
318 |
}
|
319 |
.map-container{
|
320 |
width: 50vw;
|
321 |
height: 100%;
|
322 |
justify-content: center;
|
323 |
}
|
324 |
+
.user-message{
|
325 |
+
font-weight: bold;
|
326 |
+
}
|
327 |
@media screen and (max-aspect-ratio: 1/1) {
|
328 |
.ui-container{
|
329 |
padding: 30px;
|