Spaces:
Running
Running
# ========================== | |
# Medical Chatbot Backend (Gemini Flash API + RAG) - Local Prebuilt Model with FAISS Index & Data Stored in MongoDB | |
# ========================== | |
""" | |
This script loads: | |
1) A FAISS index stored in MongoDB (in the "faiss_index" collection) | |
2) A local SentenceTransformer model (downloaded via snapshot_download) | |
3) QA data (the full dataset of 256916 QA entries) stored in MongoDB (in the "qa_data" collection) | |
If the QA data or FAISS index are not found in MongoDB, the script loads the full dataset from Hugging Face, | |
computes embeddings for all QA pairs (concatenating the "Patient" and "Doctor" fields), and stores both the raw QA data | |
and the FAISS index in MongoDB. | |
The chatbot instructs Gemini Flash to format its answer using markdown. | |
""" | |
import os | |
import faiss | |
import numpy as np | |
import time | |
from fastapi import FastAPI | |
from fastapi.responses import JSONResponse | |
from pathlib import Path | |
import threading | |
from dotenv import load_dotenv | |
# Checking status | |
print("🚀 Starting the script...") | |
# 🔹 Load environment variables from .env | |
load_dotenv() | |
gemini_flash_api_key = os.getenv("FlashAPI") | |
mongo_uri = os.getenv("MONGO_URI") | |
index_uri = os.getenv("INDEX_URI") | |
# Verify ENV Variables | |
print(f"🔎 FlashAPI Key Exists: {'✅ Yes' if os.getenv('FlashAPI') else '❌ No'}") | |
print(f"🔎 MongoDB URI Exists: {'✅ Yes' if os.getenv('MONGO_URI') else '❌ No'}") | |
print(f"🔎 FAISS Index URI Exists: {'✅ Yes' if os.getenv('INDEX_URI') else '❌ No'}") | |
# Heartbeat status checker | |
def keep_alive(): | |
while True: | |
print("💓 Heartbeat: App is alive") | |
time.sleep(10) | |
threading.Thread(target=keep_alive, daemon=True).start() | |
# 1. Environment variables to mitigate segmentation faults | |
os.environ["OMP_NUM_THREADS"] = "1" | |
os.environ["MKL_NUM_THREADS"] = "1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# 2. Setup Hugging Face Cloud project model cache | |
MODEL_CACHE_DIR = "/app/model_cache" | |
# Verify structure | |
print("\n📂 LLM Model Structure (Application Level):") | |
for root, dirs, files in os.walk(MODEL_CACHE_DIR): | |
print(f"📁 {root}/") | |
for file in files: | |
print(f" 📄 {file}") | |
# Ensure all necessary files exist | |
required_files = ["config.json", "pytorch_model.bin", "tokenizer.json", "1_Pooling/config.json"] | |
for f in required_files: | |
if not os.path.exists(os.path.join(MODEL_CACHE_DIR, f)): | |
print(f"❌ Missing required model file: {f}") | |
exit(1) | |
# Check if the required model files exist | |
snapshots_path = os.path.join(MODEL_CACHE_DIR, "models--sentence-transformers--all-MiniLM-L6-v2/snapshots") | |
if os.path.exists(snapshots_path): | |
snapshot_folders = os.listdir(snapshots_path) | |
if snapshot_folders: | |
model_loc = os.path.join(snapshots_path, snapshot_folders[0]) | |
print(f"✅ Found model snapshot at {model_loc}") | |
else: | |
print("❌ No snapshot folder found!") | |
exit(1) | |
else: | |
print("❌ No snapshots directory found! Reload ...") | |
exit(1) | |
# 3. Load the model to application | |
from sentence_transformers import SentenceTransformer | |
print("📥 **Loading Embedding Model...**") | |
start_time = time.time() | |
try: | |
embedding_model = SentenceTransformer(model_loc, device="cpu") | |
print("✅ Model loaded successfully in {:.2f} seconds.".format(time.time() - start_time)) | |
except Exception as e: | |
print(f"❌ Error loading model: {e}") | |
exit(1) | |
##---------------------------## | |
## EMBEDDING AND DATA RETRIEVAL | |
##---------------------------## | |
# 🔹 MongoDB Connection | |
from pymongo import MongoClient | |
try: | |
print("⏳ Connecting to MongoDB...") | |
# 1. QA client | |
client = MongoClient(mongo_uri) | |
db = client["MedicalChatbotDB"] # Use your chosen database name | |
qa_collection = db["qa_data"] | |
# 2. FAISS index client | |
iclient = MongoClient(index_uri) | |
idb = iclient["MedicalChatbotDB"] # Use your chosen database name | |
index_collection = idb["faiss_index_files"] | |
except Exception as e: | |
print(f"❌ MongoDB connection failed: {e}") | |
exit(1) | |
# 🔹 Load or Build QA Data in MongoDB | |
print("⏳ Checking MongoDB for existing QA data...") | |
if qa_collection.count_documents({}) == 0: | |
print("⚠️ QA data not found in MongoDB! Please report this to the developer.") | |
else: | |
print("✅ Loaded existing QA data from MongoDB.") | |
# Use an aggregation pipeline with allowDiskUse to sort by "i" without creating an index. | |
qa_docs = list(qa_collection.aggregate([ | |
{"$sort": {"i": 1}}, | |
{"$project": {"_id": 0}} | |
], allowDiskUse=True)) | |
qa_data = qa_docs | |
print("📦 Total QA entries loaded:", len(qa_data)) | |
# 🔹 Build or Load the FAISS Index from MongoDB using GridFS (on the separate cluster) | |
print("⏳ Checking GridFS for existing FAISS index...") | |
import gridfs | |
fs = gridfs.GridFS(idb, collection="faiss_index_files") # 'idb' is connected using INDEX_URI | |
# 1. Find the FAISS index file by filename. | |
def load_faiss_index(): | |
existing_file = fs.find_one({"filename": "faiss_index.bin"}) | |
if existing_file is None: | |
print("⚠️ FAISS index not found in GridFS! Please report this to the developer.") | |
return None | |
else: | |
print("✅ Found FAISS index in GridFS. Loading...") | |
stored_index_bytes = existing_file.read() | |
index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8') | |
index = faiss.deserialize_index(index_bytes_np) | |
print("📦 FAISS index loaded from GridFS successfully!") | |
return index | |
# Load FAISS index only when needed | |
index = None | |
##---------------------------## | |
## INFERENCE BACK+FRONT END | |
##---------------------------## | |
# 🔹 Prepare Retrieval and Chat Logic | |
def retrieve_medical_info(query): | |
"""Retrieve relevant medical knowledge using the FAISS index.""" | |
global index | |
if index is None: | |
index = load_faiss_index() # Load FAISS only on first query | |
if index is None: | |
return ["No medical information available."] | |
query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
_, idxs = index.search(query_embedding, k=3) | |
results = [] | |
for i in idxs[0]: | |
if i < len(qa_data): | |
results.append(qa_data[i].get("Doctor", "No answer available.")) | |
else: | |
results.append("No answer available.") | |
return results | |
# 🔹 Gemini Flash API Call | |
from google import genai | |
def gemini_flash_completion(prompt, model, temperature=0.7): | |
client_genai = genai.Client(api_key=gemini_flash_api_key) | |
try: | |
response = client_genai.models.generate_content(model=model, contents=prompt) | |
return response.text | |
except Exception as e: | |
print(f"⚠️ Error calling Gemini API: {e}") | |
return "Error generating response from Gemini." | |
# Define a simple language mapping (modify or add more as needed) | |
language_map = { | |
"EN": "English", | |
"VI": "Vietnamese", | |
"ZH": "Chinese" | |
} | |
# 🔹 Chatbot Class | |
class RAGMedicalChatbot: | |
def __init__(self, model_name, retrieve_function): | |
self.model_name = model_name | |
self.retrieve = retrieve_function | |
def chat(self, user_query, lang="EN"): | |
retrieved_info = self.retrieve(user_query) | |
knowledge_base = "\n".join(retrieved_info) | |
# Construct prompt for Gemini Flash | |
prompt = ( | |
"Please format your answer using markdown. Use **bold** for titles, *italic* for emphasis, " | |
"and ensure that headings and paragraphs are clearly separated.\n\n" | |
f"Using the following medical knowledge:\n{knowledge_base} \n(trained with 256,916 data entries).\n\n" | |
f"Answer the following question in a professional and medically accurate manner:\n{user_query}.\n\n" | |
f"Your response answer must be in {lang} language." | |
) | |
completion = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7) | |
return completion.strip() | |
# 🔹 Model Class (change to others if needed) | |
chatbot = RAGMedicalChatbot( | |
model_name="gemini-2.0-flash", | |
retrieve_function=retrieve_medical_info | |
) | |
print("✅ Medical chatbot is ready! 🤖") | |
# 🔹 FastAPI Server | |
# from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin | |
app = FastAPI(title="Medical Chatbot") | |
# 1. Define the origins | |
origins = [ | |
"http://localhost:5173", # Vite dev server | |
"http://localhost:3000", # Another vercel dev server | |
"https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL | |
] | |
# 2. Then add the CORS middleware: | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, # or ["*"] to allow all | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 🔹 Chat route | |
async def chat_endpoint(data: dict): | |
user_query = data.get("query", "") | |
lang = data.get("lang", "EN") # Expect a language code from the request | |
if not user_query: | |
return JSONResponse(content={"response": "No query provided."}) | |
start_time = time.time() | |
response_text = chatbot.chat(user_query, lang) # Pass language selection | |
end_time = time.time() | |
response_text += f"\n\n(Response time: {end_time - start_time:.2f} seconds)" | |
return JSONResponse(content={"response": response_text}) | |
# 🔹 Main Execution | |
import uvicorn | |
if __name__ == "__main__": | |
try: | |
print("✅ Server is starting...") | |
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) # Default 2 workers, cut to 1 | |
except Exception as e: | |
print(f"❌ Server startup failed: {e}") | |
exit(1) |