Spaces:
Paused
Paused
Update app.py via AI Editor
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import uuid
|
|
5 |
import shutil
|
6 |
import json
|
7 |
import tempfile
|
|
|
8 |
from flask import Flask, request as flask_request, make_response
|
9 |
import dash
|
10 |
from dash import dcc, html, Input, Output, State, callback_context
|
@@ -25,6 +26,10 @@ os.makedirs(SESSION_DIR_BASE, exist_ok=True)
|
|
25 |
|
26 |
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
27 |
|
|
|
|
|
|
|
|
|
28 |
def get_session_id():
|
29 |
sid = flask_request.cookies.get("asktricare_session_id")
|
30 |
if not sid:
|
@@ -75,6 +80,39 @@ def load_system_prompt():
|
|
75 |
logger.error(f"Failed to load system prompt: {e}")
|
76 |
return "You are Ask Tricare, a helpful assistant for TRICARE health benefits. Respond conversationally, and cite relevant sources when possible. If you do not know, say so."
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
app = dash.Dash(
|
79 |
__name__,
|
80 |
server=app_flask,
|
@@ -201,18 +239,20 @@ def assign_session_id(_):
|
|
201 |
Output("stream-interval", "disabled"),
|
202 |
Output("stream-interval", "n_intervals"),
|
203 |
Output("chat-name-input", "style"),
|
|
|
|
|
204 |
Input("session-id", "data"),
|
205 |
Input("send-btn", "n_clicks"),
|
206 |
Input("file-upload", "contents"),
|
207 |
Input("new-chat-btn", "n_clicks"),
|
|
|
208 |
State("file-upload", "filename"),
|
209 |
State("user-input", "value"),
|
210 |
State("chat-name-input", "value"),
|
211 |
-
State("stream-interval", "n_intervals"),
|
212 |
State("chat-name-input", "style"),
|
213 |
prevent_initial_call=False
|
214 |
)
|
215 |
-
def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, file_names, user_input, chat_name,
|
216 |
trigger = callback_context.triggered[0]['prop_id'].split('.')[0] if callback_context.triggered else ""
|
217 |
if not session_id:
|
218 |
session_id = get_session_id()
|
@@ -303,27 +343,32 @@ def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, file_
|
|
303 |
# If chat name input box is not yet visible, show it
|
304 |
if show_chat_name.get("display", "none") == "none":
|
305 |
show_chat_name["display"] = "block"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
return (
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
html.Span(
|
311 |
-
chat["name"],
|
312 |
-
style={"fontSize": "0.92rem"}
|
313 |
-
)
|
314 |
-
) for chat in state.get("chat_histories", [])[-6:]
|
315 |
-
],
|
316 |
-
[
|
317 |
-
chat_message_card(msg['content'], is_user=(msg['role'] == "user"))
|
318 |
-
for msg in state.get("messages", [])
|
319 |
-
] + (
|
320 |
-
[chat_message_card(state["stream_buffer"], is_user=False)]
|
321 |
-
if state.get("streaming", False) and state.get("stream_buffer", "") else []
|
322 |
-
),
|
323 |
"",
|
324 |
not state.get("streaming", False),
|
325 |
0,
|
326 |
-
show_chat_name
|
|
|
|
|
327 |
)
|
328 |
# If input box is visible and has a value, save chat history
|
329 |
else:
|
@@ -342,24 +387,82 @@ def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, file_
|
|
342 |
save_session_state(session_id)
|
343 |
logger.info(f"Session {session_id}: Saved chat history '{chat_title}'")
|
344 |
show_chat_name["display"] = "none"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
return (
|
346 |
-
|
347 |
-
|
348 |
-
html.Li(
|
349 |
-
html.Span(
|
350 |
-
chat["name"],
|
351 |
-
style={"fontSize": "0.92rem", "fontWeight": "bold"}
|
352 |
-
)
|
353 |
-
) for chat in state.get("chat_histories", [])[-6:]
|
354 |
-
],
|
355 |
[],
|
356 |
error,
|
357 |
not state.get("streaming", False),
|
358 |
0,
|
359 |
-
show_chat_name
|
|
|
|
|
360 |
)
|
361 |
|
362 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
chat_history = state.get("messages", [])
|
364 |
uploads = state.get("uploads", [])
|
365 |
chat_histories = state.get("chat_histories", [])
|
@@ -378,30 +481,8 @@ def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, file_
|
|
378 |
if state.get("streaming", False):
|
379 |
if state.get("stream_buffer", ""):
|
380 |
chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
|
381 |
-
return upload_cards, chat_history_items, chat_cards, error, False, 0,
|
382 |
-
return upload_cards, chat_history_items, chat_cards, error, (not state.get("streaming", False)), 0,
|
383 |
-
|
384 |
-
@app.callback(
|
385 |
-
Output("chat-window", "children"),
|
386 |
-
Output("stream-interval", "disabled"),
|
387 |
-
Input("stream-interval", "n_intervals"),
|
388 |
-
State("session-id", "data"),
|
389 |
-
prevent_initial_call=True
|
390 |
-
)
|
391 |
-
def poll_stream(n_intervals, session_id):
|
392 |
-
session_lock = get_session_lock(session_id)
|
393 |
-
with session_lock:
|
394 |
-
load_session_state(session_id)
|
395 |
-
state = get_session_state(session_id)
|
396 |
-
chat_history = state.get("messages", [])
|
397 |
-
chat_cards = []
|
398 |
-
for msg in chat_history:
|
399 |
-
chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
|
400 |
-
if state.get("streaming", False):
|
401 |
-
if state.get("stream_buffer", ""):
|
402 |
-
chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
|
403 |
-
return chat_cards, False
|
404 |
-
return chat_cards, True
|
405 |
|
406 |
@app_flask.after_request
|
407 |
def set_session_cookie(resp):
|
|
|
5 |
import shutil
|
6 |
import json
|
7 |
import tempfile
|
8 |
+
import glob
|
9 |
from flask import Flask, request as flask_request, make_response
|
10 |
import dash
|
11 |
from dash import dcc, html, Input, Output, State, callback_context
|
|
|
26 |
|
27 |
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
28 |
|
29 |
+
EMBEDDING_INDEX = {}
|
30 |
+
EMBEDDING_TEXTS = {}
|
31 |
+
EMBEDDING_MODEL = "text-embedding-ada-002"
|
32 |
+
|
33 |
def get_session_id():
|
34 |
sid = flask_request.cookies.get("asktricare_session_id")
|
35 |
if not sid:
|
|
|
80 |
logger.error(f"Failed to load system prompt: {e}")
|
81 |
return "You are Ask Tricare, a helpful assistant for TRICARE health benefits. Respond conversationally, and cite relevant sources when possible. If you do not know, say so."
|
82 |
|
83 |
+
def embed_docs_folder():
|
84 |
+
global EMBEDDING_INDEX, EMBEDDING_TEXTS
|
85 |
+
docs_folder = os.path.join(os.getcwd(), "docs")
|
86 |
+
if not os.path.isdir(docs_folder):
|
87 |
+
logger.warning(f"Docs folder '{docs_folder}' does not exist. Skipping embedding.")
|
88 |
+
return
|
89 |
+
doc_files = []
|
90 |
+
for ext in ("*.txt", "*.md", "*.pdf"):
|
91 |
+
doc_files.extend(glob.glob(os.path.join(docs_folder, ext)))
|
92 |
+
for doc_path in doc_files:
|
93 |
+
fname = os.path.basename(doc_path)
|
94 |
+
if fname in EMBEDDING_INDEX:
|
95 |
+
continue
|
96 |
+
try:
|
97 |
+
with open(doc_path, "r", encoding="utf-8", errors="ignore") as f:
|
98 |
+
text = f.read()
|
99 |
+
if not text.strip():
|
100 |
+
continue
|
101 |
+
# OpenAI recommends chunking long texts; here, we take the first 1000 tokens per doc as a simple approach
|
102 |
+
chunk = text[:4000]
|
103 |
+
response = openai.Embedding.create(
|
104 |
+
input=[chunk],
|
105 |
+
model=EMBEDDING_MODEL
|
106 |
+
)
|
107 |
+
embedding = response['data'][0]['embedding']
|
108 |
+
EMBEDDING_INDEX[fname] = embedding
|
109 |
+
EMBEDDING_TEXTS[fname] = chunk
|
110 |
+
logger.info(f"Embedded doc: {fname}")
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f"Embedding failed for {fname}: {e}")
|
113 |
+
|
114 |
+
embed_docs_folder()
|
115 |
+
|
116 |
app = dash.Dash(
|
117 |
__name__,
|
118 |
server=app_flask,
|
|
|
239 |
Output("stream-interval", "disabled"),
|
240 |
Output("stream-interval", "n_intervals"),
|
241 |
Output("chat-name-input", "style"),
|
242 |
+
Output("chat-window", "children"),
|
243 |
+
Output("stream-interval", "disabled"),
|
244 |
Input("session-id", "data"),
|
245 |
Input("send-btn", "n_clicks"),
|
246 |
Input("file-upload", "contents"),
|
247 |
Input("new-chat-btn", "n_clicks"),
|
248 |
+
Input("stream-interval", "n_intervals"),
|
249 |
State("file-upload", "filename"),
|
250 |
State("user-input", "value"),
|
251 |
State("chat-name-input", "value"),
|
|
|
252 |
State("chat-name-input", "style"),
|
253 |
prevent_initial_call=False
|
254 |
)
|
255 |
+
def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, stream_n, file_names, user_input, chat_name, chat_name_style):
|
256 |
trigger = callback_context.triggered[0]['prop_id'].split('.')[0] if callback_context.triggered else ""
|
257 |
if not session_id:
|
258 |
session_id = get_session_id()
|
|
|
343 |
# If chat name input box is not yet visible, show it
|
344 |
if show_chat_name.get("display", "none") == "none":
|
345 |
show_chat_name["display"] = "block"
|
346 |
+
upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
|
347 |
+
chat_history_items = [
|
348 |
+
html.Li(
|
349 |
+
html.Span(
|
350 |
+
chat["name"],
|
351 |
+
style={"fontSize": "0.92rem"}
|
352 |
+
)
|
353 |
+
) for chat in state.get("chat_histories", [])[-6:]
|
354 |
+
]
|
355 |
+
chat_cards = [
|
356 |
+
chat_message_card(msg['content'], is_user=(msg['role'] == "user"))
|
357 |
+
for msg in state.get("messages", [])
|
358 |
+
] + (
|
359 |
+
[chat_message_card(state["stream_buffer"], is_user=False)]
|
360 |
+
if state.get("streaming", False) and state.get("stream_buffer", "") else []
|
361 |
+
)
|
362 |
return (
|
363 |
+
upload_cards,
|
364 |
+
chat_history_items,
|
365 |
+
chat_cards,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
"",
|
367 |
not state.get("streaming", False),
|
368 |
0,
|
369 |
+
show_chat_name,
|
370 |
+
chat_cards,
|
371 |
+
not state.get("streaming", False)
|
372 |
)
|
373 |
# If input box is visible and has a value, save chat history
|
374 |
else:
|
|
|
387 |
save_session_state(session_id)
|
388 |
logger.info(f"Session {session_id}: Saved chat history '{chat_title}'")
|
389 |
show_chat_name["display"] = "none"
|
390 |
+
upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
|
391 |
+
chat_history_items = [
|
392 |
+
html.Li(
|
393 |
+
html.Span(
|
394 |
+
chat["name"],
|
395 |
+
style={"fontSize": "0.92rem", "fontWeight": "bold"}
|
396 |
+
)
|
397 |
+
) for chat in state.get("chat_histories", [])[-6:]
|
398 |
+
]
|
399 |
return (
|
400 |
+
upload_cards,
|
401 |
+
chat_history_items,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
[],
|
403 |
error,
|
404 |
not state.get("streaming", False),
|
405 |
0,
|
406 |
+
show_chat_name,
|
407 |
+
[],
|
408 |
+
not state.get("streaming", False)
|
409 |
)
|
410 |
|
411 |
+
# Handle polling for streaming
|
412 |
+
if trigger == "stream-interval":
|
413 |
+
chat_history = state.get("messages", [])
|
414 |
+
chat_cards = []
|
415 |
+
for msg in chat_history:
|
416 |
+
chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
|
417 |
+
if state.get("streaming", False):
|
418 |
+
if state.get("stream_buffer", ""):
|
419 |
+
chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
|
420 |
+
upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
|
421 |
+
chat_history_items = [
|
422 |
+
html.Li(
|
423 |
+
html.Span(
|
424 |
+
chat["name"],
|
425 |
+
style={"fontSize": "0.92rem", "fontWeight": "bold" if i == len(state.get("chat_histories", []))-1 else "normal"}
|
426 |
+
)
|
427 |
+
) for i, chat in enumerate(state.get("chat_histories", [])[-6:])
|
428 |
+
]
|
429 |
+
return (
|
430 |
+
upload_cards,
|
431 |
+
chat_history_items,
|
432 |
+
chat_cards,
|
433 |
+
"",
|
434 |
+
False,
|
435 |
+
stream_n+1,
|
436 |
+
chat_name_style,
|
437 |
+
chat_cards,
|
438 |
+
False
|
439 |
+
)
|
440 |
+
else:
|
441 |
+
chat_cards = []
|
442 |
+
for msg in state.get("messages", []):
|
443 |
+
chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
|
444 |
+
upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
|
445 |
+
chat_history_items = [
|
446 |
+
html.Li(
|
447 |
+
html.Span(
|
448 |
+
chat["name"],
|
449 |
+
style={"fontSize": "0.92rem", "fontWeight": "bold" if i == len(state.get("chat_histories", []))-1 else "normal"}
|
450 |
+
)
|
451 |
+
) for i, chat in enumerate(state.get("chat_histories", [])[-6:])
|
452 |
+
]
|
453 |
+
return (
|
454 |
+
upload_cards,
|
455 |
+
chat_history_items,
|
456 |
+
chat_cards,
|
457 |
+
"",
|
458 |
+
True,
|
459 |
+
0,
|
460 |
+
chat_name_style,
|
461 |
+
chat_cards,
|
462 |
+
True
|
463 |
+
)
|
464 |
+
|
465 |
+
# Default: Build Uploads, Chat History and Chat Window
|
466 |
chat_history = state.get("messages", [])
|
467 |
uploads = state.get("uploads", [])
|
468 |
chat_histories = state.get("chat_histories", [])
|
|
|
481 |
if state.get("streaming", False):
|
482 |
if state.get("stream_buffer", ""):
|
483 |
chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
|
484 |
+
return upload_cards, chat_history_items, chat_cards, error, False, 0, chat_name_style, chat_cards, False
|
485 |
+
return upload_cards, chat_history_items, chat_cards, error, (not state.get("streaming", False)), 0, chat_name_style, chat_cards, True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
|
487 |
@app_flask.after_request
|
488 |
def set_session_cookie(resp):
|