LiamKhoaLe commited on
Commit
4455263
·
1 Parent(s): 9f75635

Optimized FastAPI for Hugging Face Spaces

Browse files
Files changed (2) hide show
  1. Dockerfile +0 -3
  2. app.py +72 -193
Dockerfile CHANGED
@@ -22,9 +22,6 @@ ENV SENTENCE_TRANSFORMERS_HOME="/home/user/.cache/huggingface/sentence-transform
22
  RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformers && \
23
  chown -R user:user /app/model_cache /home/user/.cache/huggingface
24
 
25
- # Run the model download script
26
- RUN python /app/download_model.py
27
-
28
  # Pre-load model in a separate script
29
  RUN python /app/download_model.py && python /app/warmup.py
30
 
 
22
  RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformers && \
23
  chown -R user:user /app/model_cache /home/user/.cache/huggingface
24
 
 
 
 
25
  # Pre-load model in a separate script
26
  RUN python /app/download_model.py && python /app/warmup.py
27
 
app.py CHANGED
@@ -1,182 +1,89 @@
1
- # ==========================
2
- # Medical Chatbot Backend (Gemini Flash API + RAG) - Local Prebuilt Model with FAISS Index & Data Stored in MongoDB
3
- # ==========================
4
- """
5
- This script loads:
6
- 1) A FAISS index stored in MongoDB (in the "faiss_index" collection)
7
- 2) A local SentenceTransformer model (downloaded via snapshot_download)
8
- 3) QA data (the full dataset of 256916 QA entries) stored in MongoDB (in the "qa_data" collection)
9
-
10
- If the QA data or FAISS index are not found in MongoDB, the script loads the full dataset from Hugging Face,
11
- computes embeddings for all QA pairs (concatenating the "Patient" and "Doctor" fields), and stores both the raw QA data
12
- and the FAISS index in MongoDB.
13
-
14
- The chatbot instructs Gemini Flash to format its answer using markdown.
15
- """
16
-
17
  import os
18
  import faiss
19
  import numpy as np
20
  import time
 
21
  from fastapi import FastAPI
22
  from fastapi.responses import JSONResponse
23
- from pathlib import Path
24
- import threading
25
- from dotenv import load_dotenv
26
-
27
- # Checking status
28
- print("🚀 Starting the script...")
29
-
30
 
31
- # 🔹 Load environment variables from .env
32
- load_dotenv()
33
- gemini_flash_api_key = os.getenv("FlashAPI")
34
  mongo_uri = os.getenv("MONGO_URI")
35
  index_uri = os.getenv("INDEX_URI")
36
- # Verify ENV Variables
37
- print(f"🔎 FlashAPI Key Exists: {'✅ Yes' if os.getenv('FlashAPI') else '❌ No'}")
38
- print(f"🔎 MongoDB URI Exists: {'✅ Yes' if os.getenv('MONGO_URI') else '❌ No'}")
39
- print(f"🔎 FAISS Index URI Exists: {'✅ Yes' if os.getenv('INDEX_URI') else '❌ No'}")
40
-
41
- # Heartbeat status checker
42
- def keep_alive():
43
- while True:
44
- print("💓 Heartbeat: App is alive")
45
- time.sleep(10)
46
- threading.Thread(target=keep_alive, daemon=True).start()
47
-
48
- # 1. Environment variables to mitigate segmentation faults
49
  os.environ["OMP_NUM_THREADS"] = "1"
50
- os.environ["MKL_NUM_THREADS"] = "1"
51
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
52
 
53
- # 2. Setup Hugging Face Cloud project model cache
54
- MODEL_CACHE_DIR = "/app/model_cache"
55
- # Ensure all necessary files exist
56
- required_files = ["config.json", "pytorch_model.bin", "tokenizer.json", "1_Pooling/config.json"]
57
- for f in required_files:
58
- if not os.path.exists(os.path.join(MODEL_CACHE_DIR, f)):
59
- print(f"❌ Missing required model file: {f}")
60
- exit(1)
61
-
62
- # 3. Use the preloaded model from `warmup.py`
63
- from sentence_transformers import SentenceTransformer
64
- import torch
65
- print("📥 **Using Preloaded Embedding Model from Warm-up...**")
66
- start_time = time.time()
67
- try:
68
- embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device="cpu")
69
- embedding_model = embedding_model.half() # Ensure it stays quantized
70
- embedding_model.to(torch.device("cpu"))
71
- print("✅ Model ready in {:.2f} seconds.".format(time.time() - start_time))
72
- except Exception as e:
73
- print(f"❌ Error loading model: {e}")
74
- exit(1)
75
 
 
 
