Spaces:
Running
Running
Commit
·
3dcd314
1
Parent(s):
2415f43
Update diagnosis retrieval and embedder
Browse files- app.py +3 -2
- diagnosis.py +76 -0
app.py
CHANGED
@@ -9,9 +9,10 @@ from fastapi.responses import JSONResponse
|
|
9 |
from pymongo import MongoClient
|
10 |
from google import genai
|
11 |
from sentence_transformers import SentenceTransformer
|
|
|
12 |
from memory import MemoryManager
|
13 |
from translation import translate_query
|
14 |
-
|
15 |
|
16 |
# ✅ Enable Logging for Debugging
|
17 |
import logging
|
@@ -183,7 +184,7 @@ def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
|
|
183 |
global SYMPTOM_VECTORS, SYMPTOM_DOCS
|
184 |
# Lazy load
|
185 |
if SYMPTOM_VECTORS is None:
|
186 |
-
all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1}))
|
187 |
SYMPTOM_DOCS = all_docs
|
188 |
SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
|
189 |
# Embed input
|
|
|
9 |
from pymongo import MongoClient
|
10 |
from google import genai
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
+
from sentence_transformers.util import cos_sim
|
13 |
from memory import MemoryManager
|
14 |
from translation import translate_query
|
15 |
+
|
16 |
|
17 |
# ✅ Enable Logging for Debugging
|
18 |
import logging
|
|
|
184 |
global SYMPTOM_VECTORS, SYMPTOM_DOCS
|
185 |
# Lazy load
|
186 |
if SYMPTOM_VECTORS is None:
|
187 |
+
all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1}))
|
188 |
SYMPTOM_DOCS = all_docs
|
189 |
SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
|
190 |
# Embed input
|
diagnosis.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ✅ Google Colab: SymbiPredict Embedding + Chunking + MongoDB Upload
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
from pymongo import MongoClient
|
7 |
+
from pymongo.errors import BulkWriteError
|
8 |
+
import hashlib, os
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# ✅ Load model
|
12 |
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
13 |
+
|
14 |
+
# ✅ Load SymbiPredict
|
15 |
+
df = pd.read_csv("symbipredict_2022.csv")
|
16 |
+
|
17 |
+
# ✅ Connect to MongoDB
|
18 |
+
mongo_uri = "..."
|
19 |
+
client = MongoClient(mongo_uri)
|
20 |
+
db = client["MedicalChatbotDB"]
|
21 |
+
collection = db["symptom_diagnosis"]
|
22 |
+
|
23 |
+
# ✅ Clear old symptom-diagnosis records
|
24 |
+
print("🧹 Dropping old 'symptom_diagnosis' collection...")
|
25 |
+
collection.drop()
|
26 |
+
# Reconfirm collection is empty
|
27 |
+
if collection.count_documents({}) != 0:
|
28 |
+
raise RuntimeError("❌ Collection not empty after drop — aborting!")
|
29 |
+
|
30 |
+
# ✅ Convert CSV rows into QA-style records with embeddings
|
31 |
+
records = []
|
32 |
+
for i, row in tqdm(df.iterrows(), total=len(df)):
|
33 |
+
symptom_cols = df.columns[:-1]
|
34 |
+
label_col = df.columns[-1]
|
35 |
+
|
36 |
+
# Extract symptoms present (value==1)
|
37 |
+
symptoms = [col.replace("_", " ").strip() for col in symptom_cols if row[col] == 1]
|
38 |
+
if not symptoms:
|
39 |
+
continue
|
40 |
+
|
41 |
+
label = row[label_col].strip()
|
42 |
+
question = f"What disease is likely given these symptoms: {', '.join(symptoms)}?"
|
43 |
+
answer = f"The patient is likely suffering from: {label}."
|
44 |
+
|
45 |
+
# Embed question only
|
46 |
+
embed = model.encode(question, convert_to_numpy=True)
|
47 |
+
hashkey = hashlib.md5((question + answer).encode()).hexdigest()
|
48 |
+
|
49 |
+
records.append({
|
50 |
+
"_id": hashkey,
|
51 |
+
"i": int(i),
|
52 |
+
"symptoms": symptoms,
|
53 |
+
"prognosis": label,
|
54 |
+
"question": question,
|
55 |
+
"answer": answer,
|
56 |
+
"embedding": embed.tolist()
|
57 |
+
})
|
58 |
+
|
59 |
+
# ✅ Save to MongoDB
|
60 |
+
if records:
|
61 |
+
print(f"⬆️ Uploading {len(records)} records to MongoDB...")
|
62 |
+
unique_ids = set()
|
63 |
+
deduped = []
|
64 |
+
for r in records:
|
65 |
+
if r["_id"] not in unique_ids:
|
66 |
+
unique_ids.add(r["_id"])
|
67 |
+
deduped.append(r)
|
68 |
+
try:
|
69 |
+
collection.insert_many(deduped, ordered=False)
|
70 |
+
print(f"✅ Inserted {len(deduped)} records without duplicates.")
|
71 |
+
except BulkWriteError as bwe:
|
72 |
+
inserted = bwe.details.get('nInserted', 0)
|
73 |
+
print(f"⚠️ Inserted with some duplicate skips. Records inserted: {inserted}")
|
74 |
+
print("✅ Upload complete.")
|
75 |
+
else:
|
76 |
+
print("⚠️ No records to upload.")
|