LiamKhoaLe commited on
Commit
9382e01
·
1 Parent(s): 8bc48fc

Upd RAG for sym-diagnosis

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py CHANGED
@@ -94,6 +94,9 @@ except Exception as e:
94
  logger.error(f"❌ Model Loading Failed: {e}")
95
  exit(1)
96
 
 
 
 
97
 
98
  # ✅ Setup MongoDB Connection
99
  # QA data
@@ -104,6 +107,9 @@ qa_collection = db["qa_data"]
104
  iclient = MongoClient(index_uri)
105
  idb = iclient["MedicalChatbotDB"]
106
  index_collection = idb["faiss_index_files"]
 
 
 
107
 
108
  # ✅ Load FAISS Index (Lazy Load)
109
  import gridfs
@@ -142,6 +148,26 @@ def retrieve_medical_info(query, k=5, min_sim=0.6): # Min similarity between que
142
  results.append(doc.get("Doctor", "No answer available."))
143
  return results if results else ["No relevant medical entries found."]
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # ✅ Gemini Flash API Call
147
  def gemini_flash_completion(prompt, model, temperature=0.7):
@@ -161,8 +187,11 @@ class RAGMedicalChatbot:
161
 
162
  def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
163
  # 1. Fetch knowledge
 
164
  retrieved_info = self.retrieve(user_query)
165
  knowledge_base = "\n".join(retrieved_info)
 
 
166
 
167
  # 2. Use relevant chunks from short-term memory FAISS index (nearest 3 chunks)
168
  context = memory.get_relevant_chunks(user_id, user_query, top_k=3)
@@ -177,6 +206,9 @@ class RAGMedicalChatbot:
177
  # Load up guideline
178
  if knowledge_base:
179
  parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
 
 
 
180
  parts.append(f"Question: {user_query}")
181
  parts.append(f"Language: {lang}")
182
  prompt = "\n\n".join(parts)
 
94
  logger.error(f"❌ Model Loading Failed: {e}")
95
  exit(1)
96
 
97
+ # Cache in-memory vectors (optional — useful for <10k rows)
98
+ SYMPTOM_VECTORS = None
99
+ SYMPTOM_DOCS = None
100
 
101
  # ✅ Setup MongoDB Connection
102
  # QA data
 
107
  iclient = MongoClient(index_uri)
108
  idb = iclient["MedicalChatbotDB"]
109
  index_collection = idb["faiss_index_files"]
110
+ # Symptom Diagnosis data
111
+ symptom_client = MongoClient(mongo_uri)
112
+ symptom_col = symptom_client["MedicalChatbotDB"]["symptom_diagnosis"]
113
 
114
  # ✅ Load FAISS Index (Lazy Load)
115
  import gridfs
 
148
  results.append(doc.get("Doctor", "No answer available."))
149
  return results if results else ["No relevant medical entries found."]
150
 
151
+ # ✅ Retrieve Sym-Dia Info (4962 scenario)
152
+ def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
153
+ global SYMPTOM_VECTORS, SYMPTOM_DOCS
154
+ # Lazy load
155
+ if SYMPTOM_VECTORS is None:
156
+ all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1}))
157
+ SYMPTOM_DOCS = all_docs
158
+ SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
159
+ # Embed input
160
+ qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
161
+ qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
162
+ # Similarity compute
163
+ sims = SYMPTOM_VECTORS @ qvec # cosine
164
+ sorted_idx = np.argsort(sims)[-top_k:][::-1]
165
+ # Final
166
+ return [
167
+ SYMPTOM_DOCS[i]["answer"]
168
+ for i in sorted_idx
169
+ if sims[i] >= min_sim
170
+ ]
171
 
172
  # ✅ Gemini Flash API Call
173
  def gemini_flash_completion(prompt, model, temperature=0.7):
 
187
 
188
  def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
189
  # 1. Fetch knowledge
190
+ ## a. KB for generic QA retrieval
191
  retrieved_info = self.retrieve(user_query)
192
  knowledge_base = "\n".join(retrieved_info)
193
+ ## b. Diagnosis RAG from symptom query
194
+ diagnosis_guides = retrieve_diagnosis_from_symptoms(user_query) # smart matcher
195
 
196
  # 2. Use relevant chunks from short-term memory FAISS index (nearest 3 chunks)
197
  context = memory.get_relevant_chunks(user_id, user_query, top_k=3)
 
206
  # Load up guideline
207
  if knowledge_base:
208
  parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
209
+ # Symptom-Diagnosis prediction RAG
210
+ if diagnosis_guides:
211
+ parts.append("Symptom-based diagnosis guidance:\n" + "\n".join(diagnosis_guides))
212
  parts.append(f"Question: {user_query}")
213
  parts.append(f"Language: {lang}")
214
  prompt = "\n\n".join(parts)