Spaces:
Paused
Paused
import os | |
import threading | |
import logging | |
import uuid | |
import shutil | |
import json | |
import tempfile | |
from flask import Flask, request as flask_request, make_response | |
import dash | |
from dash import dcc, html, Input, Output, State, callback_context | |
import dash_bootstrap_components as dbc | |
import openai | |
import base64 | |
import datetime | |
from werkzeug.utils import secure_filename | |
import chromadb | |
from chromadb.config import Settings | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(threadName)s %(message)s") | |
logger = logging.getLogger("AskTricare") | |
app_flask = Flask(__name__) | |
SESSION_DATA = {} | |
SESSION_LOCKS = {} | |
SESSION_DIR_BASE = os.path.join(tempfile.gettempdir(), "asktricare_sessions") | |
os.makedirs(SESSION_DIR_BASE, exist_ok=True) | |
VECTOR_DB_DIR = os.path.join(os.getcwd(), "vector_db") | |
DOCS_DIR = os.path.join(os.getcwd(), "doc") | |
os.makedirs(DOCS_DIR, exist_ok=True) | |
os.makedirs(VECTOR_DB_DIR, exist_ok=True) | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
chroma_client = chromadb.Client(Settings( | |
chroma_db_impl="duckdb+parquet", | |
persist_directory=VECTOR_DB_DIR, | |
)) | |
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai.api_key) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
def ingest_docs(): | |
logger.info("Starting document ingestion...") | |
file_paths = [] | |
for root, _, files in os.walk(DOCS_DIR): | |
for f in files: | |
if f.lower().endswith(('.txt', '.pdf', '.md', '.docx')): | |
file_paths.append(os.path.join(root, f)) | |
documents = [] | |
metadatas = [] | |
ids = [] | |
for path in file_paths: | |
try: | |
with open(path, "r", encoding="utf-8", errors="ignore") as infile: | |
content = infile.read() | |
chunks = text_splitter.split_text(content) | |
for idx, chunk in enumerate(chunks): | |
documents.append(chunk) | |
metadatas.append({"source": path, "chunk": idx}) | |
ids.append(f"{os.path.basename(path)}_{idx}") | |
except Exception as e: | |
logger.error(f"Error ingesting {path}: {e}") | |
if documents: | |
vectordb = Chroma( | |
collection_name="asktricare", | |
embedding_function=embeddings, | |
persist_directory=VECTOR_DB_DIR, | |
client_settings=Settings(chroma_db_impl="duckdb+parquet", persist_directory=VECTOR_DB_DIR), | |
) | |
vectordb.add_texts(documents, metadatas=metadatas, ids=ids) | |
vectordb.persist() | |
logger.info(f"Ingested {len(documents)} chunks from {len(file_paths)} files.") | |
else: | |
logger.info("No new documents to ingest.") | |
if not os.listdir(VECTOR_DB_DIR): | |
ingest_docs() | |
vectordb = Chroma( | |
collection_name="asktricare", | |
embedding_function=embeddings, | |
persist_directory=VECTOR_DB_DIR, | |
client_settings=Settings(chroma_db_impl="duckdb+parquet", persist_directory=VECTOR_DB_DIR), | |
) | |
def get_session_id(): | |
sid = flask_request.cookies.get("asktricare_session_id") | |
if not sid: | |
sid = str(uuid.uuid4()) | |
return sid | |
def get_session_dir(session_id): | |
d = os.path.join(SESSION_DIR_BASE, session_id) | |
os.makedirs(d, exist_ok=True) | |
return d | |
def get_session_lock(session_id): | |
if session_id not in SESSION_LOCKS: | |
SESSION_LOCKS[session_id] = threading.Lock() | |
return SESSION_LOCKS[session_id] | |
def get_session_state(session_id): | |
if session_id not in SESSION_DATA: | |
SESSION_DATA[session_id] = { | |
"messages": [], | |
"uploads": [], | |
"created": datetime.datetime.utcnow().isoformat() | |
} | |
return SESSION_DATA[session_id] | |
def save_session_state(session_id): | |
state = get_session_state(session_id) | |
d = get_session_dir(session_id) | |
with open(os.path.join(d, "state.json"), "w") as f: | |
json.dump(state, f) | |
def load_session_state(session_id): | |
d = get_session_dir(session_id) | |
path = os.path.join(d, "state.json") | |
if os.path.exists(path): | |
with open(path, "r") as f: | |
SESSION_DATA[session_id] = json.load(f) | |
def load_system_prompt(): | |
prompt_path = os.path.join(os.getcwd(), "system_prompt.txt") | |
try: | |
with open(prompt_path, "r", encoding="utf-8") as f: | |
return f.read().strip() | |
except Exception as e: | |
logger.error(f"Failed to load system prompt: {e}") | |
return "You are Ask Tricare, a helpful assistant for TRICARE health benefits. Respond conversationally, and cite relevant sources when possible. If you do not know, say so." | |
app = dash.Dash( | |
__name__, | |
server=app_flask, | |
suppress_callback_exceptions=True, | |
external_stylesheets=[dbc.themes.BOOTSTRAP, "/assets/custom.css"], | |
update_title="Ask Tricare" | |
) | |
def chat_message_card(msg, is_user): | |
align = "end" if is_user else "start" | |
color = "primary" if is_user else "secondary" | |
avatar = "π§" if is_user else "π€" | |
return dbc.Card( | |
dbc.CardBody([ | |
html.Div([ | |
html.Span(avatar, style={"fontSize": "2rem"}), | |
html.Span(msg, style={"whiteSpace": "pre-wrap", "marginLeft": "0.75rem"}) | |
], style={"display": "flex", "alignItems": "center", "justifyContent": align}) | |
]), | |
className=f"mb-2 ms-3 me-3", | |
color=color, | |
inverse=is_user, | |
style={"maxWidth": "80%", "alignSelf": f"flex-{align}"} | |
) | |
def uploaded_file_card(filename, is_img): | |
ext = os.path.splitext(filename)[1].lower() | |
icon = "πΌοΈ" if is_img else "π" | |
return dbc.Card( | |
dbc.CardBody([ | |
html.Span(icon, style={"fontSize": "2rem", "marginRight": "0.5rem"}), | |
html.Span(filename) | |
]), | |
className="mb-2", | |
color="tertiary" | |
) | |
def disclaimer_card(): | |
return dbc.Card( | |
dbc.CardBody([ | |
html.H5("Disclaimer", className="card-title"), | |
html.P("This information is not private. Do not send PII or PHI. For official guidance visit the Tricare website.", style={"fontSize": "0.95rem"}) | |
]), | |
className="mb-2" | |
) | |
def left_navbar(session_id, chat_history, uploads): | |
return html.Div([ | |
html.Div([ | |
html.H3("Ask Tricare", className="mb-3 mt-3", style={"fontWeight": "bold"}), | |
disclaimer_card(), | |
dcc.Upload( | |
id="file-upload", | |
children=dbc.Button("Upload Document/Image", color="secondary", className="mb-2", style={"width": "100%"}), | |
multiple=True, | |
style={"width": "100%"} | |
), | |
html.Div([uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in uploads], id="upload-list"), | |
html.Hr(), | |
html.H5("Chat History", className="mb-2"), | |
html.Ul([html.Li(html.Span((msg['role'] + ": " + msg['content'])[:40] + ("..." if len(msg['content']) > 40 else ""), style={"fontSize": "0.92rem"})) for msg in chat_history[-6:]], style={"listStyle": "none", "paddingLeft": "0"}), | |
], style={"padding": "1rem"}) | |
], style={"backgroundColor": "#f8f9fa", "height": "100vh", "overflowY": "auto"}) | |
def right_main(chat_history, loading, error): | |
chat_cards = [] | |
for msg in chat_history: | |
if msg['role'] == "user": | |
chat_cards.append(chat_message_card(msg['content'], is_user=True)) | |
elif msg['role'] == "assistant": | |
chat_cards.append(chat_message_card(msg['content'], is_user=False)) | |
return html.Div([ | |
dbc.Card([ | |
dbc.CardBody([ | |
html.Div(chat_cards, id="chat-window", style={"minHeight": "60vh", "display": "flex", "flexDirection": "column", "justifyContent": "flex-end"}), | |
html.Div([ | |
dcc.Textarea( | |
id="user-input", | |
placeholder="Type your question...", | |
style={"width": "100%", "height": "60px", "resize": "vertical", "wordWrap": "break-word"}, | |
wrap="soft", | |
maxLength=1000, | |
autoFocus=True | |
), | |
dbc.Button("Send", id="send-btn", color="primary", className="mt-2", style={"float": "right", "minWidth": "100px"}), | |
], style={"marginTop": "1rem"}), | |
html.Div(error, id="error-message", style={"color": "#bb2124", "marginTop": "0.5rem"}), | |
]) | |
], className="mt-3"), | |
dcc.Loading(id="loading", type="default", fullscreen=False, style={"position": "absolute", "top": "5%", "left": "50%"}) | |
], style={"padding": "1rem", "backgroundColor": "#fff", "height": "100vh", "overflowY": "auto"}) | |
app.layout = html.Div([ | |
dcc.Store(id="session-id", storage_type="local"), | |
dcc.Location(id="url"), | |
html.Div([ | |
html.Div(id='left-navbar', style={"width": "30vw", "height": "100vh", "position": "fixed", "left": 0, "top": 0, "zIndex": 2, "overflowY": "auto"}), | |
html.Div(id='right-main', style={"marginLeft": "30vw", "width": "70vw", "overflowY": "auto"}) | |
], style={"display": "flex"}) | |
]) | |
def assign_session_id(_): | |
sid = get_session_id() | |
d = get_session_dir(sid) | |
load_session_state(sid) | |
logger.info(f"Assigned session id: {sid}") | |
resp = dash.no_update | |
return sid | |
def main_callback(session_id, send_clicks, file_contents, file_names, user_input, right_children, left_children): | |
trigger = callback_context.triggered[0]['prop_id'].split('.')[0] if callback_context.triggered else "" | |
if not session_id: | |
session_id = get_session_id() | |
session_lock = get_session_lock(session_id) | |
with session_lock: | |
load_session_state(session_id) | |
state = get_session_state(session_id) | |
error = "" | |
loading = False | |
if trigger == "file-upload" and file_contents and file_names: | |
uploads = [] | |
if not isinstance(file_contents, list): | |
file_contents = [file_contents] | |
file_names = [file_names] | |
for c, n in zip(file_contents, file_names): | |
header, data = c.split(',', 1) | |
ext = os.path.splitext(n)[1].lower() | |
is_img = ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"] | |
fname = secure_filename(f"{datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{n}") | |
session_dir = get_session_dir(session_id) | |
fp = os.path.join(session_dir, fname) | |
with open(fp, "wb") as f: | |
f.write(base64.b64decode(data)) | |
uploads.append({"name": fname, "is_img": is_img, "path": fp}) | |
state["uploads"].extend(uploads) | |
save_session_state(session_id) | |
logger.info(f"Session {session_id}: Uploaded files {[u['name'] for u in uploads]}") | |
if trigger == "send-btn" and user_input and user_input.strip(): | |
loading = True | |
state["messages"].append({"role": "user", "content": user_input}) | |
try: | |
docs = [] | |
try: | |
retr = vectordb.similarity_search(user_input, k=3) | |
docs = [d.page_content for d in retr] | |
except Exception as e: | |
logger.warning(f"Vector search failed: {e}") | |
context = "\n\n".join(docs) | |
system_prompt = load_system_prompt() | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
] | |
for m in state["messages"]: | |
messages.append({"role": m["role"], "content": m["content"]}) | |
if context.strip(): | |
messages.append({"role": "system", "content": f"Relevant reference material:\n{context}"}) | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
max_tokens=700, | |
temperature=0.2, | |
) | |
reply = response.choices[0].message.content | |
state["messages"].append({"role": "assistant", "content": reply}) | |
logger.info(f"Session {session_id}: User: {user_input} | Assistant: {reply}") | |
error = "" | |
except Exception as e: | |
error = f"Error: {e}" | |
logger.error(f"Session {session_id}: {error}") | |
save_session_state(session_id) | |
loading = False | |
chat_history = state.get("messages", []) | |
uploads = state.get("uploads", []) | |
left = left_navbar(session_id, chat_history, uploads) | |
right = right_main(chat_history, loading, error) | |
return left, right | |
def set_session_cookie(resp): | |
sid = flask_request.cookies.get("asktricare_session_id") | |
if not sid: | |
sid = str(uuid.uuid4()) | |
resp.set_cookie("asktricare_session_id", sid, max_age=60*60*24*7, path="/") | |
return resp | |
def cleanup_sessions(max_age_hours=48): | |
now = datetime.datetime.utcnow() | |
for sid in os.listdir(SESSION_DIR_BASE): | |
d = os.path.join(SESSION_DIR_BASE, sid) | |
try: | |
state_path = os.path.join(d, "state.json") | |
if os.path.exists(state_path): | |
with open(state_path, "r") as f: | |
st = json.load(f) | |
created = st.get("created") | |
if created and (now - datetime.datetime.fromisoformat(created)).total_seconds() > max_age_hours * 3600: | |
shutil.rmtree(d) | |
logger.info(f"Cleaned up session {sid}") | |
except Exception as e: | |
logger.error(f"Cleanup error for {sid}: {e}") | |
try: | |
import torch | |
if torch.cuda.is_available(): | |
torch.set_default_tensor_type(torch.cuda.FloatTensor) | |
logger.info("CUDA GPU detected and configured.") | |
except Exception as e: | |
logger.warning(f"CUDA config failed: {e}") | |
if __name__ == '__main__': | |
print("Starting the Dash application...") | |
app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) | |
print("Dash application has finished running.") |