Spaces:
Paused
Paused
| import os | |
| import threading | |
| import logging | |
| import uuid | |
| import shutil | |
| import json | |
| import tempfile | |
| import glob | |
| from flask import Flask, request as flask_request, make_response | |
| import dash | |
| from dash import dcc, html, Input, Output, State, callback_context | |
| import dash_bootstrap_components as dbc | |
| import openai | |
| import base64 | |
| import datetime | |
| from werkzeug.utils import secure_filename | |
| import numpy as np | |
| import io | |
| from pdfminer.high_level import extract_text as pdf_extract_text | |
| import docx | |
| import openpyxl | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(threadName)s %(message)s") | |
| logger = logging.getLogger("AskTricare") | |
| app_flask = Flask(__name__) | |
| SESSION_DATA = {} | |
| SESSION_LOCKS = {} | |
| SESSION_DIR_BASE = os.path.join(tempfile.gettempdir(), "asktricare_sessions") | |
| os.makedirs(SESSION_DIR_BASE, exist_ok=True) | |
| openai.api_key = os.environ.get("OPENAI_API_KEY") | |
| EMBEDDING_INDEX = {} | |
| EMBEDDING_TEXTS = {} | |
| EMBEDDING_MODEL = "text-embedding-ada-002" | |
| def get_session_id(): | |
| sid = flask_request.cookies.get("asktricare_session_id") | |
| if not sid: | |
| sid = str(uuid.uuid4()) | |
| return sid | |
| def get_session_dir(session_id): | |
| d = os.path.join(SESSION_DIR_BASE, session_id) | |
| os.makedirs(d, exist_ok=True) | |
| return d | |
| def get_session_lock(session_id): | |
| if session_id not in SESSION_LOCKS: | |
| SESSION_LOCKS[session_id] = threading.Lock() | |
| return SESSION_LOCKS[session_id] | |
| def get_session_state(session_id): | |
| if session_id not in SESSION_DATA: | |
| SESSION_DATA[session_id] = { | |
| "messages": [], | |
| "uploads": [], | |
| "created": datetime.datetime.utcnow().isoformat(), | |
| "streaming": False, | |
| "stream_buffer": "", | |
| "chat_histories": [] | |
| } | |
| return SESSION_DATA[session_id] | |
| def save_session_state(session_id): | |
| state = get_session_state(session_id) | |
| d = get_session_dir(session_id) | |
| with open(os.path.join(d, "state.json"), "w") as f: | |
| json.dump(state, f) | |
| def load_session_state(session_id): | |
| d = get_session_dir(session_id) | |
| path = os.path.join(d, "state.json") | |
| if os.path.exists(path): | |
| with open(path, "r") as f: | |
| SESSION_DATA[session_id] = json.load(f) | |
| def load_system_prompt(): | |
| prompt_path = os.path.join(os.getcwd(), "system_prompt.txt") | |
| try: | |
| with open(prompt_path, "r", encoding="utf-8") as f: | |
| return f.read().strip() | |
| except Exception as e: | |
| logger.error(f"Failed to load system prompt: {e}") | |
| return "You are Ask Tricare, a helpful assistant for TRICARE health benefits. Respond conversationally, and cite relevant sources when possible. If you do not know, say so." | |
| def embed_docs_folder(): | |
| global EMBEDDING_INDEX, EMBEDDING_TEXTS | |
| docs_folder = os.path.join(os.getcwd(), "docs") | |
| if not os.path.isdir(docs_folder): | |
| logger.warning(f"Docs folder '{docs_folder}' does not exist. Skipping embedding.") | |
| return | |
| doc_files = [] | |
| for ext in ("*.txt", "*.md", "*.pdf"): | |
| doc_files.extend(glob.glob(os.path.join(docs_folder, ext))) | |
| for doc_path in doc_files: | |
| fname = os.path.basename(doc_path) | |
| if fname in EMBEDDING_INDEX: | |
| continue | |
| try: | |
| with open(doc_path, "r", encoding="utf-8", errors="ignore") as f: | |
| text = f.read() | |
| if not text.strip(): | |
| continue | |
| chunk = text[:4000] | |
| response = openai.Embedding.create( | |
| input=[chunk], | |
| model=EMBEDDING_MODEL | |
| ) | |
| embedding = response['data'][0]['embedding'] | |
| EMBEDDING_INDEX[fname] = embedding | |
| EMBEDDING_TEXTS[fname] = chunk | |
| logger.info(f"Embedded doc: {fname}") | |
| except Exception as e: | |
| logger.error(f"Embedding failed for {fname}: {e}") | |
| embed_docs_folder() | |
| def embed_user_doc(session_id, filename, text): | |
| session_dir = get_session_dir(session_id) | |
| if not text.strip(): | |
| return | |
| try: | |
| chunk = text[:4000] | |
| response = openai.Embedding.create( | |
| input=[chunk], | |
| model=EMBEDDING_MODEL | |
| ) | |
| embedding = response['data'][0]['embedding'] | |
| user_embeds_path = os.path.join(session_dir, "user_embeds.json") | |
| if os.path.exists(user_embeds_path): | |
| with open(user_embeds_path, "r") as f: | |
| user_embeds = json.load(f) | |
| else: | |
| user_embeds = {"embeddings": [], "texts": [], "filenames": []} | |
| user_embeds["embeddings"].append(embedding) | |
| user_embeds["texts"].append(chunk) | |
| user_embeds["filenames"].append(filename) | |
| with open(user_embeds_path, "w") as f: | |
| json.dump(user_embeds, f) | |
| logger.info(f"Session {session_id}: Embedded user doc {filename}") | |
| except Exception as e: | |
| logger.error(f"Session {session_id}: Failed to embed user doc {filename}: {e}") | |
| def get_user_embeddings(session_id): | |
| session_dir = get_session_dir(session_id) | |
| user_embeds_path = os.path.join(session_dir, "user_embeds.json") | |
| if os.path.exists(user_embeds_path): | |
| with open(user_embeds_path, "r") as f: | |
| d = json.load(f) | |
| embeds = np.array(d.get("embeddings", [])) | |
| texts = d.get("texts", []) | |
| filenames = d.get("filenames", []) | |
| return embeds, texts, filenames | |
| return np.array([]), [], [] | |
| def semantic_search(query, embed_matrix, texts, filenames, top_k=2): | |
| if len(embed_matrix) == 0: | |
| return [] | |
| try: | |
| q_embed = openai.Embedding.create(input=[query], model=EMBEDDING_MODEL)["data"][0]["embedding"] | |
| q_embed = np.array(q_embed) | |
| embed_matrix = np.array(embed_matrix) | |
| scores = np.dot(embed_matrix, q_embed) / (np.linalg.norm(embed_matrix, axis=1) * np.linalg.norm(q_embed) + 1e-8) | |
| idx = np.argsort(scores)[::-1][:top_k] | |
| results = [] | |
| for i in idx: | |
| results.append({"filename": filenames[i], "text": texts[i], "score": float(scores[i])}) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Semantic search error: {e}") | |
| return [] | |
| app = dash.Dash( | |
| __name__, | |
| server=app_flask, | |
| suppress_callback_exceptions=True, | |
| external_stylesheets=[dbc.themes.BOOTSTRAP, "/assets/custom.css"], | |
| update_title="Ask Tricare" | |
| ) | |
| def chat_message_card(msg, is_user): | |
| align = "end" if is_user else "start" | |
| color = "primary" if is_user else "secondary" | |
| avatar = "🧑" if is_user else "🤖" | |
| return dbc.Card( | |
| dbc.CardBody([ | |
| html.Div([ | |
| html.Span(avatar, style={"fontSize": "2rem"}), | |
| html.Span(msg, style={"whiteSpace": "pre-wrap", "marginLeft": "0.75rem", "overflowWrap": "break-word", "wordBreak": "break-word"}) | |
| ], style={"display": "flex", "alignItems": "center", "justifyContent": align}) | |
| ]), | |
| className=f"mb-2 ms-3 me-3", | |
| color=color, | |
| inverse=is_user, | |
| style={"maxWidth": "80%", "alignSelf": f"flex-{align}"} | |
| ) | |
| def uploaded_file_card(filename, is_img): | |
| ext = os.path.splitext(filename)[1].lower() | |
| icon = "🖼️" if is_img else "📄" | |
| return dbc.Card( | |
| dbc.CardBody([ | |
| html.Span(icon, style={"fontSize": "2rem", "marginRight": "0.5rem"}), | |
| html.Span(filename) | |
| ]), | |
| className="mb-2", | |
| color="tertiary" | |
| ) | |
| def disclaimer_card(): | |
| return dbc.Card( | |
| dbc.CardBody([ | |
| html.H5("Disclaimer", className="card-title"), | |
| html.P("This information is not private. Do not send PII or PHI. For official guidance visit the Tricare website.", style={"fontSize": "0.95rem"}) | |
| ]), | |
| className="mb-2" | |
| ) | |
| def left_navbar_static(): | |
| return html.Div([ | |
| html.H3("Ask Tricare", className="mb-3 mt-3", style={"fontWeight": "bold"}), | |
| disclaimer_card(), | |
| dcc.Upload( | |
| id="file-upload", | |
| children=dbc.Button("Upload Document/Image", color="secondary", className="mb-2", style={"width": "100%"}), | |
| multiple=True, | |
| style={"width": "100%"} | |
| ), | |
| html.Div(id="upload-list"), | |
| html.Hr(), | |
| html.H5("Chat History", className="mb-2"), | |
| html.Ul(id="chat-history-list", style={"listStyle": "none", "paddingLeft": "0"}), | |
| ], style={"padding": "1rem", "backgroundColor": "#f8f9fa", "height": "100vh", "overflowY": "auto"}) | |
| def chat_box_card(): | |
| return dbc.Card( | |
| dbc.CardBody([ | |
| html.Div(id="chat-window", style={ | |
| "height": "60vh", | |
| "overflowY": "scroll", | |
| "overflowX": "auto", | |
| "display": "flex", | |
| "flexDirection": "column", | |
| "justifyContent": "flex-end", | |
| "backgroundColor": "#fff", | |
| "padding": "0.5rem", | |
| "borderRadius": "0.5rem" | |
| }) | |
| ]), | |
| className="mt-3", | |
| style={ | |
| "height": "62vh", | |
| "overflowY": "scroll", | |
| "overflowX": "auto" | |
| } | |
| ) | |
| def user_input_card(): | |
| return dbc.Card( | |
| dbc.CardBody([ | |
| html.Div([ | |
| dcc.Textarea( | |
| id="user-input", | |
| placeholder="Type your question...", | |
| style={"width": "100%", "height": "60px", "resize": "vertical", "wordWrap": "break-word"}, | |
| wrap="soft", | |
| maxLength=1000, | |
| n_blur=0, | |
| ), | |
| dcc.Store(id="enter-triggered", data=False), | |
| html.Div([ | |
| dbc.Button("Send", id="send-btn", color="primary", className="mt-2 me-2", style={"minWidth": "100px"}), | |
| dbc.Button("New Chat", id="new-chat-btn", color="secondary", className="mt-2", style={"minWidth": "110px"}), | |
| ], style={"float": "right", "display": "flex", "gap": "0.5rem"}), | |
| dcc.Store(id="user-input-store", data="", storage_type="session"), | |
| html.Button(id='hidden-send', style={'display': 'none'}) | |
| ], style={"marginTop": "1rem"}), | |
| html.Div(id="error-message", style={"color": "#bb2124", "marginTop": "0.5rem"}), | |
| dcc.Store(id="should-clear-input", data=False) | |
| ]) | |
| ) | |
| def right_main_static(): | |
| return html.Div([ | |
| chat_box_card(), | |
| user_input_card(), | |
| dcc.Loading(id="loading", type="default", fullscreen=False, style={"position": "absolute", "top": "5%", "left": "50%"}), | |
| dcc.Interval(id="stream-interval", interval=400, n_intervals=0, disabled=True, max_intervals=1000), | |
| dcc.Store(id="client-question", data="") | |
| ], style={"padding": "1rem", "backgroundColor": "#fff", "height": "100vh", "overflowY": "auto"}) | |
| app.layout = html.Div([ | |
| dcc.Store(id="session-id", storage_type="local"), | |
| dcc.Location(id="url"), | |
| dcc.Store(id="selected-history", data=None), | |
| html.Div([ | |
| html.Div(left_navbar_static(), id='left-navbar', style={"width": "30vw", "height": "100vh", "position": "fixed", "left": 0, "top": 0, "zIndex": 2, "overflowY": "auto"}), | |
| html.Div(right_main_static(), id='right-main', style={"marginLeft": "30vw", "width": "70vw", "overflowY": "auto"}) | |
| ], style={"display": "flex"}), | |
| dcc.Store(id="clear-input", data=False), | |
| dcc.Store(id="scroll-bottom", data=0), | |
| dcc.Store(id="enter-pressed", data=False) | |
| ]) | |
| app.clientside_callback( | |
| """ | |
| function(n, value) { | |
| var ta = document.getElementById('user-input'); | |
| if (!ta) return window.dash_clientside.no_update; | |
| if (!window._asktricare_enter_handler) { | |
| ta.addEventListener('keydown', function(e) { | |
| if (e.key === 'Enter' && !e.shiftKey) { | |
| e.preventDefault(); | |
| var btn = document.getElementById('hidden-send'); | |
| if (btn) btn.click(); | |
| } | |
| }); | |
| window._asktricare_enter_handler = true; | |
| } | |
| return window.dash_clientside.no_update; | |
| } | |
| """, | |
| Output('enter-pressed', 'data'), | |
| Input('user-input', 'n_blur'), | |
| State('user-input', 'value') | |
| ) | |
| def _is_supported_doc(filename): | |
| ext = os.path.splitext(filename)[1].lower() | |
| return ext in [".txt", ".pdf", ".md", ".docx", ".xlsx"] | |
| def _extract_text_from_upload(filepath, ext): | |
| try: | |
| if ext in [".txt", ".md"]: | |
| with open(filepath, "r", encoding="utf-8", errors="ignore") as f: | |
| text = f.read() | |
| return text | |
| elif ext == ".pdf": | |
| try: | |
| text = pdf_extract_text(filepath) | |
| return text | |
| except Exception as e: | |
| logger.error(f"Error reading PDF {filepath}: {e}") | |
| return "" | |
| elif ext == ".docx": | |
| try: | |
| doc = docx.Document(filepath) | |
| paragraphs = [p.text for p in doc.paragraphs if p.text.strip()] | |
| return "\n".join(paragraphs) | |
| except Exception as e: | |
| logger.error(f"Error reading DOCX {filepath}: {e}") | |
| return "" | |
| elif ext == ".xlsx": | |
| try: | |
| wb = openpyxl.load_workbook(filepath, read_only=True, data_only=True) | |
| text_rows = [] | |
| for ws in wb.worksheets: | |
| for row in ws.iter_rows(values_only=True): | |
| row_strs = [str(cell) for cell in row if cell is not None] | |
| if any(row_strs): | |
| text_rows.append("\t".join(row_strs)) | |
| return "\n".join(text_rows) | |
| except Exception as e: | |
| logger.error(f"Error reading XLSX {filepath}: {e}") | |
| return "" | |
| else: | |
| return "" | |
| except Exception as e: | |
| logger.error(f"Error extracting text from {filepath}: {e}") | |
| return "" | |
| def assign_session_id(_): | |
| sid = get_session_id() | |
| d = get_session_dir(sid) | |
| load_session_state(sid) | |
| logger.info(f"Assigned session id: {sid}") | |
| return sid | |
| def main_callback(session_id, send_clicks, file_contents, new_chat_clicks, stream_n, chat_history_clicks, hidden_send_clicks, file_names, user_input, selected_history, chat_history_list_children): | |
| trigger = callback_context.triggered[0]['prop_id'].split('.')[0] if callback_context.triggered else "" | |
| session_id = session_id or get_session_id() | |
| session_lock = get_session_lock(session_id) | |
| with session_lock: | |
| load_session_state(session_id) | |
| state = get_session_state(session_id) | |
| error = "" | |
| start_streaming = False | |
| chat_histories = state.get("chat_histories", []) | |
| uploads = state.get("uploads", []) | |
| history_index_clicked = None | |
| if trigger == "" and chat_history_clicks: | |
| for idx, n in enumerate(chat_history_clicks): | |
| if n: | |
| history_index_clicked = idx | |
| break | |
| elif trigger.startswith("{\"type\":\"chat-history-item\""): | |
| for idx, n in enumerate(chat_history_clicks): | |
| if n: | |
| history_index_clicked = idx | |
| break | |
| if history_index_clicked is not None and history_index_clicked < len(chat_histories): | |
| selected_dialog = chat_histories[history_index_clicked]["dialog"] | |
| state["messages"] = selected_dialog.copy() | |
| state["stream_buffer"] = "" | |
| state["streaming"] = False | |
| save_session_state(session_id) | |
| logger.info(f"Session {session_id}: Loaded chat history index {history_index_clicked}") | |
| chat_cards = [] | |
| for msg in state["messages"]: | |
| chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user"))) | |
| upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])] | |
| chat_history_items = [ | |
| html.Li( | |
| dbc.Button( | |
| chat["name"], | |
| id={"type": "chat-history-item", "index": i}, | |
| color="tertiary", | |
| outline=False, | |
| style={"width": "100%", "textAlign": "left", "fontWeight": "bold" if i == len(chat_histories)-1 else "normal", "fontSize": "0.92rem", "marginBottom": "0.2rem"} | |
| ) | |
| ) | |
| for i, chat in enumerate(chat_histories[-6:]) | |
| ] | |
| return ( | |
| upload_cards, | |
| chat_history_items, | |
| chat_cards, | |
| "", | |
| True, | |
| 0, | |
| "", | |
| history_index_clicked | |
| ) | |
| file_was_uploaded_and_sent = False | |
| file_upload_message = None | |
| doc_texts_to_send = [] | |
| if trigger == "file-upload" and file_contents and file_names: | |
| uploads = [] | |
| file_upload_messages = [] | |
| if not isinstance(file_contents, list): | |
| file_contents = [file_contents] | |
| file_names = [file_names] | |
| for c, n in zip(file_contents, file_names): | |
| header, data = c.split(',', 1) | |
| ext = os.path.splitext(n)[1].lower() | |
| is_img = ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"] | |
| fname = secure_filename(f"{datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')}_{n}") | |
| session_dir = get_session_dir(session_id) | |
| fp = os.path.join(session_dir, fname) | |
| with open(fp, "wb") as f: | |
| f.write(base64.b64decode(data)) | |
| uploads.append({"name": fname, "is_img": is_img, "path": fp}) | |
| if _is_supported_doc(n) and not is_img: | |
| text = _extract_text_from_upload(fp, ext) | |
| if text.strip(): | |
| embed_user_doc(session_id, fname, text) | |
| logger.info(f"Session {session_id}: Uploaded doc '{n}' embedded for user vector store") | |
| preview = text[:1000] | |
| file_upload_messages.append({ | |
| "role": "user", | |
| "content": f"[Document uploaded: {n}]\n{preview if preview.strip() else '[No text extracted]'}" | |
| }) | |
| doc_texts_to_send.append(text.strip()) | |
| else: | |
| file_upload_messages.append({ | |
| "role": "user", | |
| "content": f"[Document uploaded: {n}]\n[No text extracted]" | |
| }) | |
| elif is_img: | |
| file_upload_messages.append({ | |
| "role": "user", | |
| "content": f"[Image uploaded: {n}]" | |
| }) | |
| else: | |
| file_upload_messages.append({ | |
| "role": "user", | |
| "content": f"[File uploaded: {n}]" | |
| }) | |
| state["uploads"].extend(uploads) | |
| # Add all file upload messages to chat | |
| for msg in file_upload_messages: | |
| state["messages"].append(msg) | |
| save_session_state(session_id) | |
| logger.info(f"Session {session_id}: Uploaded files {[u['name'] for u in uploads]}") | |
| # If text was extracted, send it to OpenAI as a question for assistant reply | |
| if doc_texts_to_send: | |
| doc_question = "\n\n".join(doc_texts_to_send) | |
| state["messages"].append({"role": "user", "content": doc_question}) | |
| state["streaming"] = True | |
| state["stream_buffer"] = "" | |
| save_session_state(session_id) | |
| def run_stream_for_doc(session_id, messages, doc_question): | |
| try: | |
| system_prompt = load_system_prompt() | |
| rag_chunks = [] | |
| try: | |
| global_embeds = [] | |
| global_texts = [] | |
| global_fnames = [] | |
| for fname, emb in EMBEDDING_INDEX.items(): | |
| global_embeds.append(emb) | |
| global_texts.append(EMBEDDING_TEXTS[fname]) | |
| global_fnames.append(fname) | |
| global_rag = semantic_search(doc_question, global_embeds, global_texts, global_fnames, top_k=2) | |
| if global_rag: | |
| for r in global_rag: | |
| rag_chunks.append(f"Global doc [{r['filename']}]:\n{r['text'][:1000]}") | |
| user_embeds, user_texts, user_fnames = get_user_embeddings(session_id) | |
| user_rag = semantic_search(doc_question, user_embeds, user_texts, user_fnames, top_k=2) | |
| if user_rag: | |
| for r in user_rag: | |
| rag_chunks.append(f"User upload [{r['filename']}]:\n{r['text'][:1000]}") | |
| except Exception as e: | |
| logger.error(f"Session {session_id}: RAG error (doc upload): {e}") | |
| context_block = "" | |
| if rag_chunks: | |
| context_block = "The following sources may help answer the question:\n\n" + "\n\n".join(rag_chunks) + "\n\n" | |
| msg_list = [{"role": "system", "content": system_prompt}] | |
| if context_block: | |
| msg_list.append({"role": "system", "content": context_block}) | |
| for m in messages: | |
| msg_list.append({"role": m["role"], "content": m["content"]}) | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=msg_list, | |
| max_tokens=700, | |
| temperature=0.2, | |
| stream=True, | |
| ) | |
| reply = "" | |
| for chunk in response: | |
| delta = chunk["choices"][0]["delta"] | |
| content = delta.get("content", "") | |
| if content: | |
| reply += content | |
| session_lock = get_session_lock(session_id) | |
| with session_lock: | |
| load_session_state(session_id) | |
| state = get_session_state(session_id) | |
| state["stream_buffer"] = reply | |
| save_session_state(session_id) | |
| session_lock = get_session_lock(session_id) | |
| with session_lock: | |
| load_session_state(session_id) | |
| state = get_session_state(session_id) | |
| state["messages"].append({"role": "assistant", "content": reply}) | |
| state["stream_buffer"] = "" | |
| state["streaming"] = False | |
| save_session_state(session_id) | |
| logger.info(f"Session {session_id}: Assistant responded to doc upload") | |
| except Exception as e: | |
| session_lock = get_session_lock(session_id) | |
| with session_lock: | |
| load_session_state(session_id) | |
| state = get_session_state(session_id) | |
| state["streaming"] = False | |
| state["stream_buffer"] = "" | |
| save_session_state(session_id) | |
| logger.error(f"Session {session_id}: Streaming error for doc upload: {e}") | |
| threading.Thread(target=run_stream_for_doc, args=(session_id, list(state["messages"]), doc_question), daemon=True).start() | |
| start_streaming = True | |
| send_triggered = False | |
| if trigger == "send-btn" or trigger == "hidden-send": | |
| send_triggered = True | |
| if send_triggered and user_input and user_input.strip(): | |
| question = user_input.strip() | |
| state["messages"].append({"role": "user", "content": question}) | |
| state["streaming"] = True | |
| state["stream_buffer"] = "" | |
| save_session_state(session_id) | |
| def run_stream(session_id, messages, question): | |
| try: | |
| system_prompt = load_system_prompt() | |
| rag_chunks = [] | |
| try: | |
| global_embeds = [] | |
| global_texts = [] | |
| global_fnames = [] | |
| for fname, emb in EMBEDDING_INDEX.items(): | |
| global_embeds.append(emb) | |
| global_texts.append(EMBEDDING_TEXTS[fname]) | |
| global_fnames.append(fname) | |
| global_rag = semantic_search(question, global_embeds, global_texts, global_fnames, top_k=2) | |
| if global_rag: | |
| for r in global_rag: | |
| rag_chunks.append(f"Global doc [{r['filename']}]:\n{r['text'][:1000]}") | |
| user_embeds, user_texts, user_fnames = get_user_embeddings(session_id) | |
| user_rag = semantic_search(question, user_embeds, user_texts, user_fnames, top_k=2) | |
| if user_rag: | |
| for r in user_rag: | |
| rag_chunks.append(f"User upload [{r['filename']}]:\n{r['text'][:1000]}") | |
| except Exception as e: | |
| logger.error(f"Session {session_id}: RAG error: {e}") | |
| context_block = "" | |
| if rag_chunks: | |
| context_block = "The following sources may help answer the question:\n\n" + "\n\n".join(rag_chunks) + "\n\n" | |
| msg_list = [{"role": "system", "content": system_prompt}] | |
| if context_block: | |
| msg_list.append({"role": "system", "content": context_block}) | |
| for m in messages: | |
| msg_list.append({"role": m["role"], "content": m["content"]}) | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=msg_list, | |
| max_tokens=700, | |
| temperature=0.2, | |
| stream=True, | |
| ) | |
| reply = "" | |
| for chunk in response: | |
| delta = chunk["choices"][0]["delta"] | |
| content = delta.get("content", "") | |
| if content: | |
| reply += content | |
| session_lock = get_session_lock(session_id) | |
| with session_lock: | |
| load_session_state(session_id) | |
| state = get_session_state(session_id) | |
| state["stream_buffer"] = reply | |
| save_session_state(session_id) | |
| session_lock = get_session_lock(session_id) | |
| with session_lock: | |
| load_session_state(session_id) | |
| state = get_session_state(session_id) | |
| state["messages"].append({"role": "assistant", "content": reply}) | |
| state["stream_buffer"] = "" | |
| state["streaming"] = False | |
| save_session_state(session_id) | |
| logger.info(f"Session {session_id}: User: {question} | Assistant: {reply}") | |
| except Exception as e: | |
| session_lock = get_session_lock(session_id) | |
| with session_lock: | |
| load_session_state(session_id) | |
| state = get_session_state(session_id) | |
| state["streaming"] = False | |
| state["stream_buffer"] = "" | |
| save_session_state(session_id) | |
| logger.error(f"Session {session_id}: Streaming error: {e}") | |
| threading.Thread(target=run_stream, args=(session_id, list(state["messages"]), question), daemon=True).start() | |
| start_streaming = True | |
| if trigger == "new-chat-btn": | |
| chat_dialog = list(state.get("messages", [])) | |
| if not chat_dialog: | |
| error = "No chat to save. Start chatting!" | |
| else: | |
| chat_title = "Chat " + datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M") | |
| state.setdefault("chat_histories", []).append({ | |
| "name": chat_title, | |
| "dialog": chat_dialog | |
| }) | |
| state["messages"] = [] | |
| state["stream_buffer"] = "" | |
| state["streaming"] = False | |
| save_session_state(session_id) | |
| logger.info(f"Session {session_id}: Saved chat history '{chat_title}'") | |
| if trigger == "stream-interval": | |
| chat_history = state.get("messages", []) | |
| chat_cards = [] | |
| for msg in chat_history: | |
| chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user"))) | |
| if state.get("streaming", False): | |
| if state.get("stream_buffer", ""): | |
| chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False)) | |
| upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])] | |
| chat_history_items = [ | |
| html.Li( | |
| dbc.Button( | |
| chat["name"], | |
| id={"type": "chat-history-item", "index": i}, | |
| color="tertiary", | |
| outline=False, | |
| style={"width": "100%", "textAlign": "left", "fontWeight": "bold" if i == len(state.get("chat_histories", []) )-1 else "normal", "fontSize": "0.92rem", "marginBottom": "0.2rem"} | |
| ) | |
| ) | |
| for i, chat in enumerate(state.get("chat_histories", [])[-6:]) | |
| ] | |
| return ( | |
| upload_cards, | |
| chat_history_items, | |
| chat_cards, | |
| "", | |
| False, | |
| stream_n+1, | |
| "", | |
| selected_history | |
| ) | |
| else: | |
| chat_cards = [] | |
| for msg in state.get("messages", []): | |
| chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user"))) | |
| upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in state.get("uploads", [])] | |
| chat_history_items = [ | |
| html.Li( | |
| dbc.Button( | |
| chat["name"], | |
| id={"type": "chat-history-item", "index": i}, | |
| color="tertiary", | |
| outline=False, | |
| style={"width": "100%", "textAlign": "left", "fontWeight": "bold" if i == len(state.get("chat_histories", []))-1 else "normal", "fontSize": "0.92rem", "marginBottom": "0.2rem"} | |
| ) | |
| ) | |
| for i, chat in enumerate(state.get("chat_histories", [])[-6:]) | |
| ] | |
| return ( | |
| upload_cards, | |
| chat_history_items, | |
| chat_cards, | |
| "", | |
| True, | |
| 0, | |
| "", | |
| selected_history | |
| ) | |
| chat_history = state.get("messages", []) | |
| uploads = state.get("uploads", []) | |
| chat_histories = state.get("chat_histories", []) | |
| upload_cards = [uploaded_file_card(os.path.basename(f["name"]), f["is_img"]) for f in uploads] | |
| chat_history_items = [ | |
| html.Li( | |
| dbc.Button( | |
| chat["name"], | |
| id={"type": "chat-history-item", "index": i}, | |
| color="tertiary", | |
| outline=False, | |
| style={"width": "100%", "textAlign": "left", "fontWeight": "bold" if i == len(chat_histories)-1 else "normal", "fontSize": "0.92rem", "marginBottom": "0.2rem"} | |
| ) | |
| ) | |
| for i, chat in enumerate(chat_histories[-6:]) | |
| ] | |
| chat_cards = [] | |
| for msg in chat_history: | |
| chat_cards.append(chat_message_card(msg['content'], is_user=(msg['role'] == "user"))) | |
| if state.get("streaming", False): | |
| if state.get("stream_buffer", ""): | |
| chat_cards.append(chat_message_card(state["stream_buffer"], is_user=False)) | |
| return upload_cards, chat_history_items, chat_cards, error, False, 0, "", selected_history | |
| if send_triggered: | |
| return upload_cards, chat_history_items, chat_cards, error, (not state.get("streaming", False)), 0, "", selected_history | |
| return upload_cards, chat_history_items, chat_cards, error, (not state.get("streaming", False)), 0, user_input or "", selected_history | |
| def set_session_cookie(resp): | |
| sid = flask_request.cookies.get("asktricare_session_id") | |
| if not sid: | |
| sid = str(uuid.uuid4()) | |
| resp.set_cookie("asktricare_session_id", sid, max_age=60*60*24*7, path="/") | |
| return resp | |
| def cleanup_sessions(max_age_hours=48): | |
| now = datetime.datetime.utcnow() | |
| for sid in os.listdir(SESSION_DIR_BASE): | |
| d = os.path.join(SESSION_DIR_BASE, sid) | |
| try: | |
| state_path = os.path.join(d, "state.json") | |
| if os.path.exists(state_path): | |
| with open(state_path, "r") as f: | |
| st = json.load(f) | |
| created = st.get("created") | |
| if created and (now - datetime.datetime.fromisoformat(created)).total_seconds() > max_age_hours * 3600: | |
| shutil.rmtree(d) | |
| logger.info(f"Cleaned up session {sid}") | |
| except Exception as e: | |
| logger.error(f"Cleanup error for {sid}: {e}") | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.set_default_tensor_type(torch.cuda.FloatTensor) | |
| logger.info("CUDA GPU detected and configured.") | |
| except Exception as e: | |
| logger.warning(f"CUDA config failed: {e}") | |
| if __name__ == '__main__': | |
| print("Starting the Dash application...") | |
| app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) | |
| print("Dash application has finished running.") |