File size: 9,649 Bytes
5e69775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8828f20
5e69775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70ce1b2
 
 
5e69775
 
 
 
 
 
 
 
8828f20
5e69775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8828f20
5e69775
 
 
8828f20
5e69775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8828f20
5e69775
 
70ce1b2
5e69775
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# ==========================
# Medical Chatbot Backend (Gemini Flash API + RAG) - Local Prebuilt Model with FAISS Index & Data Stored in MongoDB
# ==========================
"""
This script loads:
  1) A FAISS index stored in MongoDB (in the "faiss_index" collection)
  2) A local SentenceTransformer model (downloaded via snapshot_download)
  3) QA data (the full dataset of 256916 QA entries) stored in MongoDB (in the "qa_data" collection)

If the QA data or FAISS index are not found in MongoDB, the script loads the full dataset from Hugging Face,
computes embeddings for all QA pairs (concatenating the "Patient" and "Doctor" fields), and stores both the raw QA data
and the FAISS index in MongoDB.

The chatbot instructs Gemini Flash to format its answer using markdown.
"""

import os
import faiss
import numpy as np
import time
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pathlib import Path
import threading
from dotenv import load_dotenv

# Checking status
print("🚀 Starting the script...")


# 🔹 Load environment variables from .env
load_dotenv()
gemini_flash_api_key = os.getenv("FlashAPI")
mongo_uri = os.getenv("MONGO_URI")
index_uri = os.getenv("INDEX_URI")
# Verify ENV Variables
print(f"🔎 FlashAPI Key Exists: {'✅ Yes' if os.getenv('FlashAPI') else '❌ No'}")
print(f"🔎 MongoDB URI Exists: {'✅ Yes' if os.getenv('MONGO_URI') else '❌ No'}")
print(f"🔎 FAISS Index URI Exists: {'✅ Yes' if os.getenv('INDEX_URI') else '❌ No'}")

# Heartbeat status checker
def keep_alive():
    while True:
        print("💓 Heartbeat: App is alive")
        time.sleep(10)
threading.Thread(target=keep_alive, daemon=True).start()

# 1. Environment variables to mitigate segmentation faults 
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 2. Setup Hugging Face Cloud project model cache
MODEL_CACHE_DIR = "/app/model_cache"
# Verify structure
print("\n📂 LLM Model Structure (Application Level):")
for root, dirs, files in os.walk(MODEL_CACHE_DIR):
    print(f"📁 {root}/")
    for file in files:
        print(f"  📄 {file}")
# Ensure all necessary files exist
required_files = ["config.json", "pytorch_model.bin", "tokenizer.json", "1_Pooling/config.json"]
for f in required_files:
    if not os.path.exists(os.path.join(MODEL_CACHE_DIR, f)):
        print(f"❌ Missing required model file: {f}")
        exit(1)
# Check if the required model files exist
snapshots_path = os.path.join(MODEL_CACHE_DIR, "models--sentence-transformers--all-MiniLM-L6-v2/snapshots")
if os.path.exists(snapshots_path):
    snapshot_folders = os.listdir(snapshots_path)
    if snapshot_folders:
        model_loc = os.path.join(snapshots_path, snapshot_folders[0])
        print(f"✅ Found model snapshot at {model_loc}")
    else:
        print("❌ No snapshot folder found!")
        exit(1)
else:
    print("❌ No snapshots directory found! Reload ...")
    exit(1)

# 3. Load the model to application
from sentence_transformers import SentenceTransformer
print("📥 **Loading Embedding Model...**")
start_time = time.time()
try:
    embedding_model = SentenceTransformer(model_loc, device="cpu")
    print("✅ Model loaded successfully in {:.2f} seconds.".format(time.time() - start_time))
except Exception as e:
    print(f"❌ Error loading model: {e}")
    exit(1)


##---------------------------##
## EMBEDDING AND DATA RETRIEVAL
##---------------------------##


# 🔹 MongoDB Connection
from pymongo import MongoClient
try:
    print("⏳ Connecting to MongoDB...")
    # 1. QA client
    client = MongoClient(mongo_uri)
    db = client["MedicalChatbotDB"]  # Use your chosen database name
    qa_collection = db["qa_data"]
    # 2. FAISS index client
    iclient = MongoClient(index_uri)
    idb = iclient["MedicalChatbotDB"]  # Use your chosen database name
    index_collection = idb["faiss_index_files"]
except Exception as e:
    print(f"❌ MongoDB connection failed: {e}")
    exit(1)


# 🔹 Load or Build QA Data in MongoDB
print("⏳ Checking MongoDB for existing QA data...")
if qa_collection.count_documents({}) == 0:
    print("⚠️ QA data not found in MongoDB! Please report this to the developer.")
