LiamKhoaLe commited on
Commit
57e04c3
·
1 Parent(s): 15ed85c

Populate LRU caching allowing recent chat retrieval. Add memory handler

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +42 -31
  3. memory.py +43 -0
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: docker
7
  sdk_version: latest
8
  pinned: false
9
  license: apache-2.0
10
- short_description: Medical Chatbot, with FAISS, Gemini Flash, and MongoDB
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  sdk_version: latest
8
  pinned: false
9
  license: apache-2.0
10
+ short_description: MedicalChatbot - FAISS RAG, Gemini Flash, MongoDB vDB, LRU cache
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,13 +1,15 @@
 
1
  import os
2
  import faiss
3
  import numpy as np
4
  import time
5
  import uvicorn
6
- from fastapi import FastAPI
7
  from fastapi.responses import JSONResponse
8
  from pymongo import MongoClient
9
  from google import genai
10
  from sentence_transformers import SentenceTransformer
 
11
 
12
  # ✅ Enable Logging for Debugging
13
  import logging
@@ -20,9 +22,10 @@ for name in [
20
  "google", "google.auth",
21
  ]:
22
  logging.getLogger(name).setLevel(logging.WARNING)
23
- logging.basicConfig(level=logging.DEBUG, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True)
24
  logger = logging.getLogger("medical-chatbot")
25
  logger.setLevel(logging.DEBUG)
 
26
  # Debug Start
27
  logger.info("🚀 Starting Medical Chatbot API...")
28
  print("🚀 Starting Medical Chatbot API...")
@@ -59,6 +62,8 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
59
 
60
  # ✅ Initialize FastAPI app
61
  app = FastAPI(title="Medical Chatbot API")
 
 
62
  from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
63
  # Define the origins
64
  origins = [
@@ -153,44 +158,50 @@ class RAGMedicalChatbot:
153
  self.model_name = model_name
154
  self.retrieve = retrieve_function
155
 
156
- def chat(self, user_query, lang="EN"):
 
157
  retrieved_info = self.retrieve(user_query)
158
  knowledge_base = "\n".join(retrieved_info)
159
 
160
- # Construct Prompt
161
- prompt = f"""
162
- You are a medical chatbot, designed to answer medical questions.
163
-
164
- Please format your answer using markdown.
165
- **Bold for titles**, *italic for emphasis*, and clear headings.
166
-
167
- **Medical knowledge (trained with 256,916 data entries):**
168
- {knowledge_base}
169
-
170
- **Question:** {user_query}
171
-
172
- **Language Required:** {lang}
173
- """
174
- completion = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
175
- return completion.strip()
 
 
 
176
 
177
  # ✅ Initialize Chatbot
178
  chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash-preview-04-17", retrieve_function=retrieve_medical_info)
179
 
180
  # ✅ Chat Endpoint
181
  @app.post("/chat")
182
- async def chat_endpoint(data: dict):
183
- user_query = data.get("query", "")
184
- lang = data.get("lang", "EN")
185
- if not user_query:
186
- return JSONResponse(content={"response": "No query provided."})
187
- # Output parameter
188
- start_time = time.time()
189
- response_text = chatbot.chat(user_query, lang)
190
- end_time = time.time()
191
- response_text += f"\n\n(Response time: {end_time - start_time:.2f} seconds)"
192
- # Send JSON response
193
- return JSONResponse(content={"response": response_text})
 
 
194
 
195
  # ✅ Run Uvicorn
196
  if __name__ == "__main__":
 
1
+ # app.py
2
  import os
3
  import faiss
4
  import numpy as np
5
  import time
6
  import uvicorn
7
+ from fastapi import FastAPI, Request
8
  from fastapi.responses import JSONResponse
9
  from pymongo import MongoClient
10
  from google import genai
11
  from sentence_transformers import SentenceTransformer
12
+ from memory import MemoryManager
13
 
14
  # ✅ Enable Logging for Debugging
15
  import logging
 
22
  "google", "google.auth",
23
  ]:
24
  logging.getLogger(name).setLevel(logging.WARNING)
25
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
26
  logger = logging.getLogger("medical-chatbot")