76
 
77
- ##---------------------------##
78
- ## EMBEDDING AND DATA RETRIEVAL
79
- ##---------------------------##
 
 
80
 
 
 
 
 
81
 
82
- # 🔹 MongoDB Connection
83
- from pymongo import MongoClient
84
- try:
85
- print("⏳ Connecting to MongoDB...")
86
- # 1. QA client
87
- client = MongoClient(mongo_uri)
88
- db = client["MedicalChatbotDB"] # Use your chosen database name
89
- qa_collection = db["qa_data"]
90
- # 2. FAISS index client
91
- iclient = MongoClient(index_uri)
92
- idb = iclient["MedicalChatbotDB"] # Use your chosen database name
93
- index_collection = idb["faiss_index_files"]
94
- except Exception as e:
95
- print(f"❌ MongoDB connection failed: {e}")
96
- exit(1)
97
-
98
-
99
- # 🔹 Load or Build QA Data in MongoDB
100
- print("⏳ Checking MongoDB for existing QA data...")
101
- if qa_collection.count_documents({}) == 0:
102
- print("⚠️ QA data not found in MongoDB! Please report this to the developer.")
103
- else:
104
- print("✅ Loaded existing QA data from MongoDB.")
105
- # Use an aggregation pipeline with allowDiskUse to sort by "i" without creating an index.
106
- qa_docs = list(qa_collection.aggregate([
107
- {"$sort": {"i": 1}},
108
- {"$project": {"_id": 0}}
109
- ], allowDiskUse=True))
110
- qa_data = qa_docs
111
- print("📦 Total QA entries loaded:", len(qa_data))
112
-
113
-
114
- # 🔹 Build or Load the FAISS Index from MongoDB using GridFS (on the separate cluster)
115
- print("⏳ Checking GridFS for existing FAISS index...")
116
  import gridfs
117
- fs = gridfs.GridFS(idb, collection="faiss_index_files") # 'idb' is connected using INDEX_URI
118
 
119
- # 1. Find the FAISS index file by filename.
120
  def load_faiss_index():
121
- existing_file = fs.find_one({"filename": "faiss_index.bin"})
122
- if existing_file is None:
123
- print("⚠️ FAISS index not found in GridFS! Please report this to the developer.")
124
- return None
125
- else:
126
- print("✅ Found FAISS index in GridFS. Loading...")
127
- stored_index_bytes = existing_file.read()
128
- index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
129
- index = faiss.deserialize_index(index_bytes_np)
130
- print("📦 FAISS index loaded from GridFS successfully!")
131
- return index
132
- # Load FAISS index only when needed
133
- index = None
134
-
135
- ##---------------------------##
136
- ## INFERENCE BACK+FRONT END
137
- ##---------------------------##
138
-
139
-
140
- # 🔹 Prepare Retrieval and Chat Logic
141
- def retrieve_medical_info(query):
142
- """Retrieve relevant medical knowledge using the FAISS index."""
143
  global index
144
  if index is None:
145
- index = load_faiss_index() # Load FAISS only on first query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if index is None:
147
  return ["No medical information available."]
 
148
  query_embedding = embedding_model.encode([query], convert_to_numpy=True)
149
  _, idxs = index.search(query_embedding, k=3)
150
- results = []
151
- for i in idxs[0]:
152
- if i < len(qa_data):
153
- results.append(qa_data[i].get("Doctor", "No answer available."))
154
- else:
155
- results.append("No answer available.")
156
  return results
157
 
158
-
159
- # 🔹 Gemini Flash API Call
160
- from google import genai
161
  def gemini_flash_completion(prompt, model, temperature=0.7):
162
  client_genai = genai.Client(api_key=gemini_flash_api_key)
163
  try:
164
  response = client_genai.models.generate_content(model=model, contents=prompt)
165
  return response.text
166
  except Exception as e:
167
- print(f"⚠️ Error calling Gemini API: {e}")
168
  return "Error generating response from Gemini."
