Spaces:
Running
Running
Commit
·
57e04c3
1
Parent(s):
15ed85c
Populate LRU caching allowing recent chat retrieval. Add memory handler
Browse files
README.md
CHANGED
@@ -7,7 +7,7 @@ sdk: docker
|
|
7 |
sdk_version: latest
|
8 |
pinned: false
|
9 |
license: apache-2.0
|
10 |
-
short_description:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
7 |
sdk_version: latest
|
8 |
pinned: false
|
9 |
license: apache-2.0
|
10 |
+
short_description: MedicalChatbot - FAISS RAG, Gemini Flash, MongoDB vDB, LRU cache
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,13 +1,15 @@
|
|
|
|
1 |
import os
|
2 |
import faiss
|
3 |
import numpy as np
|
4 |
import time
|
5 |
import uvicorn
|
6 |
-
from fastapi import FastAPI
|
7 |
from fastapi.responses import JSONResponse
|
8 |
from pymongo import MongoClient
|
9 |
from google import genai
|
10 |
from sentence_transformers import SentenceTransformer
|
|
|
11 |
|
12 |
# ✅ Enable Logging for Debugging
|
13 |
import logging
|
@@ -20,9 +22,10 @@ for name in [
|
|
20 |
"google", "google.auth",
|
21 |
]:
|
22 |
logging.getLogger(name).setLevel(logging.WARNING)
|
23 |
-
logging.basicConfig(level=logging.
|
24 |
logger = logging.getLogger("medical-chatbot")
|
25 |
logger.setLevel(logging.DEBUG)
|
|
|
26 |
# Debug Start
|
27 |
logger.info("🚀 Starting Medical Chatbot API...")
|
28 |
print("🚀 Starting Medical Chatbot API...")
|
@@ -59,6 +62,8 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
59 |
|
60 |
# ✅ Initialize FastAPI app
|
61 |
app = FastAPI(title="Medical Chatbot API")
|
|
|
|
|
62 |
from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
|
63 |
# Define the origins
|
64 |
origins = [
|
@@ -153,44 +158,50 @@ class RAGMedicalChatbot:
|
|
153 |
self.model_name = model_name
|
154 |
self.retrieve = retrieve_function
|
155 |
|
156 |
-
def chat(self, user_query, lang="EN"):
|
|
|
157 |
retrieved_info = self.retrieve(user_query)
|
158 |
knowledge_base = "\n".join(retrieved_info)
|
159 |
|
160 |
-
#
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
""
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
176 |
|
177 |
# ✅ Initialize Chatbot
|
178 |
chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash-preview-04-17", retrieve_function=retrieve_medical_info)
|
179 |
|
180 |
# ✅ Chat Endpoint
|
181 |
@app.post("/chat")
|
182 |
-
async def chat_endpoint(
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
#
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
194 |
|
195 |
# ✅ Run Uvicorn
|
196 |
if __name__ == "__main__":
|
|
|
1 |
+
# app.py
|
2 |
import os
|
3 |
import faiss
|
4 |
import numpy as np
|
5 |
import time
|
6 |
import uvicorn
|
7 |
+
from fastapi import FastAPI, Request
|
8 |
from fastapi.responses import JSONResponse
|
9 |
from pymongo import MongoClient
|
10 |
from google import genai
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
+
from memory import MemoryManager
|
13 |
|
14 |
# ✅ Enable Logging for Debugging
|
15 |
import logging
|
|
|
22 |
"google", "google.auth",
|
23 |
]:
|
24 |
logging.getLogger(name).setLevel(logging.WARNING)
|
25 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
|
26 |
logger = logging.getLogger("medical-chatbot")
|
27 |
logger.setLevel(logging.DEBUG)
|
28 |
+
|
29 |
# Debug Start
|
30 |
logger.info("🚀 Starting Medical Chatbot API...")
|
31 |
print("🚀 Starting Medical Chatbot API...")
|
|
|
62 |
|
63 |
# ✅ Initialize FastAPI app
|
64 |
app = FastAPI(title="Medical Chatbot API")
|
65 |
+
memory = MemoryManager()
|
66 |
+
|
67 |
from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
|
68 |
# Define the origins
|
69 |
origins = [
|
|
|
158 |
self.model_name = model_name
|
159 |
self.retrieve = retrieve_function
|
160 |
|
161 |
+
def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
|
162 |
+
# 1. Fetch knowledge
|
163 |
retrieved_info = self.retrieve(user_query)
|
164 |
knowledge_base = "\n".join(retrieved_info)
|
165 |
|
166 |
+
# 2. Fetch recent context (last 3 chats)
|
167 |
+
context = memory.get_context(user_id, num_turns=3)
|
168 |
+
|
169 |
+
# 3. Build prompt parts
|
170 |
+
parts = ["You are a medical chatbot, designed to answer medical questions."]
|
171 |
+
parts.append("Please format your answer using MarkDown.")
|
172 |
+
parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
|
173 |
+
# Historical chat retrieval case
|
174 |
+
if context:
|
175 |
+
parts.append(f"Previous conversation:\n{context}")
|
176 |
+
parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
|
177 |
+
parts.append(f"Question: {user_query}")
|
178 |
+
parts.append(f"Language: {lang}")
|
179 |
+
prompt = "\n\n".join(parts)
|
180 |
+
response = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
|
181 |
+
# Add to STM caching by id+req+res
|
182 |
+
if user_id:
|
183 |
+
memory.add_exchange(user_id, user_query, response)
|
184 |
+
return response.strip()
|
185 |
|
186 |
# ✅ Initialize Chatbot
|
187 |
chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash-preview-04-17", retrieve_function=retrieve_medical_info)
|
188 |
|
189 |
# ✅ Chat Endpoint
|
190 |
@app.post("/chat")
|
191 |
+
async def chat_endpoint(req: Request):
|
192 |
+
body = await req.json()
|
193 |
+
user_id = body.get("user_id", "anonymous")
|
194 |
+
query = body.get("query", "").strip()
|
195 |
+
lang = body.get("lang", "EN")
|
196 |
+
# Error
|
197 |
+
if not query:
|
198 |
+
return JSONResponse({"response": "No query provided."})
|
199 |
+
start = time.time()
|
200 |
+
answer = chatbot.chat(user_id, query, lang)
|
201 |
+
elapsed = time.time() - start
|
202 |
+
# Final
|
203 |
+
return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
|
204 |
+
|
205 |
|
206 |
# ✅ Run Uvicorn
|
207 |
if __name__ == "__main__":
|
memory.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# memory.py
|
2 |
+
|
3 |
+
from collections import defaultdict, deque
|
4 |
+
from typing import List, Tuple
|
5 |
+
|
6 |
+
class MemoryManager:
|
7 |
+
"""
|
8 |
+
In-memory LRU cache of short-term history:
|
9 |
+
• max_users: total distinct user_ids cached
|
10 |
+
• history_per_user: number of exchanges to keep per user
|
11 |
+
"""
|
12 |
+
def __init__(self, max_users: int = 1000, history_per_user: int = 10):
|
13 |
+
self.memory = defaultdict(lambda: deque(maxlen=history_per_user))
|
14 |
+
self.user_queue = deque(maxlen=max_users)
|
15 |
+
|
16 |
+
def add_exchange(self, user_id: str, query: str, response: str):
|
17 |
+
"""
|
18 |
+
Append a (query, response) pair to the user's history.
|
19 |
+
Evict oldest users when capacity is reached.
|
20 |
+
"""
|
21 |
+
if user_id not in self.memory:
|
22 |
+
# If cache full, drop oldest user and all their history
|
23 |
+
if len(self.user_queue) >= self.user_queue.maxlen:
|
24 |
+
oldest = self.user_queue.popleft()
|
25 |
+
del self.memory[oldest]
|
26 |
+
self.user_queue.append(user_id)
|
27 |
+
|
28 |
+
self.memory[user_id].append((query.strip(), response.strip()))
|
29 |
+
|
30 |
+
def get_context(self, user_id: str, num_turns: int = 3) -> str:
|
31 |
+
"""
|
32 |
+
Return the last `num_turns` as a single formatted string:
|
33 |
+
User: question
|
34 |
+
Bot: answer
|
35 |
+
"""
|
36 |
+
turns = list(self.memory.get(user_id, []))[-num_turns:]
|
37 |
+
return "\n".join(f"User: {q}\nBot: {r}" for q, r in turns)
|
38 |
+
|
39 |
+
def reset(self, user_id: str):
|
40 |
+
"""Clear all history for a given user_id."""
|
41 |
+
if user_id in self.memory:
|
42 |
+
del self.memory[user_id]
|
43 |
+
self.user_queue.remove(user_id)
|