LiamKhoaLe commited on
Commit
6b4f62a
·
1 Parent(s): 9382e01

Upd translation modules. Increase kb retrieval threshold

Browse files
Files changed (2) hide show
  1. app.py +26 -12
  2. translation.py +26 -0
app.py CHANGED
@@ -10,6 +10,7 @@ from pymongo import MongoClient
10
  from google import genai
11
  from sentence_transformers import SentenceTransformer
12
  from memory import MemoryManager
 
13
 
14
  # ✅ Enable Logging for Debugging
15
  import logging
@@ -129,8 +130,8 @@ def load_faiss_index():
129
  logger.error("[KB] ❌ FAISS index not found in GridFS.")
130
  return index
131
 
132
- # ✅ Retrieve Medical Info
133
- def retrieve_medical_info(query, k=5, min_sim=0.6): # Min similarity between query and kb is to be 60%
134
  global index
135
  index = load_faiss_index()
136
  if index is None:
@@ -140,15 +141,19 @@ def retrieve_medical_info(query, k=5, min_sim=0.6): # Min similarity between que
140
  D, I = index.search(query_vec, k=k)
141
  # Filter by cosine threshold
142
  results = []
 
143
  for score, idx in zip(D[0], I[0]):
144
  if score < min_sim:
145
  continue
146
  doc = qa_collection.find_one({"i": int(idx)})
147
  if doc:
148
- results.append(doc.get("Doctor", "No answer available."))
149
- return results if results else ["No relevant medical entries found."]
 
 
150
 
151
- # ✅ Retrieve Sym-Dia Info (4962 scenario)
 
152
  def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
153
  global SYMPTOM_VECTORS, SYMPTOM_DOCS
154
  # Lazy load
@@ -162,12 +167,17 @@ def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
162
  # Similarity compute
163
  sims = SYMPTOM_VECTORS @ qvec # cosine
164
  sorted_idx = np.argsort(sims)[-top_k:][::-1]
165
- # Final
166
- return [
167
- SYMPTOM_DOCS[i]["answer"]
168
- for i in sorted_idx
169
- if sims[i] >= min_sim
170
- ]
 
 
 
 
 
171
 
172
  # ✅ Gemini Flash API Call
173
  def gemini_flash_completion(prompt, model, temperature=0.7):
@@ -186,6 +196,10 @@ class RAGMedicalChatbot:
186
  self.retrieve = retrieve_function
187
 
188
  def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
 
 
 
 
189
  # 1. Fetch knowledge
190
  ## a. KB for generic QA retrieval
191
  retrieved_info = self.retrieve(user_query)
@@ -205,7 +219,7 @@ class RAGMedicalChatbot:
205
  parts.append("Relevant context from prior conversation:\n" + "\n".join(context))
206
  # Load up guideline
207
  if knowledge_base:
208
- parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
209
  # Symptom-Diagnosis prediction RAG
210
  if diagnosis_guides:
211
  parts.append("Symptom-based diagnosis guidance:\n" + "\n".join(diagnosis_guides))
 
10
  from google import genai
11
  from sentence_transformers import SentenceTransformer
12
  from memory import MemoryManager
13
+ from translation import translate_query
14
 
15
  # ✅ Enable Logging for Debugging
16
  import logging
 
130
  logger.error("[KB] ❌ FAISS index not found in GridFS.")
131
  return index
132
 
133
+ # ✅ Retrieve Medical Info (256,916 scenario)
134
+ def retrieve_medical_info(query, k=5, min_sim=0.8): # Min similarity between query and kb is to be 80%
135
  global index
136
  index = load_faiss_index()
137
  if index is None:
 
141
  D, I = index.search(query_vec, k=k)
142
  # Filter by cosine threshold
143
  results = []
144
+ seen = set() # avoid near-duplicate KB responses
145
  for score, idx in zip(D[0], I[0]):
146
  if score < min_sim:
147
  continue
148
  doc = qa_collection.find_one({"i": int(idx)})
149
  if doc:
150
+ answer = doc.get("Doctor", "No answer available.")
151
+ if answer not in seen:
152
+ seen.add(answer)
153
+ results.append(answer)
154
 
155
+
156
+ # ✅ Retrieve Sym-Dia Info (4,962 scenario)
157
  def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
158
  global SYMPTOM_VECTORS, SYMPTOM_DOCS
159
  # Lazy load
 
167
  # Similarity compute
168
  sims = SYMPTOM_VECTORS @ qvec # cosine
169
  sorted_idx = np.argsort(sims)[-top_k:][::-1]
170
+ seen_diag = set()
171
+ final = [] # Dedup
172
+ for i in sorted_idx:
173
+ sim = sims[i]
174
+ if sim < min_sim:
175
+ continue
176
+ label = SYMPTOM_DOCS[i]["prognosis"]
177
+ if label not in seen_diag:
178
+ final.append(SYMPTOM_DOCS[i]["answer"])
179
+ seen_diag.add(label)
180
+ return final
181
 
182
  # ✅ Gemini Flash API Call
183
  def gemini_flash_completion(prompt, model, temperature=0.7):
 
196
  self.retrieve = retrieve_function
197
 
198
  def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
199
+ # 0. Translate query if not EN, this help our RAG system
200
+ if lang.upper() in {"VI", "ZH"}:
201
+ user_query = translate_query(user_query, lang.lower())
202
+
203
  # 1. Fetch knowledge
204
  ## a. KB for generic QA retrieval
205
  retrieved_info = self.retrieve(user_query)
 
219
  parts.append("Relevant context from prior conversation:\n" + "\n".join(context))
220
  # Load up guideline
221
  if knowledge_base:
222
+ parts.append(f"Medical scenario knowledge: {knowledge_base}")
223
  # Symptom-Diagnosis prediction RAG
224
  if diagnosis_guides:
225
  parts.append("Symptom-based diagnosis guidance:\n" + "\n".join(diagnosis_guides))
translation.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # translation.py
2
+ from transformers import pipeline
3
+ import logging
4
+
5
+ logger = logging.getLogger("translation-agent")
6
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
7
+
8
+ # To use lazy model loader
9
+ vi_en = None
10
+ zh_en = None
11
+
12
+ def translate_query(text: str, lang_code: str) -> str:
13
+ global vi_en, zh_en
14
+ if lang_code == "vi":
15
+ if vi_en is None:
16
+ vi_en = pipeline("translation", model="VietAI/envit5-translation", src_lang="vi", tgt_lang="en", device=-1)
17
+ result = vi_en(text, max_length=512)[0]["translation_text"]
18
+ logger.info(f"[En-Vi] Query in `{lang_code}` translated to: {result}")
19
+ return result
20
+ elif lang_code == "zh":
21
+ if zh_en is None:
22
+ zh_en = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en", device=-1)
23
+ result = zh_en(text, max_length=512)[0]["translation_text"]
24
+ logger.info(f"[En-Zh] Query in `{lang_code}` translated to: {result}")
25
+ return result
26
+ return text