LiamKhoaLe commited on
Commit
3dcd314
·
1 Parent(s): 2415f43

Update diagnosis retrieval and embedder

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. diagnosis.py +76 -0
app.py CHANGED
@@ -9,9 +9,10 @@ from fastapi.responses import JSONResponse
9
  from pymongo import MongoClient
10
  from google import genai
11
  from sentence_transformers import SentenceTransformer
 
12
  from memory import MemoryManager
13
  from translation import translate_query
14
- from sentence_transformers.util import cos_sim
15
 
16
  # ✅ Enable Logging for Debugging
17
  import logging
@@ -183,7 +184,7 @@ def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
183
  global SYMPTOM_VECTORS, SYMPTOM_DOCS
184
  # Lazy load
185
  if SYMPTOM_VECTORS is None:
186
- all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1}))
187
  SYMPTOM_DOCS = all_docs
188
  SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
189
  # Embed input
 
9
  from pymongo import MongoClient
10
  from google import genai
11
  from sentence_transformers import SentenceTransformer
12
+ from sentence_transformers.util import cos_sim
13
  from memory import MemoryManager
14
  from translation import translate_query
15
+
16
 
17
  # ✅ Enable Logging for Debugging
18
  import logging
 
184
  global SYMPTOM_VECTORS, SYMPTOM_DOCS
185
  # Lazy load
186
  if SYMPTOM_VECTORS is None:
187
+ all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1}))
188
  SYMPTOM_DOCS = all_docs
189
  SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
190
  # Embed input
diagnosis.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ✅ Google Colab: SymbiPredict Embedding + Chunking + MongoDB Upload
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from pymongo import MongoClient
7
+ from pymongo.errors import BulkWriteError
8
+ import hashlib, os
9
+ from tqdm import tqdm
10
+
11
+ # ✅ Load model
12
+ model = SentenceTransformer("all-MiniLM-L6-v2")
13
+
14
+ # ✅ Load SymbiPredict
15
+ df = pd.read_csv("symbipredict_2022.csv")
16
+
17
+ # ✅ Connect to MongoDB
18
+ mongo_uri = "..."
19
+ client = MongoClient(mongo_uri)
20
+ db = client["MedicalChatbotDB"]
21
+ collection = db["symptom_diagnosis"]
22
+
23
+ # ✅ Clear old symptom-diagnosis records
24
+ print("🧹 Dropping old 'symptom_diagnosis' collection...")
25
+ collection.drop()
26
+ # Reconfirm collection is empty
27
+ if collection.count_documents({}) != 0:
28
+ raise RuntimeError("❌ Collection not empty after drop — aborting!")
29
+
30
+ # ✅ Convert CSV rows into QA-style records with embeddings
31
+ records = []
32
+ for i, row in tqdm(df.iterrows(), total=len(df)):
33
+ symptom_cols = df.columns[:-1]
34
+ label_col = df.columns[-1]
35
+
36
+ # Extract symptoms present (value==1)
37
+ symptoms = [col.replace("_", " ").strip() for col in symptom_cols if row[col] == 1]
38
+ if not symptoms:
39
+ continue
40
+
41
+ label = row[label_col].strip()
42
+ question = f"What disease is likely given these symptoms: {', '.join(symptoms)}?"
43
+ answer = f"The patient is likely suffering from: {label}."
44
+
45
+ # Embed question only
46
+ embed = model.encode(question, convert_to_numpy=True)
47
+ hashkey = hashlib.md5((question + answer).encode()).hexdigest()
48
+
49
+ records.append({
50
+ "_id": hashkey,
51
+ "i": int(i),
52
+ "symptoms": symptoms,
53
+ "prognosis": label,
54
+ "question": question,
55
+ "answer": answer,
56
+ "embedding": embed.tolist()
57
+ })
58
+
59
+ # ✅ Save to MongoDB
60
+ if records:
61
+ print(f"⬆️ Uploading {len(records)} records to MongoDB...")
62
+ unique_ids = set()
63
+ deduped = []
64
+ for r in records:
65
+ if r["_id"] not in unique_ids:
66
+ unique_ids.add(r["_id"])
67
+ deduped.append(r)
68
+ try:
69
+ collection.insert_many(deduped, ordered=False)
70
+ print(f"✅ Inserted {len(deduped)} records without duplicates.")
71
+ except BulkWriteError as bwe:
72
+ inserted = bwe.details.get('nInserted', 0)
73
+ print(f"⚠️ Inserted with some duplicate skips. Records inserted: {inserted}")
74
+ print("✅ Upload complete.")
75
+ else:
76
+ print("⚠️ No records to upload.")