Spaces:
Running
Running
Commit
·
4455263
1
Parent(s):
9f75635
Optimized FastAPI for Hugging Face Spaces
Browse files- Dockerfile +0 -3
- app.py +72 -193
Dockerfile
CHANGED
@@ -22,9 +22,6 @@ ENV SENTENCE_TRANSFORMERS_HOME="/home/user/.cache/huggingface/sentence-transform
|
|
22 |
RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformers && \
|
23 |
chown -R user:user /app/model_cache /home/user/.cache/huggingface
|
24 |
|
25 |
-
# Run the model download script
|
26 |
-
RUN python /app/download_model.py
|
27 |
-
|
28 |
# Pre-load model in a separate script
|
29 |
RUN python /app/download_model.py && python /app/warmup.py
|
30 |
|
|
|
22 |
RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformers && \
|
23 |
chown -R user:user /app/model_cache /home/user/.cache/huggingface
|
24 |
|
|
|
|
|
|
|
25 |
# Pre-load model in a separate script
|
26 |
RUN python /app/download_model.py && python /app/warmup.py
|
27 |
|
app.py
CHANGED
@@ -1,182 +1,89 @@
|
|
1 |
-
# ==========================
|
2 |
-
# Medical Chatbot Backend (Gemini Flash API + RAG) - Local Prebuilt Model with FAISS Index & Data Stored in MongoDB
|
3 |
-
# ==========================
|
4 |
-
"""
|
5 |
-
This script loads:
|
6 |
-
1) A FAISS index stored in MongoDB (in the "faiss_index" collection)
|
7 |
-
2) A local SentenceTransformer model (downloaded via snapshot_download)
|
8 |
-
3) QA data (the full dataset of 256916 QA entries) stored in MongoDB (in the "qa_data" collection)
|
9 |
-
|
10 |
-
If the QA data or FAISS index are not found in MongoDB, the script loads the full dataset from Hugging Face,
|
11 |
-
computes embeddings for all QA pairs (concatenating the "Patient" and "Doctor" fields), and stores both the raw QA data
|
12 |
-
and the FAISS index in MongoDB.
|
13 |
-
|
14 |
-
The chatbot instructs Gemini Flash to format its answer using markdown.
|
15 |
-
"""
|
16 |
-
|
17 |
import os
|
18 |
import faiss
|
19 |
import numpy as np
|
20 |
import time
|
|
|
21 |
from fastapi import FastAPI
|
22 |
from fastapi.responses import JSONResponse
|
23 |
-
from
|
24 |
-
import
|
25 |
-
from
|
26 |
-
|
27 |
-
# Checking status
|
28 |
-
print("🚀 Starting the script...")
|
29 |
-
|
30 |
|
31 |
-
#
|
32 |
-
load_dotenv()
|
33 |
-
gemini_flash_api_key = os.getenv("FlashAPI")
|
34 |
mongo_uri = os.getenv("MONGO_URI")
|
35 |
index_uri = os.getenv("INDEX_URI")
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
#
|
42 |
-
def keep_alive():
|
43 |
-
while True:
|
44 |
-
print("💓 Heartbeat: App is alive")
|
45 |
-
time.sleep(10)
|
46 |
-
threading.Thread(target=keep_alive, daemon=True).start()
|
47 |
-
|
48 |
-
# 1. Environment variables to mitigate segmentation faults
|
49 |
os.environ["OMP_NUM_THREADS"] = "1"
|
50 |
-
os.environ["MKL_NUM_THREADS"] = "1"
|
51 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
52 |
|
53 |
-
#
|
54 |
-
|
55 |
-
# Ensure all necessary files exist
|
56 |
-
required_files = ["config.json", "pytorch_model.bin", "tokenizer.json", "1_Pooling/config.json"]
|
57 |
-
for f in required_files:
|
58 |
-
if not os.path.exists(os.path.join(MODEL_CACHE_DIR, f)):
|
59 |
-
print(f"❌ Missing required model file: {f}")
|
60 |
-
exit(1)
|
61 |
-
|
62 |
-
# 3. Use the preloaded model from `warmup.py`
|
63 |
-
from sentence_transformers import SentenceTransformer
|
64 |
-
import torch
|
65 |
-
print("📥 **Using Preloaded Embedding Model from Warm-up...**")
|
66 |
-
start_time = time.time()
|
67 |
-
try:
|
68 |
-
embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device="cpu")
|
69 |
-
embedding_model = embedding_model.half() # Ensure it stays quantized
|
70 |
-
embedding_model.to(torch.device("cpu"))
|
71 |
-
print("✅ Model ready in {:.2f} seconds.".format(time.time() - start_time))
|
72 |
-
except Exception as e:
|
73 |
-
print(f"❌ Error loading model: {e}")
|
74 |
-
exit(1)
|
75 |
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
client = MongoClient(mongo_uri)
|
88 |
-
db = client["MedicalChatbotDB"] # Use your chosen database name
|
89 |
-
qa_collection = db["qa_data"]
|
90 |
-
# 2. FAISS index client
|
91 |
-
iclient = MongoClient(index_uri)
|
92 |
-
idb = iclient["MedicalChatbotDB"] # Use your chosen database name
|
93 |
-
index_collection = idb["faiss_index_files"]
|
94 |
-
except Exception as e:
|
95 |
-
print(f"❌ MongoDB connection failed: {e}")
|
96 |
-
exit(1)
|
97 |
-
|
98 |
-
|
99 |
-
# 🔹 Load or Build QA Data in MongoDB
|
100 |
-
print("⏳ Checking MongoDB for existing QA data...")
|
101 |
-
if qa_collection.count_documents({}) == 0:
|
102 |
-
print("⚠️ QA data not found in MongoDB! Please report this to the developer.")
|
103 |
-
else:
|
104 |
-
print("✅ Loaded existing QA data from MongoDB.")
|
105 |
-
# Use an aggregation pipeline with allowDiskUse to sort by "i" without creating an index.
|
106 |
-
qa_docs = list(qa_collection.aggregate([
|
107 |
-
{"$sort": {"i": 1}},
|
108 |
-
{"$project": {"_id": 0}}
|
109 |
-
], allowDiskUse=True))
|
110 |
-
qa_data = qa_docs
|
111 |
-
print("📦 Total QA entries loaded:", len(qa_data))
|
112 |
-
|
113 |
-
|
114 |
-
# 🔹 Build or Load the FAISS Index from MongoDB using GridFS (on the separate cluster)
|
115 |
-
print("⏳ Checking GridFS for existing FAISS index...")
|
116 |
import gridfs
|
117 |
-
fs = gridfs.GridFS(idb, collection="faiss_index_files")
|
118 |
|
119 |
-
# 1. Find the FAISS index file by filename.
|
120 |
def load_faiss_index():
|
121 |
-
existing_file = fs.find_one({"filename": "faiss_index.bin"})
|
122 |
-
if existing_file is None:
|
123 |
-
print("⚠️ FAISS index not found in GridFS! Please report this to the developer.")
|
124 |
-
return None
|
125 |
-
else:
|
126 |
-
print("✅ Found FAISS index in GridFS. Loading...")
|
127 |
-
stored_index_bytes = existing_file.read()
|
128 |
-
index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
|
129 |
-
index = faiss.deserialize_index(index_bytes_np)
|
130 |
-
print("📦 FAISS index loaded from GridFS successfully!")
|
131 |
-
return index
|
132 |
-
# Load FAISS index only when needed
|
133 |
-
index = None
|
134 |
-
|
135 |
-
##---------------------------##
|
136 |
-
## INFERENCE BACK+FRONT END
|
137 |
-
##---------------------------##
|
138 |
-
|
139 |
-
|
140 |
-
# 🔹 Prepare Retrieval and Chat Logic
|
141 |
-
def retrieve_medical_info(query):
|
142 |
-
"""Retrieve relevant medical knowledge using the FAISS index."""
|
143 |
global index
|
144 |
if index is None:
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
if index is None:
|
147 |
return ["No medical information available."]
|
|
|
148 |
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
|
149 |
_, idxs = index.search(query_embedding, k=3)
|
150 |
-
results = []
|
151 |
-
for i in idxs[0]:
|
152 |
-
if i < len(qa_data):
|
153 |
-
results.append(qa_data[i].get("Doctor", "No answer available."))
|
154 |
-
else:
|
155 |
-
results.append("No answer available.")
|
156 |
return results
|
157 |
|
158 |
-
|
159 |
-
# 🔹 Gemini Flash API Call
|
160 |
-
from google import genai
|
161 |
def gemini_flash_completion(prompt, model, temperature=0.7):
|
162 |
client_genai = genai.Client(api_key=gemini_flash_api_key)
|
163 |
try:
|
164 |
response = client_genai.models.generate_content(model=model, contents=prompt)
|
165 |
return response.text
|
166 |
except Exception as e:
|
167 |
-
print(f"
|
168 |
return "Error generating response from Gemini."
|
169 |
|
170 |
-
|
171 |
-
# Define a simple language mapping (modify or add more as needed)
|
172 |
-
language_map = {
|
173 |
-
"EN": "English",
|
174 |
-
"VI": "Vietnamese",
|
175 |
-
"ZH": "Chinese"
|
176 |
-
}
|
177 |
-
|
178 |
-
|
179 |
-
# 🔹 Chatbot Class
|
180 |
class RAGMedicalChatbot:
|
181 |
def __init__(self, model_name, retrieve_function):
|
182 |
self.model_name = model_name
|
@@ -185,69 +92,41 @@ class RAGMedicalChatbot:
|
|
185 |
def chat(self, user_query, lang="EN"):
|
186 |
retrieved_info = self.retrieve(user_query)
|
187 |
knowledge_base = "\n".join(retrieved_info)
|
188 |
-
# Construct prompt for Gemini Flash
|
189 |
-
prompt = (
|
190 |
-
"Please format your answer using markdown. Use **bold** for titles, *italic* for emphasis, "
|
191 |
-
"and ensure that headings and paragraphs are clearly separated.\n\n"
|
192 |
-
f"Using the following medical knowledge:\n{knowledge_base} \n(trained with 256,916 data entries).\n\n"
|
193 |
-
f"Answer the following question in a professional and medically accurate manner:\n{user_query}.\n\n"
|
194 |
-
f"Your response answer must be in {lang} language."
|
195 |
-
)
|
196 |
-
completion = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
|
197 |
-
return completion.strip()
|
198 |
-
|
199 |
|
200 |
-
#
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
)
|
205 |
-
print("✅ Medical chatbot is ready! 🤖")
|
206 |
|
|
|
|
|
207 |
|
208 |
-
|
209 |
-
# from fastapi.staticfiles import StaticFiles
|
210 |
-
from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
|
211 |
-
app = FastAPI(title="Medical Chatbot")
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
"https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL
|
218 |
-
|
219 |
-
]
|
220 |
-
|
221 |
-
# 2. Then add the CORS middleware:
|
222 |
-
app.add_middleware(
|
223 |
-
CORSMiddleware,
|
224 |
-
allow_origins=origins, # or ["*"] to allow all
|
225 |
-
allow_credentials=True,
|
226 |
-
allow_methods=["*"],
|
227 |
-
allow_headers=["*"],
|
228 |
-
)
|
229 |
|
|
|
|
|
230 |
|
231 |
-
#
|
232 |
@app.post("/chat")
|
233 |
async def chat_endpoint(data: dict):
|
234 |
user_query = data.get("query", "")
|
235 |
-
lang = data.get("lang", "EN")
|
236 |
if not user_query:
|
237 |
return JSONResponse(content={"response": "No query provided."})
|
|
|
238 |
start_time = time.time()
|
239 |
-
response_text = chatbot.chat(user_query, lang)
|
240 |
end_time = time.time()
|
241 |
response_text += f"\n\n(Response time: {end_time - start_time:.2f} seconds)"
|
242 |
-
return JSONResponse(content={"response": response_text})
|
243 |
|
|
|
244 |
|
245 |
-
#
|
246 |
-
import uvicorn
|
247 |
if __name__ == "__main__":
|
248 |
-
|
249 |
-
|
250 |
-
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) # Default 2 workers, cut to 1
|
251 |
-
except Exception as e:
|
252 |
-
print(f"❌ Server startup failed: {e}")
|
253 |
-
exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import faiss
|
3 |
import numpy as np
|
4 |
import time
|
5 |
+
import uvicorn
|
6 |
from fastapi import FastAPI
|
7 |
from fastapi.responses import JSONResponse
|
8 |
+
from pymongo import MongoClient
|
9 |
+
from google import genai
|
10 |
+
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
# ✅ Environment Variables
|
|
|
|
|
13 |
mongo_uri = os.getenv("MONGO_URI")
|
14 |
index_uri = os.getenv("INDEX_URI")
|
15 |
+
gemini_flash_api_key = os.getenv("FlashAPI")
|
16 |
+
|
17 |
+
if not all([gemini_flash_api_key, mongo_uri, index_uri]):
|
18 |
+
raise ValueError("❌ Missing API keys! Set them in Hugging Face Secrets.")
|
19 |
+
|
20 |
+
# ✅ Reduce Memory Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
os.environ["OMP_NUM_THREADS"] = "1"
|
|
|
22 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
23 |
|
24 |
+
# ✅ Initialize FastAPI app
|
25 |
+
app = FastAPI(title="Medical Chatbot API")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
# ✅ Use Lazy Loading for FAISS Index
|
28 |
+
index = None # Delay FAISS Index loading until first query
|
29 |
|
30 |
+
# ✅ Load SentenceTransformer Model (Quantized)
|
31 |
+
print("📥 Loading SentenceTransformer Model...")
|
32 |
+
MODEL_CACHE_DIR = "/app/model_cache"
|
33 |
+
embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device="cpu")
|
34 |
+
embedding_model = embedding_model.half() # Reduce memory usage
|
35 |
|
36 |
+
# ✅ Setup MongoDB Connection
|
37 |
+
client = MongoClient(mongo_uri)
|
38 |
+
db = client["MedicalChatbotDB"]
|
39 |
+
qa_collection = db["qa_data"]
|
40 |
|
41 |
+
iclient = MongoClient(index_uri)
|
42 |
+
idb = iclient["MedicalChatbotDB"]
|
43 |
+
index_collection = idb["faiss_index_files"]
|
44 |
+
|
45 |
+
# ✅ Load FAISS Index (Lazy Load)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
import gridfs
|
47 |
+
fs = gridfs.GridFS(idb, collection="faiss_index_files")
|
48 |
|
|
|
49 |
def load_faiss_index():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
global index
|
51 |
if index is None:
|
52 |
+
print("⏳ Loading FAISS index from GridFS...")
|
53 |
+
existing_file = fs.find_one({"filename": "faiss_index.bin"})
|
54 |
+
if existing_file:
|
55 |
+
stored_index_bytes = existing_file.read()
|
56 |
+
index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
|
57 |
+
index = faiss.deserialize_index(index_bytes_np)
|
58 |
+
print("✅ FAISS Index Loaded")
|
59 |
+
else:
|
60 |
+
print("❌ FAISS index not found in GridFS.")
|
61 |
+
return index
|
62 |
+
|
63 |
+
# ✅ Retrieve Medical Info
|
64 |
+
def retrieve_medical_info(query):
|
65 |
+
global index
|
66 |
+
index = load_faiss_index() # Load FAISS on demand
|
67 |
+
|
68 |
if index is None:
|
69 |
return ["No medical information available."]
|
70 |
+
|
71 |
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
|
72 |
_, idxs = index.search(query_embedding, k=3)
|
73 |
+
results = [qa_collection.find_one({"i": int(i)}).get("Doctor", "No answer available.") for i in idxs[0]]
|
|
|
|
|
|
|
|
|
|
|
74 |
return results
|
75 |
|
76 |
+
# ✅ Gemini Flash API Call
|
|
|
|
|
77 |
def gemini_flash_completion(prompt, model, temperature=0.7):
|
78 |
client_genai = genai.Client(api_key=gemini_flash_api_key)
|
79 |
try:
|
80 |
response = client_genai.models.generate_content(model=model, contents=prompt)
|
81 |
return response.text
|
82 |
except Exception as e:
|
83 |
+
print(f"❌ Error calling Gemini API: {e}")
|
84 |
return "Error generating response from Gemini."
|
85 |
|
86 |
+
# ✅ Chatbot Class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
class RAGMedicalChatbot:
|
88 |
def __init__(self, model_name, retrieve_function):
|
89 |
self.model_name = model_name
|
|
|
92 |
def chat(self, user_query, lang="EN"):
|
93 |
retrieved_info = self.retrieve(user_query)
|
94 |
knowledge_base = "\n".join(retrieved_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
# ✅ Construct Prompt
|
97 |
+
prompt = f"""
|
98 |
+
Please format your answer using markdown.
|
99 |
+
**Bold for titles**, *italic for emphasis*, and clear headings.
|
|
|
|
|
100 |
|
101 |
+
**Medical knowledge:**
|
102 |
+
{knowledge_base}
|
103 |
|
104 |
+
**Question:** {user_query}
|
|
|
|
|
|
|
105 |
|
106 |
+
**Language Required:** {lang}
|
107 |
+
"""
|
108 |
+
completion = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
|
109 |
+
return completion.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
# ✅ Initialize Chatbot
|
112 |
+
chatbot = RAGMedicalChatbot(model_name="gemini-2.0-flash", retrieve_function=retrieve_medical_info)
|
113 |
|
114 |
+
# ✅ Chat Endpoint
|
115 |
@app.post("/chat")
|
116 |
async def chat_endpoint(data: dict):
|
117 |
user_query = data.get("query", "")
|
118 |
+
lang = data.get("lang", "EN")
|
119 |
if not user_query:
|
120 |
return JSONResponse(content={"response": "No query provided."})
|
121 |
+
|
122 |
start_time = time.time()
|
123 |
+
response_text = chatbot.chat(user_query, lang)
|
124 |
end_time = time.time()
|
125 |
response_text += f"\n\n(Response time: {end_time - start_time:.2f} seconds)"
|
|
|
126 |
|
127 |
+
return JSONResponse(content={"response": response_text})
|
128 |
|
129 |
+
# ✅ Run Uvicorn with 1 Worker
|
|
|
130 |
if __name__ == "__main__":
|
131 |
+
print("✅ Starting FastAPI Server...")
|
132 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
|
|
|
|
|
|
|
|