Spaces:
Running
Running
Commit
·
6b4f62a
1
Parent(s):
9382e01
Upd translation modules. Increase kb retrieval threshold
Browse files- app.py +26 -12
- translation.py +26 -0
app.py
CHANGED
@@ -10,6 +10,7 @@ from pymongo import MongoClient
|
|
10 |
from google import genai
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
from memory import MemoryManager
|
|
|
13 |
|
14 |
# ✅ Enable Logging for Debugging
|
15 |
import logging
|
@@ -129,8 +130,8 @@ def load_faiss_index():
|
|
129 |
logger.error("[KB] ❌ FAISS index not found in GridFS.")
|
130 |
return index
|
131 |
|
132 |
-
# ✅ Retrieve Medical Info
|
133 |
-
def retrieve_medical_info(query, k=5, min_sim=0.
|
134 |
global index
|
135 |
index = load_faiss_index()
|
136 |
if index is None:
|
@@ -140,15 +141,19 @@ def retrieve_medical_info(query, k=5, min_sim=0.6): # Min similarity between que
|
|
140 |
D, I = index.search(query_vec, k=k)
|
141 |
# Filter by cosine threshold
|
142 |
results = []
|
|
|
143 |
for score, idx in zip(D[0], I[0]):
|
144 |
if score < min_sim:
|
145 |
continue
|
146 |
doc = qa_collection.find_one({"i": int(idx)})
|
147 |
if doc:
|
148 |
-
|
149 |
-
|
|
|
|
|
150 |
|
151 |
-
|
|
|
152 |
def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
|
153 |
global SYMPTOM_VECTORS, SYMPTOM_DOCS
|
154 |
# Lazy load
|
@@ -162,12 +167,17 @@ def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
|
|
162 |
# Similarity compute
|
163 |
sims = SYMPTOM_VECTORS @ qvec # cosine
|
164 |
sorted_idx = np.argsort(sims)[-top_k:][::-1]
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
if
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
# ✅ Gemini Flash API Call
|
173 |
def gemini_flash_completion(prompt, model, temperature=0.7):
|
@@ -186,6 +196,10 @@ class RAGMedicalChatbot:
|
|
186 |
self.retrieve = retrieve_function
|
187 |
|
188 |
def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
|
|
|
|
|
|
|
|
|
189 |
# 1. Fetch knowledge
|
190 |
## a. KB for generic QA retrieval
|
191 |
retrieved_info = self.retrieve(user_query)
|
@@ -205,7 +219,7 @@ class RAGMedicalChatbot:
|
|
205 |
parts.append("Relevant context from prior conversation:\n" + "\n".join(context))
|
206 |
# Load up guideline
|
207 |
if knowledge_base:
|
208 |
-
parts.append(f"Medical knowledge
|
209 |
# Symptom-Diagnosis prediction RAG
|
210 |
if diagnosis_guides:
|
211 |
parts.append("Symptom-based diagnosis guidance:\n" + "\n".join(diagnosis_guides))
|
|
|
10 |
from google import genai
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
from memory import MemoryManager
|
13 |
+
from translation import translate_query
|
14 |
|
15 |
# ✅ Enable Logging for Debugging
|
16 |
import logging
|
|
|
130 |
logger.error("[KB] ❌ FAISS index not found in GridFS.")
|
131 |
return index
|
132 |
|
133 |
+
# ✅ Retrieve Medical Info (256,916 scenario)
|
134 |
+
def retrieve_medical_info(query, k=5, min_sim=0.8): # Min similarity between query and kb is to be 80%
|
135 |
global index
|
136 |
index = load_faiss_index()
|
137 |
if index is None:
|
|
|
141 |
D, I = index.search(query_vec, k=k)
|
142 |
# Filter by cosine threshold
|
143 |
results = []
|
144 |
+
seen = set() # avoid near-duplicate KB responses
|
145 |
for score, idx in zip(D[0], I[0]):
|
146 |
if score < min_sim:
|
147 |
continue
|
148 |
doc = qa_collection.find_one({"i": int(idx)})
|
149 |
if doc:
|
150 |
+
answer = doc.get("Doctor", "No answer available.")
|
151 |
+
if answer not in seen:
|
152 |
+
seen.add(answer)
|
153 |
+
results.append(answer)
|
154 |
|
155 |
+
|
156 |
+
# ✅ Retrieve Sym-Dia Info (4,962 scenario)
|
157 |
def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
|
158 |
global SYMPTOM_VECTORS, SYMPTOM_DOCS
|
159 |
# Lazy load
|
|
|
167 |
# Similarity compute
|
168 |
sims = SYMPTOM_VECTORS @ qvec # cosine
|
169 |
sorted_idx = np.argsort(sims)[-top_k:][::-1]
|
170 |
+
seen_diag = set()
|
171 |
+
final = [] # Dedup
|
172 |
+
for i in sorted_idx:
|
173 |
+
sim = sims[i]
|
174 |
+
if sim < min_sim:
|
175 |
+
continue
|
176 |
+
label = SYMPTOM_DOCS[i]["prognosis"]
|
177 |
+
if label not in seen_diag:
|
178 |
+
final.append(SYMPTOM_DOCS[i]["answer"])
|
179 |
+
seen_diag.add(label)
|
180 |
+
return final
|
181 |
|
182 |
# ✅ Gemini Flash API Call
|
183 |
def gemini_flash_completion(prompt, model, temperature=0.7):
|
|
|
196 |
self.retrieve = retrieve_function
|
197 |
|
198 |
def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
|
199 |
+
# 0. Translate query if not EN, this help our RAG system
|
200 |
+
if lang.upper() in {"VI", "ZH"}:
|
201 |
+
user_query = translate_query(user_query, lang.lower())
|
202 |
+
|
203 |
# 1. Fetch knowledge
|
204 |
## a. KB for generic QA retrieval
|
205 |
retrieved_info = self.retrieve(user_query)
|
|
|
219 |
parts.append("Relevant context from prior conversation:\n" + "\n".join(context))
|
220 |
# Load up guideline
|
221 |
if knowledge_base:
|
222 |
+
parts.append(f"Medical scenario knowledge: {knowledge_base}")
|
223 |
# Symptom-Diagnosis prediction RAG
|
224 |
if diagnosis_guides:
|
225 |
parts.append("Symptom-based diagnosis guidance:\n" + "\n".join(diagnosis_guides))
|
translation.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# translation.py
|
2 |
+
from transformers import pipeline
|
3 |
+
import logging
|
4 |
+
|
5 |
+
logger = logging.getLogger("translation-agent")
|
6 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
|
7 |
+
|
8 |
+
# To use lazy model loader
|
9 |
+
vi_en = None
|
10 |
+
zh_en = None
|
11 |
+
|
12 |
+
def translate_query(text: str, lang_code: str) -> str:
|
13 |
+
global vi_en, zh_en
|
14 |
+
if lang_code == "vi":
|
15 |
+
if vi_en is None:
|
16 |
+
vi_en = pipeline("translation", model="VietAI/envit5-translation", src_lang="vi", tgt_lang="en", device=-1)
|
17 |
+
result = vi_en(text, max_length=512)[0]["translation_text"]
|
18 |
+
logger.info(f"[En-Vi] Query in `{lang_code}` translated to: {result}")
|
19 |
+
return result
|
20 |
+
elif lang_code == "zh":
|
21 |
+
if zh_en is None:
|
22 |
+
zh_en = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en", device=-1)
|
23 |
+
result = zh_en(text, max_length=512)[0]["translation_text"]
|
24 |
+
logger.info(f"[En-Zh] Query in `{lang_code}` translated to: {result}")
|
25 |
+
return result
|
26 |
+
return text
|