169
 
170
-
171
- # Define a simple language mapping (modify or add more as needed)
172
- language_map = {
173
- "EN": "English",
174
- "VI": "Vietnamese",
175
- "ZH": "Chinese"
176
- }
177
-
178
-
179
- # 🔹 Chatbot Class
180
  class RAGMedicalChatbot:
181
  def __init__(self, model_name, retrieve_function):
182
  self.model_name = model_name
@@ -185,69 +92,41 @@ class RAGMedicalChatbot:
185
  def chat(self, user_query, lang="EN"):
186
  retrieved_info = self.retrieve(user_query)
187
  knowledge_base = "\n".join(retrieved_info)
188
- # Construct prompt for Gemini Flash
189
- prompt = (
190
- "Please format your answer using markdown. Use **bold** for titles, *italic* for emphasis, "
191
- "and ensure that headings and paragraphs are clearly separated.\n\n"
192
- f"Using the following medical knowledge:\n{knowledge_base} \n(trained with 256,916 data entries).\n\n"
193
- f"Answer the following question in a professional and medically accurate manner:\n{user_query}.\n\n"
194
- f"Your response answer must be in {lang} language."
195
- )
196
- completion = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
197
- return completion.strip()
198
-
199
 
200
- # 🔹 Model Class (change to others if needed)
201
- chatbot = RAGMedicalChatbot(
202
- model_name="gemini-2.0-flash",
203
- retrieve_function=retrieve_medical_info
204
- )
205
- print("✅ Medical chatbot is ready! 🤖")
206
 
 
 
207
 
208
- # 🔹 FastAPI Server
209
- # from fastapi.staticfiles import StaticFiles
210
- from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
211
- app = FastAPI(title="Medical Chatbot")
212
 
213
- # 1. Define the origins
214
- origins = [
215
- "http://localhost:5173", # Vite dev server
216
- "http://localhost:3000", # Another vercel dev server
217
- "https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL
218
-
219
- ]
220
-
221
- # 2. Then add the CORS middleware:
222
- app.add_middleware(
223
- CORSMiddleware,
224
- allow_origins=origins, # or ["*"] to allow all
225
- allow_credentials=True,
226
- allow_methods=["*"],
227
- allow_headers=["*"],
228
- )
229
 
 
 
230
 
231
- # 🔹 Chat route
232
  @app.post("/chat")
233
  async def chat_endpoint(data: dict):
234
  user_query = data.get("query", "")
235
- lang = data.get("lang", "EN") # Expect a language code from the request
236
  if not user_query:
237
  return JSONResponse(content={"response": "No query provided."})
 
238
  start_time = time.time()
239
- response_text = chatbot.chat(user_query, lang) # Pass language selection
240
  end_time = time.time()
241
  response_text += f"\n\n(Response time: {end_time - start_time:.2f} seconds)"
242
- return JSONResponse(content={"response": response_text})
243
 
 
244
 
245
- # 🔹 Main Execution
246
- import uvicorn
247
  if __name__ == "__main__":
248
- try:
249
- print(" Server is starting...")
250
- uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) # Default 2 workers, cut to 1
251
- except Exception as e:
252
- print(f"❌ Server startup failed: {e}")
253
- exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import faiss
3
  import numpy as np
4
  import time
5
+ import uvicorn
6
  from fastapi import FastAPI
7
  from fastapi.responses import JSONResponse
8
+ from pymongo import MongoClient
9
+ from google import genai
10
+ from sentence_transformers import SentenceTransformer
 
 
 
 
11
 
12
+ # Environment Variables
 
 
13
  mongo_uri = os.getenv("MONGO_URI")
14
  index_uri = os.getenv("INDEX_URI")
15
+ gemini_flash_api_key = os.getenv("FlashAPI")
16
+
17
+ if not all([gemini_flash_api_key, mongo_uri, index_uri]):
18
+ raise ValueError(" Missing API keys! Set them in Hugging Face Secrets.")
19
+
20
+ # Reduce Memory Usage
 
 
 
 
 
 
 
21
  os.environ["OMP_NUM_THREADS"] = "1"
 
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
 
