ask-tricare / app.py
bluenevus's picture
Update app.py via AI Editor
9a3a044
raw
history blame
29.4 kB
import os
import threading
import logging
import uuid
import shutil
import json
import tempfile
import glob
from flask import Flask, request as flask_request, make_response
import dash
from dash import dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
import openai
import base64
import datetime
from werkzeug.utils import secure_filename
import numpy as np
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(threadName)s %(message)s")
logger = logging.getLogger("AskTricare")
app_flask = Flask(__name__)
SESSION_DATA = {}
SESSION_LOCKS = {}
SESSION_DIR_BASE = os.path.join(tempfile.gettempdir(), "asktricare_sessions")
os.makedirs(SESSION_DIR_BASE, exist_ok=True)
openai.api_key = os.environ.get("OPENAI_API_KEY")
EMBEDDING_INDEX = {}
EMBEDDING_TEXTS = {}
EMBEDDING_MODEL = "text-embedding-ada-002"
def get_session_id():
sid = flask_request.cookies.get("asktricare_session_id")
if not sid:
sid = str(uuid.uuid4())
return sid
def get_session_dir(session_id):
d = os.path.join(SESSION_DIR_BASE, session_id)
os.makedirs(d, exist_ok=True)
return d
def get_session_lock(session_id):
if session_id not in SESSION_LOCKS:
SESSION_LOCKS[session_id] = threading.Lock()
return SESSION_LOCKS[session_id]
def get_session_state(session_id):
if session_id not in SESSION_DATA:
SESSION_DATA[session_id] = {
"messages": [],
"uploads": [],
"created": datetime.datetime.utcnow().isoformat(),
"streaming": False,
"stream_buffer": "",
"chat_histories": []
}
return SESSION_DATA[session_id]
def save_session_state(session_id):
state = get_session_state(session_id)
d = get_session_dir(session_id)
with open(os.path.join(d, "state.json"), "w") as f:
json.dump(state, f)
def load_session_state(session_id):
d = get_session_dir(session_id)
path = os.path.join(d, "state.json")
if os.path.exists(path):
with open(path, "r") as f:
SESSION_DATA[session_id] = json.load(f)
def load_system_prompt():
prompt_path = os.path.join(os.getcwd(), "system_prompt.txt")
try:
with open(prompt_path, "r", encoding="utf-8") as f:
return f.read().strip()
except Exception as e:
logger.error(f"Failed to load system prompt: {e}")
return "You are Ask Tricare, a helpful assistant for TRICARE health benefits. Respond conversationally, and cite relevant sources when possible. If you do not know, say so."
def embed_docs_folder():
global EMBEDDING_INDEX, EMBEDDING_TEXTS
docs_folder = os.path.join(os.getcwd(), "docs")
if not os.path.isdir(docs_folder):
logger.warning(f"Docs folder '{docs_folder}' does not exist. Skipping embedding.")
return
doc_files = []
for ext in ("*.txt", "*.md", "*.pdf"):
doc_files.extend(glob.glob(os.path.join(docs_folder, ext)))
for doc_path in doc_files:
fname = os.path.basename(doc_path)
if fname in EMBEDDING_INDEX:
continue
try:
with open(doc_path, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
if not text.strip():
continue
chunk = text[:4000]
response = openai.Embedding.create(
input=[chunk],
model=EMBEDDING_MODEL
)
embedding = response['data'][0]['embedding']
EMBEDDING_INDEX[fname] = embedding
EMBEDDING_TEXTS[fname] = chunk
logger.info(f"Embedded doc: {fname}")
except Exception as e:
logger.error(f"Embedding failed for {fname}: {e}")
embed_docs_folder()
def embed_user_doc(session_id, filename, text):
session_dir = get_session_dir(session_id)
if not text.strip():
return
try:
chunk = text[:4000]
response = openai.Embedding.create(
input=[chunk],
model=EMBEDDING_MODEL
)
embedding = response['data'][0]['embedding']
user_embeds_path = os.path.join(session_dir, "user_embeds.json")
if os.path.exists(user_embeds_path):
with open(user_embeds_path, "r") as f:
user_embeds = json.load(f)
else:
user_embeds = {"embeddings": [], "texts": [], "filenames": []}
user_embeds["embeddings"].append(embedding)
user_embeds["texts"].append(chunk)
user_embeds["filenames"].append(filename)
with open(user_embeds_path, "w") as f:
json.dump(user_embeds, f)
logger.info(f"Session {session_id}: Embedded user doc {filename}")
except Exception as e:
logger.error(f"Session {session_id}: Failed to embed user doc {filename}: {e}")
def get_user_embeddings(session_id):
session_dir = get_session_dir(session_id)
user_embeds_path = os.path.join(session_dir, "user_embeds.json")
if os.path.exists(user_embeds_path):
with open(user_embeds_path, "r") as f:
d = json.load(f)
embeds = np.array(d.get("embeddings", []))
texts = d.get("texts", [])
filenames = d.get("filenames", [])
return embeds, texts, filenames
return np.array([]), [], []
def semantic_search(query, embed_matrix, texts, filenames, top_k=2):
if len(embed_matrix) == 0:
return []
try:
q_embed = openai.Embedding.create(input=[query], model=EMBEDDING_MODEL)["data"][0]["embedding"]
q_embed = np.array(q_embed)
embed_matrix = np.array(embed_matrix)
scores = np.dot(embed_matrix, q_embed) / (np.linalg.norm(embed_matrix, axis=1) * np.linalg.norm(q_embed) + 1e-8)
idx = np.argsort(scores)[::-1][:top_k]
results = []
for i in idx:
results.append({"filename": filenames[i], "text": texts[i], "score": float(scores[i])})
return results
except Exception as e:
logger.error(f"Semantic search error: {e}")
return []
app = dash.Dash(
__name__,
server=app_flask,
suppress_callback_exceptions=True,
external_stylesheets=[dbc.themes.BOOTSTRAP, "/assets/custom.css"],
update_title="Ask Tricare"
)
def chat_message_card(msg, is_user):
align = "end" if is_user else "start"
color = "primary" if is_user else "secondary"
avatar = "🧑" if is_user else "🤖"
return dbc.Card(
dbc.CardBody([
html.Div([
html.Span(avatar, style={"fontSize": "2rem"}),
html.Span(msg, style={"whiteSpace": "pre-wrap", "marginLeft": "0.75rem"})
], style={"display": "flex", "alignItems": "center", "justifyContent": align})
]),
className=f"mb-2 ms-3 me-3",
color=color,
inverse=is_user,
style={"maxWidth": "80%", "alignSelf": f"flex-{align}"}
)
def uploaded_file_card(filename, is_img):
ext = os.path.splitext(filename)[1].lower()
icon = "🖼️" if is_img else "📄"
return dbc.Card(
dbc.CardBody([
html.Span(icon, style={"fontSize": "2rem", "marginRight": "0.5rem"}),
html.Span(filename)
]),
className="mb-2",
color="tertiary"
)
def disclaimer_card():
return dbc.Card(
dbc.CardBody([
html.H5("Disclaimer", className="card-title"),
html.P("This information is not private. Do not send PII or PHI. For official guidance visit the Tricare website.", style={"fontSize": "0.95rem"})
]),
className="mb-2"
)
def left_navbar_static():
return html.Div([
html.H3("Ask Tricare", className="mb-3 mt-3", style={"fontWeight": "bold"}),
disclaimer_card(),
dcc.Upload(
id="file-upload",
children=dbc.Button("Upload Document/Image", color="secondary", className="mb-2", style={"width": "100%"}),
multiple=True,
style={"width": "100%"}
),
html.Div(id="upload-list"),
html.Hr(),
html.H5("Chat History", className="mb-2"),
html.Ul(id="chat-history-list", style={"listStyle": "none", "paddingLeft": "0"}),
], style={"padding": "1rem", "backgroundColor": "#f8f9fa", "height": "100vh", "overflowY": "auto"})
def chat_box_card():
return dbc.Card(
dbc.CardBody([
html.Div(id="chat-window", style={
"height": "60vh",
"overflowY": "auto",
"display": "flex",
"flexDirection": "column",
"justifyContent": "flex-end",
"backgroundColor": "#fff",
"padding": "0.5rem",
"borderRadius": "0.5rem"
})
]),
className="mt-3"
)
def user_input_card():
return dbc.Card(
dbc.CardBody([
html.Div([
dcc.Textarea(
id="user-input",
placeholder="Type your question...",
style={"width": "100%", "height": "60px", "resize": "vertical", "wordWrap": "break-word"},
wrap="soft",
maxLength=1000,
n_submit=0,
n_blur=0,
),
dcc.Store(id="enter-triggered", data=False),
html.Div([
dbc.Button("Send", id="send-btn", color="primary", className="mt-2 me-2", style={"minWidth": "100px"}),
dbc.Button("New Chat", id="new-chat-btn", color="secondary", className="mt-2", style={"minWidth": "110px"}),
], style={"float": "right", "display": "flex", "gap": "0.5rem"}),
dcc.Store(id="user-input-store", data="", storage_type="session"),
html.Button(id='hidden-send', style={'display': 'none'})
], style={"marginTop": "1rem"}),
html.Div(id="error-message", style={"color": "#bb2124", "marginTop": "0.5rem"}),
dcc.Store(id="should-clear-input", data=False)
])
)
def right_main_static():
return html.Div([
chat_box_card(),
user_input_card(),
dcc.Loading(id="loading", type="default", fullscreen=False, style={"position": "absolute", "top": "5%", "left": "50%"}),
dcc.Interval(id="stream-interval", interval=400, n_intervals=0, disabled=True, max_intervals=1000),
dcc.Store(id="client-question", data="")
], style={"padding": "1rem", "backgroundColor": "#fff", "height": "100vh", "overflowY": "auto"})
app.layout = html.Div([
dcc.Store(id="session-id", storage_type="local"),
dcc.Location(id="url"),
dcc.Store(id="selected-history", data=None),
html.Div([
html.Div(left_navbar_static(), id='left-navbar', style={"width": "30vw", "height": "100vh", "position": "fixed", "left": 0, "top": 0, "zIndex": 2, "overflowY": "auto"}),
html.Div(right_main_static(), id='right-main', style={"marginLeft": "30vw", "width": "70vw", "overflowY": "auto"})
], style={"display": "flex"}),
dcc.Store(id="clear-input", data=False),
dcc.Store(id="scroll-bottom", data=0),
# clientside callback for textarea enter/shift-enter
dcc.Store(id="enter-pressed", data=False)
])
# JS callback to intercept Enter/Shift+Enter for dcc.Textarea
app.clientside_callback(
"""
function(n, value) {
var ta = document.getElementById('user-input');
if (!ta) return window.dash_clientside.no_update;
if (!window._asktricare_enter_handler) {
ta.addEventListener('keydown', function(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
var btn = document.getElementById('hidden-send');
if (btn) btn.click();
}
});
window._asktricare_enter_handler = true;
}
return window.dash_clientside.no_update;
}
""",
Output('enter-pressed', 'data'),
Input('user-input', 'n_blur'),
State('user-input', 'value')
)
def _is_supported_doc(filename):
ext = os.path.splitext(filename)[1].lower()
return ext in [".txt", ".pdf", ".md", ".docx"]
def _extract_text_from_upload(filepath, ext):
# Only .txt and .md are extracted for now
if ext in [".txt", ".md"]:
try:
with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
text = f.read()
return text
except Exception as e:
logger.error(f"Error reading {filepath}: {e}")
return ""
else:
return ""
@app.callback(
Output("session-id", "data"),
Input("url", "href"),
prevent_initial_call=False
)
def assign_session_id(_):
sid = get_session_id()
d = get_session_dir(sid)
load_session_state(sid)
logger.info(f"Assigned session id: {sid}")
return sid
@app.callback(
Output("upload-list", "children"),
Output("chat-history-list", "children"),
Output("chat-window", "children"),
Output("error-message", "children"),
Output("stream-interval", "disabled"),
Output("stream-interval", "n_intervals"),
Output("user-input", "value"),
Output("selected-history", "data"),
Input("session-id", "data"),
Input("send-btn", "n_clicks"),
Input("file-upload", "contents"),
Input("new-chat-btn", "n_clicks"),
Input("stream-interval", "n_intervals"),
Input({"type": "chat-history-item", "index": dash.ALL}, "n_clicks"),
Input('hidden-send', 'n_clicks'),
State("file-upload", "filename"),
State("user-input", "value"),
State("selected-history", "data"),
State("chat-history-list", "children"),
prevent_initial_call=False
)
def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, stream_n, chat_history_clicks, hidden_send_clicks, file_names, user_input, selected_history, chat_history_list_children):
trigger = callback_context.triggered[0]['prop_id'].split('.')[0] if callback_context.triggered else ""
session_id = session_id or get_session_id()
session_lock = get_session_lock(session_id)
with session_lock:
load_session_state(session_id)
state = get_session_state(session_id)
error = ""
start_streaming = False
chat_histories = state.get("chat_histories", [])
uploads = state.get("uploads", [])
# Handle clickable chat history loading
history_index_clicked = None
if trigger == "" and chat_history_clicks:
for idx, n in enumerate(chat_history_clicks):
if n:
history_index_clicked = idx
break
elif trigger.startswith("{\"type\":\"chat-history-item\""):
for idx, n in enumerate(chat_history_clicks):
if n:
history_index_clicked = idx
break
if history_index_clicked is not None and history_index_clicked < len(chat_histories):
selected_dialog = chat_histories[history_index_clicked]["dialog"]
state["messages"] = selected_dialog.copy()
state["stream_buffer"] = ""
state["streaming"] = False
save_session_state(session_id)
logger.info(f"Session {session_id}: Loaded chat history index {history_index_clicked}")
# Rebuild downstream outputs with this dialog
chat_cards = []
for msg in state["messages"]:
chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
chat_history_items = [
html.Li(
dbc.Button(
chat["name"],
id={"type": "chat-history-item", "index": i},
color="tertiary",
outline=False,
style={"width": "100%", "textAlign": "left", "fontWeight": "bold" if i == len(chat_histories)-1 else "normal", "fontSize": "0.92rem", "marginBottom": "0.2rem"}
)
)
for i, chat in enumerate(chat_histories[-6:])
]
return (
upload_cards,
chat_history_items,
chat_cards,
"",
True,
0,
"",
history_index_clicked
)
# Handle File Upload
file_was_uploaded_and_sent = False
if trigger == "file-upload" and file_contents and file_names:
uploads = []
if not isinstance(file_contents, list):
file_contents = [file_contents]
file_names = [file_names]
for c, n in zip(file_contents, file_names):
header, data = c.split(',', 1)
ext = os.path.splitext(n)[1].lower()
is_img = ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"]
fname = secure_filename(f"{datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{n}")
session_dir = get_session_dir(session_id)
fp = os.path.join(session_dir, fname)
with open(fp, "wb") as f:
f.write(base64.b64decode(data))
uploads.append({"name": fname, "is_img": is_img, "path": fp})
if _is_supported_doc(n) and not is_img:
text = _extract_text_from_upload(fp, ext)
if text.strip():
embed_user_doc(session_id, fname, text)
logger.info(f"Session {session_id}: Uploaded doc '{n}' embedded for user vector store")
state["uploads"].extend(uploads)
save_session_state(session_id)
logger.info(f"Session {session_id}: Uploaded files {[u['name'] for u in uploads]}")
# Determine if send was triggered (via send-btn, hidden-send, or enter)
send_triggered = False
if trigger == "send-btn" or trigger == "hidden-send":
send_triggered = True
# Handle Send
if send_triggered and user_input and user_input.strip():
question = user_input.strip()
state["messages"].append({"role": "user", "content": question})
state["streaming"] = True
state["stream_buffer"] = ""
save_session_state(session_id)
def run_stream(session_id, messages, question):
try:
system_prompt = load_system_prompt()
# Retrieve relevant context from global RAG
rag_chunks = []
try:
# Search global docs
global_embeds = []
global_texts = []
global_fnames = []
for fname, emb in EMBEDDING_INDEX.items():
global_embeds.append(emb)
global_texts.append(EMBEDDING_TEXTS[fname])
global_fnames.append(fname)
global_rag = semantic_search(question, global_embeds, global_texts, global_fnames, top_k=2)
if global_rag:
for r in global_rag:
rag_chunks.append(f"Global doc [{r['filename']}]:\n{r['text'][:1000]}")
# Search user docs
user_embeds, user_texts, user_fnames = get_user_embeddings(session_id)
user_rag = semantic_search(question, user_embeds, user_texts, user_fnames, top_k=2)
if user_rag:
for r in user_rag:
rag_chunks.append(f"User upload [{r['filename']}]:\n{r['text'][:1000]}")
except Exception as e:
logger.error(f"Session {session_id}: RAG error: {e}")
context_block = ""
if rag_chunks:
context_block = "The following sources may help answer the question:\n\n" + "\n\n".join(rag_chunks) + "\n\n"
msg_list = [{"role": "system", "content": system_prompt}]
if context_block:
msg_list.append({"role": "system", "content": context_block})
for m in messages:
msg_list.append({"role": m["role"], "content": m["content"]})
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=msg_list,
max_tokens=700,
temperature=0.2,
stream=True,
)
reply = ""
for chunk in response:
delta = chunk["choices"][0]["delta"]
content = delta.get("content", "")
if content:
reply += content
session_lock = get_session_lock(session_id)
with session_lock:
load_session_state(session_id)
state = get_session_state(session_id)
state["stream_buffer"] = reply
save_session_state(session_id)
session_lock = get_session_lock(session_id)
with session_lock:
load_session_state(session_id)
state = get_session_state(session_id)
state["messages"].append({"role": "assistant", "content": reply})
state["stream_buffer"] = ""
state["streaming"] = False
save_session_state(session_id)
logger.info(f"Session {session_id}: User: {question} | Assistant: {reply}")
except Exception as e:
session_lock = get_session_lock(session_id)
with session_lock:
load_session_state(session_id)
state = get_session_state(session_id)
state["streaming"] = False
state["stream_buffer"] = ""
save_session_state(session_id)
logger.error(f"Session {session_id}: Streaming error: {e}")
threading.Thread(target=run_stream, args=(session_id, list(state["messages"]), question), daemon=True).start()
start_streaming = True
# Handle New Chat button logic: auto-name and reset
if trigger == "new-chat-btn":
chat_dialog = list(state.get("messages", []))
if not chat_dialog:
error = "No chat to save. Start chatting!"
else:
chat_title = "Chat " + datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M")
state.setdefault("chat_histories", []).append({
"name": chat_title,
"dialog": chat_dialog
})
state["messages"] = []
state["stream_buffer"] = ""
state["streaming"] = False
save_session_state(session_id)
logger.info(f"Session {session_id}: Saved chat history '{chat_title}'")
# Handle polling for streaming
if trigger == "stream-interval":
chat_history = state.get("messages", [])
chat_cards = []
for msg in chat_history:
chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
if state.get("streaming", False):
if state.get("stream_buffer", ""):
chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
chat_history_items = [
html.Li(
dbc.Button(
chat["name"],
id={"type": "chat-history-item", "index": i},
color="tertiary",
outline=False,
style={"width": "100%", "textAlign": "left", "fontWeight": "bold" if i == len(state.get("chat_histories", []) )-1 else "normal", "fontSize": "0.92rem", "marginBottom": "0.2rem"}
)
)
for i, chat in enumerate(state.get("chat_histories", [])[-6:])
]
return (
upload_cards,
chat_history_items,
chat_cards,
"",
False,
stream_n+1,
"",
selected_history
)
else:
chat_cards = []
for msg in state.get("messages", []):
chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])]
chat_history_items = [
html.Li(
dbc.Button(
chat["name"],
id={"type": "chat-history-item", "index": i},
color="tertiary",
outline=False,
style={"width": "100%", "textAlign": "left", "fontWeight": "bold" if i == len(state.get("chat_histories", []))-1 else "normal", "fontSize": "0.92rem", "marginBottom": "0.2rem"}
)
)
for i, chat in enumerate(state.get("chat_histories", [])[-6:])
]
return (
upload_cards,
chat_history_items,
chat_cards,
"",
True,
0,
"",
selected_history
)
# Default: Build Uploads, Chat History and Chat Window
chat_history = state.get("messages", [])
uploads = state.get("uploads", [])
chat_histories = state.get("chat_histories", [])
upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in uploads]
chat_history_items = [
html.Li(
dbc.Button(
chat["name"],
id={"type": "chat-history-item", "index": i},
color="tertiary",
outline=False,
style={"width": "100%", "textAlign": "left", "fontWeight": "bold" if i == len(chat_histories)-1 else "normal", "fontSize": "0.92rem", "marginBottom": "0.2rem"}
)
)
for i, chat in enumerate(chat_histories[-6:])
]
chat_cards = []
for msg in chat_history:
chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user")))
if state.get("streaming", False):
if state.get("stream_buffer", ""):
chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False))
return upload_cards, chat_history_items, chat_cards, error, False, 0, "", selected_history
# Always clear input after send
if send_triggered:
return upload_cards, chat_history_items, chat_cards, error, (not state.get("streaming", False)), 0, "", selected_history
return upload_cards, chat_history_items, chat_cards, error, (not state.get("streaming", False)), 0, user_input or "", selected_history
@app_flask.after_request
def set_session_cookie(resp):
sid = flask_request.cookies.get("asktricare_session_id")
if not sid:
sid = str(uuid.uuid4())
resp.set_cookie("asktricare_session_id", sid, max_age=60*60*24*7, path="/")
return resp
def cleanup_sessions(max_age_hours=48):
now = datetime.datetime.utcnow()
for sid in os.listdir(SESSION_DIR_BASE):
d = os.path.join(SESSION_DIR_BASE, sid)
try:
state_path = os.path.join(d, "state.json")
if os.path.exists(state_path):
with open(state_path, "r") as f:
st = json.load(f)
created = st.get("created")
if created and (now - datetime.datetime.fromisoformat(created)).total_seconds() > max_age_hours * 3600:
shutil.rmtree(d)
logger.info(f"Cleaned up session {sid}")
except Exception as e:
logger.error(f"Cleanup error for {sid}: {e}")
try:
import torch
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
logger.info("CUDA GPU detected and configured.")
except Exception as e:
logger.warning(f"CUDA config failed: {e}")
if __name__ == '__main__':
print("Starting the Dash application...")
app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)
print("Dash application has finished running.")