bluenevus commited on
Commit
49fb757
·
1 Parent(s): 195b95a

Update app.py via AI Editor

Browse files
Files changed (1) hide show
  1. app.py +135 -54
app.py CHANGED
@@ -5,6 +5,7 @@ import uuid
5
  import shutil
6
  import json
7
  import tempfile
 
8
  from flask import Flask, request as flask_request, make_response
9
  import dash
10
  from dash import dcc, html, Input, Output, State, callback_context
@@ -25,6 +26,10 @@ os.makedirs(SESSION_DIR_BASE, exist_ok=True)
25
 
26
  openai.api_key = os.environ.get("OPENAI_API_KEY")
27
 
 
 
 
 
28
  def get_session_id():
29
  sid = flask_request.cookies.get("asktricare_session_id")
30
  if not sid:
@@ -75,6 +80,39 @@ def load_system_prompt():
75
  logger.error(f"Failed to load system prompt: {e}")
76
  return "You are Ask Tricare, a helpful assistant for TRICARE health benefits. Respond conversationally, and cite relevant sources when possible. If you do not know, say so."
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  app = dash.Dash(
79
  __name__,
80
  server=app_flask,
@@ -201,18 +239,20 @@ def assign_session_id(_):
201
  Output("stream-interval", "disabled"),
202
  Output("stream-interval", "n_intervals"),
203
  Output("chat-name-input", "style"),
 
 
204
  Input("session-id", "data"),
205
  Input("send-btn", "n_clicks"),
206
  Input("file-upload", "contents"),
207
  Input("new-chat-btn", "n_clicks"),
 
208
  State("file-upload", "filename"),
209
  State("user-input", "value"),
210
  State("chat-name-input", "value"),
211
- State("stream-interval", "n_intervals"),
212
  State("chat-name-input", "style"),
213
  prevent_initial_call=False
214
  )
215
- def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, file_names, user_input, chat_name, stream_n, chat_name_style):
216
  trigger = callback_context.triggered[0]['prop_id'].split('.')[0] if callback_context.triggered else ""
217
  if not session_id:
218
  session_id = get_session_id()
@@ -303,27 +343,32 @@ def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, file_
303
  # If chat name input box is not yet visible, show it
304
  if show_chat_name.get("display", "none") == "none":
305
  show_chat_name["display"] = "block"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  return (
307
- [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])],
308
- [
309
- html.Li(
310
- html.Span(
311
- chat["name"],
312
- style={"fontSize": "0.92rem"}
313
- )
314
- ) for chat in state.get("chat_histories", [])[-6:]
315
- ],
316
- [
317
- chat_message_card(msg['content'], is_user=(msg['role'] == "user"))
318
- for msg in state.get("messages", [])
319
- ] + (
320
- [chat_message_card(state["stream_buffer"], is_user=False)]
321
- if state.get("streaming", False) and state.get("stream_buffer", "") else []
322
- ),
323
  "",
324
  not state.get("streaming", False),
325
  0,
326
- show_chat_name
 
 
327
  )
328
  # If input box is visible and has a value, save chat history
329
  else:
@@ -342,24 +387,82 @@ def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, file_
342
  save_session_state(session_id)
343
  logger.info(f"Session {session_id}: Saved chat history '{chat_title}'")
344
  show_chat_name["display"] = "none"
 
 
 
 
 
 
 
 
 
345
  return (
346
- [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])],
347
- [
348
- html.Li(
349
- html.Span(
350
- chat["name"],
351
- style={"fontSize": "0.92rem", "fontWeight": "bold"}
352
- )
353
- ) for chat in state.get("chat_histories", [])[-6:]
354
- ],
355
  [],
356
  error,
357
  not state.get("streaming", False),
358
  0,
359
- show_chat_name
 
 
360
  )
361
 
362
- # Build Uploads, Chat History and Chat Window
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  chat_history = state.get("messages", [])
364
  uploads = state.get("uploads", [])
365
  chat_histories = state.get("chat_histories", [])
@@ -378,30 +481,8 @@ def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, file_
378
  if state.get("streaming", False):
379
  if state.get("stream_buffer", ""):
380
  chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
