Spaces:
Running
Running
Commit
·
9382e01
1
Parent(s):
8bc48fc
Upd RAG for sym-diagnosis
Browse files
app.py
CHANGED
@@ -94,6 +94,9 @@ except Exception as e:
|
|
94 |
logger.error(f"❌ Model Loading Failed: {e}")
|
95 |
exit(1)
|
96 |
|
|
|
|
|
|
|
97 |
|
98 |
# ✅ Setup MongoDB Connection
|
99 |
# QA data
|
@@ -104,6 +107,9 @@ qa_collection = db["qa_data"]
|
|
104 |
iclient = MongoClient(index_uri)
|
105 |
idb = iclient["MedicalChatbotDB"]
|
106 |
index_collection = idb["faiss_index_files"]
|
|
|
|
|
|
|
107 |
|
108 |
# ✅ Load FAISS Index (Lazy Load)
|
109 |
import gridfs
|
@@ -142,6 +148,26 @@ def retrieve_medical_info(query, k=5, min_sim=0.6): # Min similarity between que
|
|
142 |
results.append(doc.get("Doctor", "No answer available."))
|
143 |
return results if results else ["No relevant medical entries found."]
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
# ✅ Gemini Flash API Call
|
147 |
def gemini_flash_completion(prompt, model, temperature=0.7):
|
@@ -161,8 +187,11 @@ class RAGMedicalChatbot:
|
|
161 |
|
162 |
def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
|
163 |
# 1. Fetch knowledge
|
|
|
164 |
retrieved_info = self.retrieve(user_query)
|
165 |
knowledge_base = "\n".join(retrieved_info)
|
|
|
|
|
166 |
|
167 |
# 2. Use relevant chunks from short-term memory FAISS index (nearest 3 chunks)
|
168 |
context = memory.get_relevant_chunks(user_id, user_query, top_k=3)
|
@@ -177,6 +206,9 @@ class RAGMedicalChatbot:
|
|
177 |
# Load up guideline
|
178 |
if knowledge_base:
|
179 |
parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
|
|
|
|
|
|
|
180 |
parts.append(f"Question: {user_query}")
|
181 |
parts.append(f"Language: {lang}")
|
182 |
prompt = "\n\n".join(parts)
|
|
|
94 |
logger.error(f"❌ Model Loading Failed: {e}")
|
95 |
exit(1)
|
96 |
|
97 |
+
# Cache in-memory vectors (optional — useful for <10k rows)
|
98 |
+
SYMPTOM_VECTORS = None
|
99 |
+
SYMPTOM_DOCS = None
|
100 |
|
101 |
# ✅ Setup MongoDB Connection
|
102 |
# QA data
|
|
|
107 |
iclient = MongoClient(index_uri)
|
108 |
idb = iclient["MedicalChatbotDB"]
|
109 |
index_collection = idb["faiss_index_files"]
|
110 |
+
# Symptom Diagnosis data
|
111 |
+
symptom_client = MongoClient(mongo_uri)
|
112 |
+
symptom_col = symptom_client["MedicalChatbotDB"]["symptom_diagnosis"]
|
113 |
|
114 |
# ✅ Load FAISS Index (Lazy Load)
|
115 |
import gridfs
|
|
|
148 |
results.append(doc.get("Doctor", "No answer available."))
|
149 |
return results if results else ["No relevant medical entries found."]
|
150 |
|
151 |
+
# ✅ Retrieve Sym-Dia Info (4962 scenario)
|
152 |
+
def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
|
153 |
+
global SYMPTOM_VECTORS, SYMPTOM_DOCS
|
154 |
+
# Lazy load
|
155 |
+
if SYMPTOM_VECTORS is None:
|
156 |
+
all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1}))
|
157 |
+
SYMPTOM_DOCS = all_docs
|
158 |
+
SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
|
159 |
+
# Embed input
|
160 |
+
qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
|
161 |
+
qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
|
162 |
+
# Similarity compute
|
163 |
+
sims = SYMPTOM_VECTORS @ qvec # cosine
|
164 |
+
sorted_idx = np.argsort(sims)[-top_k:][::-1]
|
165 |
+
# Final
|
166 |
+
return [
|
167 |
+
SYMPTOM_DOCS[i]["answer"]
|
168 |
+
for i in sorted_idx
|
169 |
+
if sims[i] >= min_sim
|
170 |
+
]
|
171 |
|
172 |
# ✅ Gemini Flash API Call
|
173 |
def gemini_flash_completion(prompt, model, temperature=0.7):
|
|
|
187 |
|
188 |
def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
|
189 |
# 1. Fetch knowledge
|
190 |
+
## a. KB for generic QA retrieval
|
191 |
retrieved_info = self.retrieve(user_query)
|
192 |
knowledge_base = "\n".join(retrieved_info)
|
193 |
+
## b. Diagnosis RAG from symptom query
|
194 |
+
diagnosis_guides = retrieve_diagnosis_from_symptoms(user_query) # smart matcher
|
195 |
|
196 |
# 2. Use relevant chunks from short-term memory FAISS index (nearest 3 chunks)
|
197 |
context = memory.get_relevant_chunks(user_id, user_query, top_k=3)
|
|
|
206 |
# Load up guideline
|
207 |
if knowledge_base:
|
208 |
parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
|
209 |
+
# Symptom-Diagnosis prediction RAG
|
210 |
+
if diagnosis_guides:
|
211 |
+
parts.append("Symptom-based diagnosis guidance:\n" + "\n".join(diagnosis_guides))
|
212 |
parts.append(f"Question: {user_query}")
|
213 |
parts.append(f"Language: {lang}")
|
214 |
prompt = "\n\n".join(parts)
|