File size: 7,873 Bytes
57e04c3
5e69775
 
 
 
4455263
57e04c3
5e69775
4455263
 
 
57e04c3
5e69775
65d7792
5c59423
15ed85c
 
 
 
 
 
 
 
 
57e04c3
65d7792
15ed85c
57e04c3
65d7792
 
 
 
4455263
5e69775
 
4455263
65d7792
4455263
 
65d7792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e69775
 
 
4455263
 
57e04c3
 
65d7792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e69775
4455263
 
5e69775
65d7792
 
4455263
 
65d7792
 
 
 
 
 
 
 
 
5e69775
4455263
65d7792
4455263
 
 
65d7792
4455263
 
 
 
 
5e69775
4455263
8828f20
5e69775
 
 
4455263
 
 
 
 
 
 
65d7792
4455263
 
65d7792
4455263
 
 
 
 
 
65d7792
5e69775
 
65d7792
5e69775
 
4455263
5e69775
 
4455263
5e69775
 
 
 
 
 
65d7792
4455263
5e69775
 
4455263
5e69775
 
 
 
 
57e04c3
 
5e69775
 
 
1ca4ee7
 
57e04c3
 
 
 
 
 
 
1ca4ee7
57e04c3
 
 
 
 
1ca4ee7
57e04c3
1ca4ee7
57e04c3
5e69775
4455263
0123a67
5e69775
4455263
5e69775
57e04c3
 
 
 
 
 
 
 
 
 
 
 
 
 
5e69775
65d7792
8828f20
65d7792
4455263
65d7792
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# app.py
import os
import faiss
import numpy as np
import time
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pymongo import MongoClient
from google import genai
from sentence_transformers import SentenceTransformer
from memory import MemoryManager

# ✅ Enable Logging for Debugging
import logging
# —————— Silence Noisy Loggers ——————
for name in [
    "uvicorn.error", "uvicorn.access",
    "fastapi", "starlette",
    "pymongo", "gridfs",
    "sentence_transformers", "faiss",
    "google", "google.auth",
]:
    logging.getLogger(name).setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
logger = logging.getLogger("medical-chatbot")
logger.setLevel(logging.DEBUG)

# Debug Start
logger.info("🚀 Starting Medical Chatbot API...")
print("🚀 Starting Medical Chatbot API...")

# ✅ Environment Variables
mongo_uri = os.getenv("MONGO_URI")
index_uri = os.getenv("INDEX_URI")
gemini_flash_api_key = os.getenv("FlashAPI")
# Validate environment endpoint
if not all([gemini_flash_api_key, mongo_uri, index_uri]):
    raise ValueError("❌ Missing API keys! Set them in Hugging Face Secrets.")
logger.info(f"🔎 MongoDB URI: {mongo_uri}")
logger.info(f"🔎 FAISS Index URI: {index_uri}")

# ✅ Monitor Resources Before Startup
import psutil
def check_system_resources():
    memory = psutil.virtual_memory()
    cpu = psutil.cpu_percent(interval=1)
    disk = psutil.disk_usage("/")
    # Defines log info messages
    logger.info(f"🔍 System Resources - RAM: {memory.percent}%, CPU: {cpu}%, Disk: {disk.percent}%")
    if memory.percent > 85:
        logger.warning("⚠️ High RAM usage detected!")
    if cpu > 90:
        logger.warning("⚠️ High CPU usage detected!")
    if disk.percent > 90:
        logger.warning("⚠️ High Disk usage detected!")
check_system_resources()

# ✅ Reduce Memory usage with optimizers
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ✅ Initialize FastAPI app
app = FastAPI(title="Medical Chatbot API")
memory = MemoryManager()

from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
# Define the origins
origins = [
    "http://localhost:5173",                    # Vite dev server
    "http://localhost:3000",                    # Another vercel local dev
    "https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL
    
]
# Add the CORS middleware:
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,   # or ["*"] to allow all
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ✅ Use Lazy Loading for FAISS Index
index = None  # Delay FAISS Index loading until first query

# ✅ Load SentenceTransformer Model (Quantized/Halved)
logger.info("📥 Loading SentenceTransformer Model...")
print("📥 Loading SentenceTransformer Model...")
MODEL_CACHE_DIR = "/app/model_cache"
try:
    embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device="cpu")
    embedding_model = embedding_model.half()  # Reduce memory
    logger.info("✅ Model Loaded Successfully.")
    print("✅ Model Loaded Successfully.")
except Exception as e:
    logger.error(f"❌ Model Loading Failed: {e}")
    exit(1)


# ✅ Setup MongoDB Connection
# QA data
client = MongoClient(mongo_uri)
db = client["MedicalChatbotDB"]
qa_collection = db["qa_data"]
# FAISS Index data
iclient = MongoClient(index_uri)
idb = iclient["MedicalChatbotDB"]
index_collection = idb["faiss_index_files"]

# ✅ Load FAISS Index (Lazy Load)
import gridfs
fs = gridfs.GridFS(idb, collection="faiss_index_files")

def load_faiss_index():
    global index
    if index is None:
        print("⏳ Loading FAISS index from GridFS...")
        existing_file = fs.find_one({"filename": "faiss_index.bin"})
        if existing_file:
            stored_index_bytes = existing_file.read()
            index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
            index = faiss.deserialize_index(index_bytes_np)
            print("✅ FAISS Index Loaded")
            logger.info("✅ FAISS Index Loaded")
        else:
            print("❌ FAISS index not found in GridFS.")
            logger.error("❌ FAISS index not found in GridFS.")
    return index

# ✅ Retrieve Medical Info
def retrieve_medical_info(query):
    global index
    index = load_faiss_index()  # Load FAISS on demand
    # N/A question
    if index is None:
        return ["No medical information available."]
    # Embed the query and send to QA db to lookup
    query_embedding = embedding_model.encode([query], convert_to_numpy=True)
    _, idxs = index.search(query_embedding, k=3)
    results = [qa_collection.find_one({"i": int(i)}).get("Doctor", "No answer available.") for i in idxs[0]]
    return results

# ✅ Gemini Flash API Call
def gemini_flash_completion(prompt, model, temperature=0.7):
    client_genai = genai.Client(api_key=gemini_flash_api_key)
    try:
        response = client_genai.models.generate_content(model=model, contents=prompt)
        return response.text
    except Exception as e:
        logger.error(f"❌ Error calling Gemini API: {e}")
        print(f"❌ Error calling Gemini API: {e}")
        return "Error generating response from Gemini."

# ✅ Chatbot Class
class RAGMedicalChatbot:
    def __init__(self, model_name, retrieve_function):
        self.model_name = model_name
        self.retrieve = retrieve_function

    def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
        # 1. Fetch knowledge
        retrieved_info = self.retrieve(user_query)
        knowledge_base = "\n".join(retrieved_info)

        # 2. Use relevant chunks from short-term memory FAISS index (nearest 3 chunks)
        context = memory.get_relevant_chunks(user_id, user_query, top_k=3)

        # 3. Build prompt parts
        parts = ["You are a medical chatbot, designed to answer medical questions."]
        parts.append("Please format your answer using MarkDown.")
        parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
        # Historical chat retrieval case
        if context:
            parts.append("Relevant context from prior conversation:\n" + "\n".join(context))
        parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
        parts.append(f"Question: {user_query}")
        parts.append(f"Language: {lang}")
        prompt = "\n\n".join(parts)
        response = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
         # Store exchange + chunking
        if user_id:
            memory.add_exchange(user_id, user_query, response, lang=lang)
        return response.strip()

# ✅ Initialize Chatbot
chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash-preview-04-17", retrieve_function=retrieve_medical_info)

# ✅ Chat Endpoint
@app.post("/chat")
async def chat_endpoint(req: Request):
    body = await req.json()
    user_id = body.get("user_id", "anonymous")
    query   = body.get("query", "").strip()
    lang    = body.get("lang", "EN")
    # Error
    if not query:
        return JSONResponse({"response": "No query provided."})
    start = time.time()
    answer = chatbot.chat(user_id, query, lang)
    elapsed = time.time() - start
    # Final
    return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})


# ✅ Run Uvicorn
if __name__ == "__main__":
    logger.info("✅ Starting FastAPI Server...")
    print("✅ Starting FastAPI Server...")
    try:
        uvicorn.run(app, host="0.0.0.0", port=7860, log_level="debug")
    except Exception as e:
        logger.error(f"❌ Server Startup Failed: {e}")
        exit(1)