381
- return upload_cards, chat_history_items, chat_cards, error, False, 0, show_chat_name
382
- return upload_cards, chat_history_items, chat_cards, error, (not state.get("streaming", False)), 0, show_chat_name
383
-
384
- @app.callback(
385
- Output("chat-window", "children"),
386
- Output("stream-interval", "disabled"),
387
- Input("stream-interval", "n_intervals"),
388
- State("session-id", "data"),
389
- prevent_initial_call=True
390
- )
391
- def poll_stream(n_intervals, session_id):
392
- session_lock = get_session_lock(session_id)
393
- with session_lock:
394
- load_session_state(session_id)
395
- state = get_session_state(session_id)
396
- chat_history = state.get("messages", [])
397
- chat_cards = []
398
- for msg in chat_history:
399
- chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
400
- if state.get("streaming", False):
401
- if state.get("stream_buffer", ""):
402
- chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
403
- return chat_cards, False
404
- return chat_cards, True
405
 
406
  @app_flask.after_request
407
  def set_session_cookie(resp):
 
5
  import shutil
6
  import json
7
  import tempfile
8
+ import glob
9
  from flask import Flask, request as flask_request, make_response
10
  import dash
11
  from dash import dcc, html, Input, Output, State, callback_context
 
26
 
27
  openai.api_key = os.environ.get("OPENAI_API_KEY")
28
 
29
+ EMBEDDING_INDEX = {}
30
+ EMBEDDING_TEXTS = {}
31
+ EMBEDDING_MODEL = "text-embedding-ada-002"
32
+
33
  def get_session_id():
34
  sid = flask_request.cookies.get("asktricare_session_id")
35
  if not sid:
 
80
  logger.error(f"Failed to load system prompt: {e}")
81
  return "You are Ask Tricare, a helpful assistant for TRICARE health benefits. Respond conversationally, and cite relevant sources when possible. If you do not know, say so."
82
 
83
+ def embed_docs_folder():
84
+ global EMBEDDING_INDEX, EMBEDDING_TEXTS
85
+ docs_folder = os.path.join(os.getcwd(), "docs")
86
+ if not os.path.isdir(docs_folder):
87
+ logger.warning(f"Docs folder '{docs_folder}' does not exist. Skipping embedding.")
88
+ return
89
+ doc_files = []
90
+ for ext in ("*.txt", "*.md", "*.pdf"):
91
+ doc_files.extend(glob.glob(os.path.join(docs_folder, ext)))
92
+ for doc_path in doc_files:
93
+ fname = os.path.basename(doc_path)
94
+ if fname in EMBEDDING_INDEX:
95
+ continue
96
+ try:
97
+ with open(doc_path, "r", encoding="utf-8", errors="ignore") as f:
98
+ text = f.read()
99
+ if not text.strip():
100
+ continue
101
+ # OpenAI recommends chunking long texts; here, we take the first 1000 tokens per doc as a simple approach
102
+ chunk = text[:4000]
103
+ response = openai.Embedding.create(
104
+ input=[chunk],
105
+ model=EMBEDDING_MODEL
106
+ )
107
+ embedding = response['data'][0]['embedding']
108
+ EMBEDDING_INDEX[fname] = embedding
109
+ EMBEDDING_TEXTS[fname] = chunk
110
+ logger.info(f"Embedded doc: {fname}")
111
+ except Exception as e:
112
+ logger.error(f"Embedding failed for {fname}: {e}")
113
+
114
+ embed_docs_folder()
115
+
116
  app = dash.Dash(
117
  __name__,
118
  server=app_flask,
 
239
  Output("stream-interval", "disabled"),
240
  Output("stream-interval", "n_intervals"),
241
  Output("chat-name-input", "style"),
242
+ Output("chat-window", "children"),
243
+ Output("stream-interval", "disabled"),
244
  Input("session-id", "data"),
245
  Input("send-btn", "n_clicks"),
246
  Input("file-upload", "contents"),
247
  Input("new-chat-btn", "n_clicks"),
248
+ Input("stream-interval", "n_intervals"),
249
  State("file-upload", "filename"),
250
  State("user-input", "value"),
251
  State("chat-name-input", "value"),
 
252
  State("chat-name-input", "style"),
253
  prevent_initial_call=False
254
  )
255
+ def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, stream_n, file_names, user_input, chat_name, chat_name_style):
256
  trigger = callback_context.triggered[0]['prop_id'].split('.')[0] if callback_context.triggered else ""
257
  if not session_id:
258
  session_id = get_session_id()
 
343
  # If chat name input box is not yet visible, show it
344
  if show_chat_name.get("display", "none") == "none":
345
  show_chat_name["display"] = "block"
346
+ upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
347
+ chat_history_items = [
348
+ html.Li(
349
+ html.Span(
350
+ chat["name"],
351
+ style={"fontSize": "0.92rem"}
352
+ )
353
+ ) for chat in state.get("chat_histories", [])[-6:]
354
+ ]
355
+ chat_cards = [
356
+ chat_message_card(msg['content'], is_user=(msg['role'] == "user"))
357
+ for msg in state.get("messages", [])
358
+ ] + (
359
+ [chat_message_card(state["stream_buffer"], is_user=False)]
360
+ if state.get("streaming", False) and state.get("stream_buffer", "") else []
361
+ )
362
  return (
363
+ upload_cards,
364
+ chat_history_items,
365
+ chat_cards,
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  "",
367
  not state.get("streaming", False),
368
  0,
369
+ show_chat_name,
370
+ chat_cards,
371
+ not state.get("streaming", False)
372
  )
373
  # If input box is visible and has a value, save chat history
374
  else:
 
387
  save_session_state(session_id)
388
  logger.info(f"Session {session_id}: Saved chat history '{chat_title}'")
389
  show_chat_name["display"] = "none"
390
+ upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
391
+ chat_history_items = [
392
+ html.Li(
393
+ html.Span(
394
+ chat["name"],
395
+ style={"fontSize": "0.92rem", "fontWeight": "bold"}
396
+ )
397
+ ) for chat in state.get("chat_histories", [])[-6:]
398
+ ]
399
  return (
400
+ upload_cards,
401
+ chat_history_items,
 
 
 
 
 
 
 
402
  [],
403
  error,
404
  not state.get("streaming", False),
405
  0,
406
+ show_chat_name,
407
+ [],
408
+ not state.get("streaming", False)
409
  )
410
 
411
+ # Handle polling for streaming
412
+ if trigger == "stream-interval":
413
+ chat_history = state.get("messages", [])
414
+ chat_cards = []
415
+ for msg in chat_history:
416
+ chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
417
+ if state.get("streaming", False):
418
+ if state.get("stream_buffer", ""):
419
+ chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
420
+ upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
421
+ chat_history_items = [
422
+ html.Li(
423
+ html.Span(
424
+ chat["name"],
425
+ style={"fontSize": "0.92rem", "fontWeight": "bold" if i == len(state.get("chat_histories", []))-1 else "normal"}
426
+ )
427
+ ) for i, chat in enumerate(state.get("chat_histories", [])[-6:])
428
+ ]
429
+ return (
430
+ upload_cards,
431
+ chat_history_items,
432
+ chat_cards,
433
+ "",
434
+ False,
435
+ stream_n+1,
436
+ chat_name_style,
437
+ chat_cards,
438
+ False
439
+ )
440
+ else:
441
+ chat_cards = []
442
+ for msg in state.get("messages", []):
443
+ chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
444
+ upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
445
+ chat_history_items = [
446
+ html.Li(
447
+ html.Span(
448
+ chat["name"],
449
+ style={"fontSize": "0.92rem", "fontWeight": "bold" if i == len(state.get("chat_histories", []))-1 else "normal"}
450
+ )
451
+ ) for i, chat in enumerate(state.get("chat_histories", [])[-6:])
452
+ ]
453
+ return (
454
+ upload_cards,
455
+ chat_history_items,
456
+ chat_cards,
457
+ "",
458
+ True,
459
+ 0,
460
+ chat_name_style,
461
+ chat_cards,
462
+ True
463
+ )
464
+
465
+ # Default: Build Uploads, Chat History and Chat Window
466
  chat_history = state.get("messages", [])
467
  uploads = state.get("uploads", [])
468
  chat_histories = state.get("chat_histories", [])
 
481
  if state.get("streaming", False):
482
  if state.get("stream_buffer", ""):
483
  chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
484
+ return upload_cards, chat_history_items, chat_cards, error, False, 0, chat_name_style, chat_cards, False
485
+ return upload_cards, chat_history_items, chat_cards, error, (not state.get("streaming", False)), 0, chat_name_style, chat_cards, True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  @app_flask.after_request
488
  def set_session_cookie(resp):