else:
    print("✅ Loaded existing QA data from MongoDB.")
    # Use an aggregation pipeline with allowDiskUse to sort by "i" without creating an index.
    qa_docs = list(qa_collection.aggregate([
            {"$sort": {"i": 1}},
            {"$project": {"_id": 0}}
        ], allowDiskUse=True))
    qa_data = qa_docs
    print("📦 Total QA entries loaded:", len(qa_data))


# 🔹 Build or Load the FAISS Index from MongoDB using GridFS (on the separate cluster) 
print("⏳ Checking GridFS for existing FAISS index...")
import gridfs
fs = gridfs.GridFS(idb, collection="faiss_index_files")  # 'idb' is connected using INDEX_URI

# 1. Find the FAISS index file by filename.
def load_faiss_index():
    existing_file = fs.find_one({"filename": "faiss_index.bin"})
    if existing_file is None:
        print("⚠️ FAISS index not found in GridFS! Please report this to the developer.")
        return None
    else:
        print("✅ Found FAISS index in GridFS. Loading...")
        stored_index_bytes = existing_file.read()
        index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
        index = faiss.deserialize_index(index_bytes_np)
        print("📦 FAISS index loaded from GridFS successfully!")
        return index
# Load FAISS index only when needed
index = None

##---------------------------##
## INFERENCE BACK+FRONT END
##---------------------------##


# 🔹 Prepare Retrieval and Chat Logic
def retrieve_medical_info(query):
    """Retrieve relevant medical knowledge using the FAISS index."""
    global index
    if index is None:
        index = load_faiss_index()  # Load FAISS only on first query
    if index is None:
        return ["No medical information available."]
    query_embedding = embedding_model.encode([query], convert_to_numpy=True)
    _, idxs = index.search(query_embedding, k=3)
    results = []
    for i in idxs[0]:
        if i < len(qa_data):
            results.append(qa_data[i].get("Doctor", "No answer available."))
        else:
            results.append("No answer available.")
    return results


# 🔹 Gemini Flash API Call
from google import genai
def gemini_flash_completion(prompt, model, temperature=0.7):
    client_genai = genai.Client(api_key=gemini_flash_api_key)
    try:
        response = client_genai.models.generate_content(model=model, contents=prompt)
        return response.text
    except Exception as e:
        print(f"⚠️ Error calling Gemini API: {e}")
        return "Error generating response from Gemini."


# Define a simple language mapping (modify or add more as needed)
language_map = {
    "EN": "English",
    "VI": "Vietnamese",
    "ZH": "Chinese"
}


# 🔹 Chatbot Class 
class RAGMedicalChatbot:
    def __init__(self, model_name, retrieve_function):
        self.model_name = model_name
        self.retrieve = retrieve_function

    def chat(self, user_query, lang="EN"):
        retrieved_info = self.retrieve(user_query)
        knowledge_base = "\n".join(retrieved_info)
        # Construct prompt for Gemini Flash
        prompt = (
            "Please format your answer using markdown. Use **bold** for titles, *italic* for emphasis, "
            "and ensure that headings and paragraphs are clearly separated.\n\n"
            f"Using the following medical knowledge:\n{knowledge_base} \n(trained with 256,916 data entries).\n\n"
            f"Answer the following question in a professional and medically accurate manner:\n{user_query}.\n\n"
            f"Your response answer must be in {lang} language."
        )
        completion = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
        return completion.strip()


# 🔹 Model Class (change to others if needed)
chatbot = RAGMedicalChatbot(
    model_name="gemini-2.0-flash",
    retrieve_function=retrieve_medical_info
)
print("✅ Medical chatbot is ready! 🤖")


# 🔹 FastAPI Server
# from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
app = FastAPI(title="Medical Chatbot")

# 1. Define the origins
origins = [
    "http://localhost:5173",            # Vite dev server
    "http://localhost:3000",            # Another vercel dev server
    "https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL
    
]

# 2. Then add the CORS middleware:
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,   # or ["*"] to allow all
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# 🔹 Chat route
@app.post("/chat")
async def chat_endpoint(data: dict):
    user_query = data.get("query", "")
    lang = data.get("lang", "EN")  # Expect a language code from the request
    if not user_query:
        return JSONResponse(content={"response": "No query provided."})
    start_time = time.time()
    response_text = chatbot.chat(user_query, lang)  # Pass language selection
    end_time = time.time()
    response_text += f"\n\n(Response time: {end_time - start_time:.2f} seconds)"
    return JSONResponse(content={"response": response_text})


# 🔹 Main Execution
import uvicorn
if __name__ == "__main__":
    try:
        print("✅ Server is starting...")
        uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) # Default 2 workers, cut to 1
    except Exception as e:
        print(f"❌ Server startup failed: {e}")
        exit(1)