27
  logger.setLevel(logging.DEBUG)
28
+
29
  # Debug Start
30
  logger.info("🚀 Starting Medical Chatbot API...")
31
  print("🚀 Starting Medical Chatbot API...")
 
62
 
63
  # ✅ Initialize FastAPI app
64
  app = FastAPI(title="Medical Chatbot API")
65
+ memory = MemoryManager()
66
+
67
  from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
68
  # Define the origins
69
  origins = [
 
158
  self.model_name = model_name
159
  self.retrieve = retrieve_function
160
 
161
+ def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
162
+ # 1. Fetch knowledge
163
  retrieved_info = self.retrieve(user_query)
164
  knowledge_base = "\n".join(retrieved_info)
165
 
166
+ # 2. Fetch recent context (last 3 chats)
167
+ context = memory.get_context(user_id, num_turns=3)
168
+
169
+ # 3. Build prompt parts
170
+ parts = ["You are a medical chatbot, designed to answer medical questions."]
171
+ parts.append("Please format your answer using MarkDown.")
172
+ parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
173
+ # Historical chat retrieval case
174
+ if context:
175
+ parts.append(f"Previous conversation:\n{context}")
176
+ parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
177
+ parts.append(f"Question: {user_query}")
178
+ parts.append(f"Language: {lang}")
179
+ prompt = "\n\n".join(parts)
180
+ response = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
181
+ # Add to STM caching by id+req+res
182
+ if user_id:
183
+ memory.add_exchange(user_id, user_query, response)
184
+ return response.strip()
185
 
186
  # ✅ Initialize Chatbot
187
  chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash-preview-04-17", retrieve_function=retrieve_medical_info)
188
 
189
  # ✅ Chat Endpoint
190
  @app.post("/chat")
191
+ async def chat_endpoint(req: Request):
192
+ body = await req.json()
193
+ user_id = body.get("user_id", "anonymous")
194
+ query = body.get("query", "").strip()
195
+ lang = body.get("lang", "EN")
196
+ # Error
197
+ if not query:
198
+ return JSONResponse({"response": "No query provided."})
199
+ start = time.time()
200
+ answer = chatbot.chat(user_id, query, lang)
201
+ elapsed = time.time() - start
202
+ # Final
203
+ return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
204
+
205
 
206
  # ✅ Run Uvicorn
207
  if __name__ == "__main__":
memory.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # memory.py
2
+
3
+ from collections import defaultdict, deque
4
+ from typing import List, Tuple
5
+
6
+ class MemoryManager:
7
+ """
8
+ In-memory LRU cache of short-term history:
9
+ • max_users: total distinct user_ids cached
10
+ • history_per_user: number of exchanges to keep per user
11
+ """
12
+ def __init__(self, max_users: int = 1000, history_per_user: int = 10):
13
+ self.memory = defaultdict(lambda: deque(maxlen=history_per_user))
14
+ self.user_queue = deque(maxlen=max_users)
15
+
16
+ def add_exchange(self, user_id: str, query: str, response: str):
17
+ """
18
+ Append a (query, response) pair to the user's history.
19
+ Evict oldest users when capacity is reached.
20
+ """
21
+ if user_id not in self.memory:
22
+ # If cache full, drop oldest user and all their history
23
+ if len(self.user_queue) >= self.user_queue.maxlen:
24
+ oldest = self.user_queue.popleft()
25
+ del self.memory[oldest]
26
+ self.user_queue.append(user_id)
27
+
28
+ self.memory[user_id].append((query.strip(), response.strip()))
29
+
30
+ def get_context(self, user_id: str, num_turns: int = 3) -> str:
31
+ """
32
+ Return the last `num_turns` as a single formatted string:
33
+ User: question
34
+ Bot: answer
35
+ """
36
+ turns = list(self.memory.get(user_id, []))[-num_turns:]
37
+ return "\n".join(f"User: {q}\nBot: {r}" for q, r in turns)
38
+
39
+ def reset(self, user_id: str):
40
+ """Clear all history for a given user_id."""
41
+ if user_id in self.memory:
42
+ del self.memory[user_id]
43
+ self.user_queue.remove(user_id)