24
+ # Initialize FastAPI app
25
+ app = FastAPI(title="Medical Chatbot API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # ✅ Use Lazy Loading for FAISS Index
28
+ index = None # Delay FAISS Index loading until first query
29
 
30
+ # ✅ Load SentenceTransformer Model (Quantized)
31
+ print("📥 Loading SentenceTransformer Model...")
32
+ MODEL_CACHE_DIR = "/app/model_cache"
33
+ embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device="cpu")
34
+ embedding_model = embedding_model.half() # Reduce memory usage
35
 
36
+ # ✅ Setup MongoDB Connection
37
+ client = MongoClient(mongo_uri)
38
+ db = client["MedicalChatbotDB"]
39
+ qa_collection = db["qa_data"]
40
 
41
+ iclient = MongoClient(index_uri)
42
+ idb = iclient["MedicalChatbotDB"]
43
+ index_collection = idb["faiss_index_files"]
44
+
45
+ # Load FAISS Index (Lazy Load)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  import gridfs
47
+ fs = gridfs.GridFS(idb, collection="faiss_index_files")
48
 
 
49
  def load_faiss_index():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  global index
51
  if index is None:
52
+ print("⏳ Loading FAISS index from GridFS...")
53
+ existing_file = fs.find_one({"filename": "faiss_index.bin"})
54
+ if existing_file:
55
+ stored_index_bytes = existing_file.read()
56
+ index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
57
+ index = faiss.deserialize_index(index_bytes_np)
58
+ print("✅ FAISS Index Loaded")
59
+ else:
60
+ print("❌ FAISS index not found in GridFS.")
61
+ return index
62
+
63
+ # ✅ Retrieve Medical Info
64
+ def retrieve_medical_info(query):
65
+ global index
66
+ index = load_faiss_index() # Load FAISS on demand
67
+
68
  if index is None:
69
  return ["No medical information available."]
70
+
71
  query_embedding = embedding_model.encode([query], convert_to_numpy=True)
72
  _, idxs = index.search(query_embedding, k=3)
73
+ results = [qa_collection.find_one({"i": int(i)}).get("Doctor", "No answer available.") for i in idxs[0]]
 
 
 
 
 
74
  return results
75
 
76
+ # ✅ Gemini Flash API Call
 
 
77
  def gemini_flash_completion(prompt, model, temperature=0.7):
78
  client_genai = genai.Client(api_key=gemini_flash_api_key)
79
  try:
80
  response = client_genai.models.generate_content(model=model, contents=prompt)
81
  return response.text
82
  except Exception as e:
83
+ print(f" Error calling Gemini API: {e}")
84
  return "Error generating response from Gemini."
85
 
86
+ # ✅ Chatbot Class
 
 
 
 
 
 
 
 
 
87
  class RAGMedicalChatbot:
88
  def __init__(self, model_name, retrieve_function):
89
  self.model_name = model_name
 
92
  def chat(self, user_query, lang="EN"):
93
  retrieved_info = self.retrieve(user_query)
94
  knowledge_base = "\n".join(retrieved_info)
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Construct Prompt
97
+ prompt = f"""
98
+ Please format your answer using markdown.
99
+ **Bold for titles**, *italic for emphasis*, and clear headings.
 
 
100
 
101
+ **Medical knowledge:**
102
+ {knowledge_base}
103
 
104
+ **Question:** {user_query}
 
 
 
105
 
106
+ **Language Required:** {lang}
107
+ """
108
+ completion = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
109
+ return completion.strip()
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # ✅ Initialize Chatbot
112
+ chatbot = RAGMedicalChatbot(model_name="gemini-2.0-flash", retrieve_function=retrieve_medical_info)
113
 
114
+ # Chat Endpoint
115
  @app.post("/chat")
116
  async def chat_endpoint(data: dict):
117
  user_query = data.get("query", "")
118
+ lang = data.get("lang", "EN")
119
  if not user_query:
120
  return JSONResponse(content={"response": "No query provided."})
121
+
122
  start_time = time.time()
123
+ response_text = chatbot.chat(user_query, lang)
124
  end_time = time.time()
125
  response_text += f"\n\n(Response time: {end_time - start_time:.2f} seconds)"
 
126
 
127
+ return JSONResponse(content={"response": response_text})
128
 
129
+ # Run Uvicorn with 1 Worker
 
130
  if __name__ == "__main__":
131
+ print("✅ Starting FastAPI Server...")
132
+ uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)