Spaces:
Running
Running
# ========================== | |
# Medical Chatbot Backend (Gemini Flash API + RAG) - Local Prebuilt Model with FAISS Index & Data Stored in MongoDB | |
# ========================== | |
""" | |
This script loads: | |
1) A FAISS index stored in MongoDB (in the "faiss_index" collection) | |
2) A local SentenceTransformer model (downloaded via snapshot_download) | |
3) QA data (the full dataset of 256916 QA entries) stored in MongoDB (in the "qa_data" collection) | |
If the QA data or FAISS index are not found in MongoDB, the script loads the full dataset from Hugging Face, | |
computes embeddings for all QA pairs (concatenating the "Patient" and "Doctor" fields), and stores both the raw QA data | |
and the FAISS index in MongoDB. | |
The chatbot instructs Gemini Flash to format its answer using markdown. | |
""" | |
import os | |
import faiss | |
import numpy as np | |
import gc | |
import time | |
from fastapi import FastAPI | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from pathlib import Path | |
# import streamlit as st | |
# import threading | |
# import requests | |
from dotenv import load_dotenv | |
# 🔹 Load environment variables from .env | |
load_dotenv() | |
gemini_flash_api_key = os.getenv("FlashAPI") | |
mongo_uri = os.getenv("MONGO_URI") | |
index_uri = os.getenv("INDEX_URI") | |
# 🔹 Load Streamlit secrets from .toml | |
# gemini_flash_api_key = st.secrets["general"]["FlashAPI"] | |
# mongo_uri = st.secrets["general"]["MONGO_URI"] | |
# index_uri = st.secrets["general"]["INDEX_URI"] | |
if not gemini_flash_api_key: | |
raise ValueError("❌ Gemini Flash API key (FlashAPI) is missing!") | |
# st.error("❌ Gemini Flash API key (FlashAPI) is missing!") | |
# st.stop() # Prevent the app from running without necessary API keys | |
if not mongo_uri: | |
raise ValueError("❌ MongoDB URI (MongoURI) is missing!") | |
# st.error("❌ MongoDB URI (MongoURI) is missing!") | |
# st.stop() # Prevent the app from running without necessary API keys | |
if not index_uri: | |
raise ValueError("❌ INDEX_URI for FAISS index cluster is missing!") | |
# st.error("❌ INDEX_URI for FAISS index cluster is missing!") | |
# st.stop() # Prevent the app from running without necessary API keys | |
# 1. Environment variables to mitigate segmentation faults | |
os.environ["OMP_NUM_THREADS"] = "1" | |
os.environ["MKL_NUM_THREADS"] = "1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# 2a) Setup local project model cache | |
# project_dir = "./AutoGenRAGMedicalChatbot" | |
# os.makedirs(project_dir, exist_ok=True) | |
# huggingface_cache_dir = os.path.join(project_dir, "huggingface_models") | |
# os.environ["HF_HOME"] = huggingface_cache_dir # Use this folder for HF cache | |
# 2. Setup Hugging Face Cloud project model cache | |
MODEL_CACHE_DIR = "/app/model_cache" | |
# Verify structure | |
print("\n📂 LLM Model Structure:") | |
for root, dirs, files in os.walk(MODEL_CACHE_DIR): | |
print(f"📁 {root}/") | |
for file in files: | |
print(f" 📄 {file}") | |
# Check if the required model files exist | |
if os.path.exists(os.path.join(MODEL_CACHE_DIR, "config.json")): | |
print(f"✅ Found cached model at {MODEL_CACHE_DIR}") | |
model_loc = MODEL_CACHE_DIR | |
else: | |
print(f"❌ Model not found in {MODEL_CACHE_DIR}. Critical error!") | |
exit(1) # Exit since the model is missing | |
# 3. Load the model to application | |
from sentence_transformers import SentenceTransformer | |
print("📥 **Loading Embedding Model...**") | |
# st.write("📥 **Loading Embedding Model...**") | |
embedding_model = SentenceTransformer(model_loc, device="cpu") | |
# 🔹 MongoDB Setup | |
from pymongo import MongoClient | |
# 1. QA client | |
client = MongoClient(mongo_uri) | |
db = client["MedicalChatbotDB"] # Use your chosen database name | |
qa_collection = db["qa_data"] | |
# 2. FAISS index client | |
iclient = MongoClient(index_uri) | |
idb = iclient["MedicalChatbotDB"] # Use your chosen database name | |
index_collection = idb["faiss_index_files"] | |
##---------------------------## | |
## EMBEDDING AND DATA RETRIEVAL | |
##---------------------------## | |
# 🔹 Load or Build QA Data in MongoDB | |
print("⏳ Checking MongoDB for existing QA data...") | |
# st.write("⏳ Checking MongoDB for existing QA data...") | |
if qa_collection.count_documents({}) == 0: | |
print("⚠️ QA data not found in MongoDB. Loading dataset from Hugging Face...") | |
# st.write("⚠️ QA data not found in MongoDB. Loading dataset from Hugging Face...") | |
from datasets import load_dataset | |
dataset = load_dataset("ruslanmv/ai-medical-chatbot", cache_dir=huggingface_cache_dir) | |
df = dataset["train"].to_pandas()[["Patient", "Doctor"]] | |
# Add an index column "i" to preserve order. | |
df["i"] = range(len(df)) | |
qa_data = df.to_dict("records") | |
# Insert in batches (e.g., batches of 1000) to avoid document size limits. | |
batch_size = 1000 | |
for i in range(0, len(qa_data), batch_size): | |
qa_collection.insert_many(qa_data[i:i+batch_size]) | |
print(f"📦 QA data stored in MongoDB. Total entries: {len(qa_data)}") | |
# st.write(f"📦 QA data stored in MongoDB. Total entries: {len(qa_data)}") | |
else: | |
print("✅ Loaded existing QA data from MongoDB.") | |
# st.write("✅ Loaded existing QA data from MongoDB.") | |
# Use an aggregation pipeline with allowDiskUse to sort by "i" without creating an index. | |
qa_docs = list(qa_collection.aggregate([ | |
{"$sort": {"i": 1}}, | |
{"$project": {"_id": 0}} | |
], allowDiskUse=True)) | |
qa_data = qa_docs | |
print("📦 Total QA entries loaded:", len(qa_data)) | |
# st.write("📦 Total QA entries loaded:", len(qa_data)) | |
# 🔹 Build or Load the FAISS Index from MongoDB using GridFS (on the separate cluster) | |
print("⏳ Checking GridFS for existing FAISS index...") | |
# st.write("⏳ Checking GridFS for existing FAISS index...") | |
import gridfs | |
fs = gridfs.GridFS(idb, collection="faiss_index_files") # 'idb' is connected using INDEX_URI | |
# 1. Find the FAISS index file by filename. | |
existing_file = fs.find_one({"filename": "faiss_index.bin"}) | |
if existing_file is None: | |
print("⚠️ FAISS index not found in GridFS. Building FAISS index from QA data...") | |
# st.write("⚠️ FAISS index not found in GridFS. Building FAISS index from QA data...") | |
# Compute embeddings for each QA pair by concatenating "Patient" and "Doctor" fields. | |
texts = [item.get("Patient", "") + " " + item.get("Doctor", "") for item in qa_data] | |
batch_size = 512 # Adjust as needed | |
embeddings_list = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i+batch_size] | |
batch_embeddings = embedding_model.encode(batch, convert_to_numpy=True).astype(np.float32) | |
embeddings_list.append(batch_embeddings) | |
print(f"Encoded batch {i} to {i + len(batch)}") | |
# st.write(f"Encoded batch {i} to {i + len(batch)}") | |
embeddings = np.vstack(embeddings_list) | |
dim = embeddings.shape[1] | |
# Create a FAISS index (using IndexHNSWFlat; or use IVFPQ for compression) | |
index = faiss.IndexHNSWFlat(dim, 32) | |
index.add(embeddings) | |
print("FAISS index built. Total vectors:", index.ntotal) | |
# Serialize the index | |
index_bytes = faiss.serialize_index(index) | |
index_data = np.frombuffer(index_bytes, dtype='uint8').tobytes() | |
# Store in GridFS (this bypasses the 16 MB limit) | |
file_id = fs.put(index_data, filename="faiss_index.bin") | |
print("📦 FAISS index built and stored in GridFS with file_id:", file_id) | |
# st.write("📦 FAISS index built and stored in GridFS with file_id:", file_id) | |
del embeddings | |
gc.collect() | |
else: | |
print("✅ Found FAISS index in GridFS. Loading...") | |
# st.write("✅ Found FAISS index in GridFS. Loading...") | |
stored_index_bytes = existing_file.read() | |
index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8') | |
index = faiss.deserialize_index(index_bytes_np) | |
print("📦 FAISS index loaded from GridFS successfully!") | |
# st.write("📦 FAISS index loaded from GridFS successfully!") | |
##---------------------------## | |
## INFERENCE BACK+FRONT END | |
##---------------------------## | |
# 🔹 Prepare Retrieval and Chat Logic | |
def retrieve_medical_info(query): | |
"""Retrieve relevant medical knowledge using the FAISS index.""" | |
query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
_, idxs = index.search(query_embedding, k=3) | |
results = [] | |
for i in idxs[0]: | |
if i < len(qa_data): | |
results.append(qa_data[i].get("Doctor", "No answer available.")) | |
else: | |
results.append("No answer available.") | |
return results | |
# 🔹 Gemini Flash API Call | |
from google import genai | |
def gemini_flash_completion(prompt, model, temperature=0.7): | |
client_genai = genai.Client(api_key=gemini_flash_api_key) | |
try: | |
response = client_genai.models.generate_content(model=model, contents=prompt) | |
return response.text | |
except Exception as e: | |
print(f"⚠️ Error calling Gemini API: {e}") | |
# st.error(f"⚠️ Error calling Gemini API: {e}") | |
return "Error generating response from Gemini." | |
# Define a simple language mapping (modify or add more as needed) | |
language_map = { | |
"EN": "English", | |
"VI": "Vietnamese", | |
"ZH": "Chinese" | |
} | |
# 🔹 Chatbot Class | |
class RAGMedicalChatbot: | |
def __init__(self, model_name, retrieve_function): | |
self.model_name = model_name | |
self.retrieve = retrieve_function | |
def chat(self, user_query, lang="EN"): | |
retrieved_info = self.retrieve(user_query) | |
knowledge_base = "\n".join(retrieved_info) | |
# Construct prompt for Gemini Flash | |
prompt = ( | |
"Please format your answer using markdown. Use **bold** for titles, *italic* for emphasis, " | |
"and ensure that headings and paragraphs are clearly separated.\n\n" | |
f"Using the following medical knowledge:\n{knowledge_base} \n(trained with 256,916 data entries).\n\n" | |
f"Answer the following question in a professional and medically accurate manner:\n{user_query}.\n\n" | |
f"Your response answer must be in {lang} language." | |
) | |
completion = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7) | |
return completion.strip() | |
# 🔹 Model Class (change to others if needed) | |
chatbot = RAGMedicalChatbot( | |
model_name="gemini-2.0-flash", | |
retrieve_function=retrieve_medical_info | |
) | |
print("✅ Medical chatbot is ready! 🤖") | |
# st.success("✅ Medical chatbot is ready! 🤖") | |
# 🔹 FastAPI Server | |
# from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin | |
app = FastAPI(title="Medical Chatbot") | |
# 1. Define the origins | |
origins = [ | |
"http://localhost:5173", # Vite dev server | |
"http://localhost:3000", # Another vercel dev server | |
"https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL | |
] | |
# 2. Then add the CORS middleware: | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, # or ["*"] to allow all | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# (02/03/2025) Move static files UI to Vercel | |
# 3. Mount static files (make sure the "static" folder exists and contains your images) | |
# app.mount("/static", StaticFiles(directory="static"), name="static") | |
# 4. Get statics template route | |
# @app.get("/", response_class=HTMLResponse) | |
# async def get_home(): | |
# return HTML_CONTENT | |
# 🔹 Chat route | |
async def chat_endpoint(data: dict): | |
user_query = data.get("query", "") | |
lang = data.get("lang", "EN") # Expect a language code from the request | |
if not user_query: | |
return JSONResponse(content={"response": "No query provided."}) | |
start_time = time.time() | |
response_text = chatbot.chat(user_query, lang) # Pass language selection | |
end_time = time.time() | |
response_text += f"\n\n(Response time: {end_time - start_time:.2f} seconds)" | |
return JSONResponse(content={"response": response_text}) | |
# 🔹 Main Execution | |
# 1. On Streamlit (free-tier allowance 1GB) | |
# 🌐 Start FastAPI server in a separate thread | |
# def run_fastapi(): | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |
# threading.Thread(target=run_fastapi, daemon=True).start() | |
# # 🔍 Streamlit UI for Testing | |
# st.title("🩺 Medical Chatbot API") | |
# st.info("This is a **FastAPI Backend running on Streamlit Cloud**") | |
# user_query = st.text_input("Enter your medical question:") | |
# selected_lang = st.selectbox("Select Language:", ["English (EN)", "Vietnamese (VI)", "Chinese (ZH)"]) | |
# if st.button("Ask Doctor Bot"): | |
# lang_code = selected_lang.split("(")[-1].strip(")") | |
# st.markdown("🤖 **DocBot is thinking...**") | |
# # a) API request to FastAPI | |
# response = requests.post("http://127.0.0.1:8000/chat", json={"query": user_query, "lang": lang_code}) | |
# response_json = response.json() | |
# # b) Display response | |
# st.markdown(response_json["response"]) | |
# 2. On Render (free-tier allowance 521MB) | |
# if __name__ == "__main__": | |
# import uvicorn | |
# print("\n🩺 Starting Medical Chatbot FastAPI server...\n") | |
# # 🌐 Start app | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |
# 3. On Hugging Face with Gradio (limited API request) | |
import uvicorn | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |