Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,14 @@ from langgraph.graph import StateGraph, END
|
|
6 |
from typing import TypedDict, List, Dict, Optional
|
7 |
from datetime import datetime
|
8 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
# Enhanced State Management
|
11 |
class MedicalState(TypedDict):
|
12 |
patient_id: str
|
13 |
conversation_history: List[Dict]
|
@@ -20,12 +26,49 @@ class MedicalState(TypedDict):
|
|
20 |
red_flags: List[str]
|
21 |
assessment_complete: bool
|
22 |
suggested_actions: List[str]
|
23 |
-
consultation_stage: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
# Medical Knowledge Base
|
26 |
MEDICAL_CATEGORIES = {
|
27 |
-
"respiratory": ["cough", "shortness of breath", "chest pain", "wheezing"],
|
28 |
-
"gastrointestinal": ["nausea", "vomiting", "diarrhea", "stomach pain", "heartburn"],
|
29 |
"neurological": ["headache", "dizziness", "numbness", "tingling"],
|
30 |
"musculoskeletal": ["joint pain", "muscle pain", "back pain", "stiffness"],
|
31 |
"cardiovascular": ["chest pain", "palpitations", "swelling", "fatigue"],
|
@@ -39,71 +82,99 @@ RED_FLAGS = [
|
|
39 |
"sudden vision changes", "loss of consciousness", "severe allergic reaction"
|
40 |
]
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
class EnhancedMedicalAssistant:
|
53 |
def __init__(self):
|
54 |
self.load_models()
|
|
|
55 |
self.setup_langgraph()
|
|
|
56 |
|
57 |
def load_models(self):
|
58 |
"""Load the AI models"""
|
59 |
print("Loading models...")
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
def setup_langgraph(self):
|
78 |
-
"""Setup LangGraph workflow"""
|
79 |
workflow = StateGraph(MedicalState)
|
80 |
|
81 |
-
# Add nodes
|
82 |
workflow.add_node("intake", self.patient_intake)
|
83 |
-
workflow.add_node("symptom_assessment", self.assess_symptoms)
|
84 |
-
workflow.add_node("risk_evaluation", self.evaluate_risks)
|
85 |
workflow.add_node("generate_recommendations", self.generate_recommendations)
|
86 |
workflow.add_node("emergency_triage", self.emergency_triage)
|
87 |
|
88 |
-
# Define edges
|
89 |
workflow.set_entry_point("intake")
|
90 |
workflow.add_conditional_edges(
|
91 |
"intake",
|
92 |
self.route_after_intake,
|
93 |
{
|
94 |
-
"continue_assessment": "symptom_assessment",
|
95 |
"emergency": "emergency_triage",
|
96 |
-
"
|
97 |
-
}
|
98 |
-
)
|
99 |
-
workflow.add_edge("symptom_assessment", "risk_evaluation")
|
100 |
-
workflow.add_conditional_edges(
|
101 |
-
"risk_evaluation",
|
102 |
-
self.route_after_risk_eval,
|
103 |
-
{
|
104 |
-
"emergency": "emergency_triage",
|
105 |
-
"continue": "generate_recommendations",
|
106 |
-
"need_more_info": "symptom_assessment"
|
107 |
}
|
108 |
)
|
109 |
workflow.add_edge("generate_recommendations", END)
|
@@ -112,59 +183,25 @@ class EnhancedMedicalAssistant:
|
|
112 |
self.workflow = workflow.compile()
|
113 |
|
114 |
def patient_intake(self, state: MedicalState) -> MedicalState:
|
115 |
-
"""
|
116 |
last_message = state["conversation_history"][-1]["content"] if state["conversation_history"] else ""
|
117 |
|
118 |
-
# Extract symptoms
|
119 |
detected_symptoms = self.extract_symptoms(last_message)
|
120 |
state["symptoms"].update(detected_symptoms)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
# Check for red flags
|
123 |
red_flags = self.check_red_flags(last_message)
|
124 |
-
|
125 |
-
state["red_flags"].extend(red_flags)
|
126 |
-
|
127 |
-
# Determine what vital questions still need to be asked
|
128 |
-
missing_questions = self.get_missing_vital_questions(state)
|
129 |
-
|
130 |
-
if missing_questions and len(state["conversation_history"]) < 6:
|
131 |
-
state["consultation_stage"] = "intake"
|
132 |
-
return state
|
133 |
-
else:
|
134 |
-
state["consultation_stage"] = "assessment"
|
135 |
-
return state
|
136 |
-
|
137 |
-
def assess_symptoms(self, state: MedicalState) -> MedicalState:
|
138 |
-
"""Detailed symptom assessment"""
|
139 |
-
# Analyze symptom patterns and severity
|
140 |
-
for symptom, details in state["symptoms"].items():
|
141 |
-
if "severity" not in details:
|
142 |
-
# Need to ask about severity
|
143 |
-
state["consultation_stage"] = "assessment"
|
144 |
-
return state
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
def evaluate_risks(self, state: MedicalState) -> MedicalState:
|
150 |
-
"""Evaluate patient risks and urgency"""
|
151 |
-
risk_score = 0
|
152 |
-
|
153 |
-
# Check red flags
|
154 |
-
if state["red_flags"]:
|
155 |
-
risk_score += len(state["red_flags"]) * 3
|
156 |
-
|
157 |
-
# Check severity scores
|
158 |
-
for severity in state["severity_scores"].values():
|
159 |
-
if severity >= 8:
|
160 |
-
risk_score += 2
|
161 |
-
elif severity >= 6:
|
162 |
-
risk_score += 1
|
163 |
-
|
164 |
-
# Check symptom duration and progression
|
165 |
-
# (Implementation would analyze timeline)
|
166 |
-
|
167 |
-
if risk_score >= 5:
|
168 |
state["consultation_stage"] = "emergency"
|
169 |
else:
|
170 |
state["consultation_stage"] = "recommendations"
|
@@ -172,23 +209,82 @@ class EnhancedMedicalAssistant:
|
|
172 |
return state
|
173 |
|
174 |
def generate_recommendations(self, state: MedicalState) -> MedicalState:
|
175 |
-
"""Generate
|
176 |
-
|
177 |
-
|
178 |
-
# Use Meditron for medical recommendations
|
179 |
-
recommendations = self.get_meditron_recommendations(patient_summary)
|
180 |
state["suggested_actions"] = recommendations
|
181 |
-
|
182 |
return state
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
def emergency_triage(self, state: MedicalState) -> MedicalState:
|
185 |
"""Handle emergency situations"""
|
186 |
-
emergency_response =
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
state["suggested_actions"] = [emergency_response]
|
193 |
return state
|
194 |
|
@@ -196,19 +292,8 @@ class EnhancedMedicalAssistant:
|
|
196 |
"""Route decision after intake"""
|
197 |
if state["red_flags"]:
|
198 |
return "emergency"
|
199 |
-
elif len(state["vital_questions_asked"]) < 5:
|
200 |
-
return "continue_assessment"
|
201 |
else:
|
202 |
-
return "
|
203 |
-
|
204 |
-
def route_after_risk_eval(self, state: MedicalState):
|
205 |
-
"""Route decision after risk evaluation"""
|
206 |
-
if state["consultation_stage"] == "emergency":
|
207 |
-
return "emergency"
|
208 |
-
elif state["assessment_complete"]:
|
209 |
-
return "continue"
|
210 |
-
else:
|
211 |
-
return "need_more_info"
|
212 |
|
213 |
def extract_symptoms(self, text: str) -> Dict:
|
214 |
"""Extract and categorize symptoms from patient text"""
|
@@ -237,62 +322,18 @@ class EnhancedMedicalAssistant:
|
|
237 |
|
238 |
return found_flags
|
239 |
|
240 |
-
def get_missing_vital_questions(self, state: MedicalState) -> List[str]:
|
241 |
-
"""Determine which vital questions haven't been asked"""
|
242 |
-
asked = state["vital_questions_asked"]
|
243 |
-
return [q for q in VITAL_QUESTIONS.keys() if q not in asked]
|
244 |
-
|
245 |
-
def create_patient_summary(self, state: MedicalState) -> str:
|
246 |
-
"""Create a comprehensive patient summary"""
|
247 |
-
summary = f"""
|
248 |
-
Patient Summary:
|
249 |
-
Symptoms: {json.dumps(state['symptoms'], indent=2)}
|
250 |
-
Medical History: {state['medical_history']}
|
251 |
-
Current Medications: {state['current_medications']}
|
252 |
-
Allergies: {state['allergies']}
|
253 |
-
Severity Scores: {state['severity_scores']}
|
254 |
-
Conversation History: {[msg['content'] for msg in state['conversation_history'][-3:]]}
|
255 |
-
"""
|
256 |
-
return summary
|
257 |
-
|
258 |
-
def get_meditron_recommendations(self, patient_summary: str) -> List[str]:
|
259 |
-
"""Get medical recommendations using Meditron model"""
|
260 |
-
prompt = f"""
|
261 |
-
Based on the following patient information, provide:
|
262 |
-
1. Specific over-the-counter medications with dosing
|
263 |
-
2. Home remedies and self-care measures
|
264 |
-
3. When to seek professional medical care
|
265 |
-
4. Follow-up recommendations
|
266 |
-
|
267 |
-
Patient Information:
|
268 |
-
{patient_summary}
|
269 |
-
|
270 |
-
Response:"""
|
271 |
-
|
272 |
-
inputs = self.meditron_tokenizer(prompt, return_tensors="pt").to(self.meditron_model.device)
|
273 |
-
|
274 |
-
with torch.no_grad():
|
275 |
-
outputs = self.meditron_model.generate(
|
276 |
-
inputs.input_ids,
|
277 |
-
attention_mask=inputs.attention_mask,
|
278 |
-
max_new_tokens=400,
|
279 |
-
temperature=0.7,
|
280 |
-
top_p=0.9,
|
281 |
-
do_sample=True
|
282 |
-
)
|
283 |
-
|
284 |
-
recommendation = self.meditron_tokenizer.decode(
|
285 |
-
outputs[0][inputs.input_ids.shape[1]:],
|
286 |
-
skip_special_tokens=True
|
287 |
-
)
|
288 |
-
|
289 |
-
return [recommendation]
|
290 |
-
|
291 |
def generate_response(self, message: str, history: List) -> str:
|
292 |
"""Main response generation function"""
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
state = MedicalState(
|
295 |
-
patient_id=
|
296 |
conversation_history=history + [{"role": "user", "content": message}],
|
297 |
symptoms={},
|
298 |
vital_questions_asked=[],
|
@@ -303,130 +344,172 @@ class EnhancedMedicalAssistant:
|
|
303 |
red_flags=[],
|
304 |
assessment_complete=False,
|
305 |
suggested_actions=[],
|
306 |
-
consultation_stage="intake"
|
|
|
|
|
307 |
)
|
308 |
|
309 |
-
#
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
return self.format_emergency_response(state)
|
321 |
-
elif state["consultation_stage"] == "intake":
|
322 |
-
return self.format_intake_response(state, user_message)
|
323 |
-
elif state["consultation_stage"] == "assessment":
|
324 |
-
return self.format_assessment_response(state)
|
325 |
-
elif state["consultation_stage"] == "recommendations":
|
326 |
-
return self.format_recommendations_response(state)
|
327 |
-
else:
|
328 |
-
return self.format_default_response(user_message)
|
329 |
-
|
330 |
-
def format_emergency_response(self, state: MedicalState) -> str:
|
331 |
-
"""Format emergency response"""
|
332 |
-
return f"""
|
333 |
-
🚨 URGENT MEDICAL ATTENTION NEEDED 🚨
|
334 |
-
|
335 |
-
Based on your symptoms, I recommend seeking immediate medical care because:
|
336 |
-
{', '.join(state['red_flags'])}
|
337 |
-
|
338 |
-
Please:
|
339 |
-
- Go to the nearest emergency room, OR
|
340 |
-
- Call emergency services (911), OR
|
341 |
-
- Contact your doctor immediately
|
342 |
-
|
343 |
-
This is not a diagnosis, but these symptoms warrant immediate professional evaluation.
|
344 |
-
"""
|
345 |
|
346 |
-
def
|
347 |
-
"""
|
348 |
-
#
|
349 |
-
|
350 |
-
|
351 |
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
return self.generate_llama_response(prompt)
|
361 |
|
362 |
-
def format_assessment_response(self, state: MedicalState) -> str:
|
363 |
-
"""Format detailed assessment response"""
|
364 |
-
return "Let me gather a bit more information to better understand your condition..."
|
365 |
-
|
366 |
-
def format_recommendations_response(self, state: MedicalState) -> str:
|
367 |
-
"""Format final recommendations"""
|
368 |
-
recommendations = "\n".join(state["suggested_actions"])
|
369 |
-
return f"""
|
370 |
-
Based on our consultation, here's my assessment and recommendations:
|
371 |
-
|
372 |
-
{recommendations}
|
373 |
-
|
374 |
-
**Important Disclaimer:** I am an AI assistant, not a licensed medical professional.
|
375 |
-
These suggestions are for informational purposes only. Please consult with a
|
376 |
-
healthcare provider for proper diagnosis and treatment.
|
377 |
-
"""
|
378 |
-
|
379 |
-
def format_default_response(self, user_message: str) -> str:
|
380 |
-
"""Format default response"""
|
381 |
-
return self.generate_llama_response(f"Respond professionally to: {user_message}")
|
382 |
-
|
383 |
def generate_llama_response(self, prompt: str) -> str:
|
384 |
-
"""Generate response using Llama-2"""
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
inputs.
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
|
402 |
-
# Initialize the
|
403 |
medical_assistant = EnhancedMedicalAssistant()
|
404 |
|
405 |
-
@spaces.GPU
|
406 |
def chat_interface(message, history):
|
407 |
"""Gradio chat interface"""
|
408 |
-
|
|
|
|
|
|
|
|
|
409 |
|
410 |
# Create Gradio interface
|
411 |
demo = gr.ChatInterface(
|
412 |
fn=chat_interface,
|
413 |
-
title="🏥
|
414 |
description="""
|
415 |
-
I'm an AI medical assistant
|
416 |
-
I
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
⚠️ **Important**: I'm not a replacement for professional medical care. Always consult healthcare providers for serious concerns.
|
419 |
""",
|
420 |
examples=[
|
421 |
-
"I
|
422 |
-
"I
|
423 |
-
"
|
424 |
-
"I have
|
425 |
],
|
426 |
theme="soft",
|
427 |
css="""
|
428 |
-
.message.user {
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
"""
|
431 |
)
|
432 |
|
|
|
6 |
from typing import TypedDict, List, Dict, Optional
|
7 |
from datetime import datetime
|
8 |
import json
|
9 |
+
import re
|
10 |
+
import numpy as np
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
import faiss
|
13 |
+
import pickle
|
14 |
+
import os
|
15 |
|
16 |
+
# Enhanced State Management with RAG
|
17 |
class MedicalState(TypedDict):
|
18 |
patient_id: str
|
19 |
conversation_history: List[Dict]
|
|
|
26 |
red_flags: List[str]
|
27 |
assessment_complete: bool
|
28 |
suggested_actions: List[str]
|
29 |
+
consultation_stage: str
|
30 |
+
retrieved_knowledge: List[Dict]
|
31 |
+
confidence_scores: Dict[str, float]
|
32 |
+
|
33 |
+
# Medical Knowledge Base for RAG
|
34 |
+
MEDICAL_KNOWLEDGE_BASE = {
|
35 |
+
"conditions": {
|
36 |
+
"common_cold": {
|
37 |
+
"symptoms": ["runny nose", "cough", "sneezing", "sore throat", "mild fever"],
|
38 |
+
"treatment": "Rest, fluids, OTC pain relievers",
|
39 |
+
"otc_medications": [
|
40 |
+
{"name": "Acetaminophen", "dose": "500-1000mg every 4-6 hours", "max_daily": "3000mg"},
|
41 |
+
{"name": "Ibuprofen", "dose": "200-400mg every 4-6 hours", "max_daily": "1200mg"}
|
42 |
+
],
|
43 |
+
"home_remedies": ["Warm salt water gargle", "Honey and lemon tea", "Steam inhalation"],
|
44 |
+
"when_to_seek_care": "If symptoms worsen after 7-10 days or fever above 101.3°F"
|
45 |
+
},
|
46 |
+
"headache": {
|
47 |
+
"symptoms": ["head pain", "pressure", "throbbing"],
|
48 |
+
"treatment": "Pain relief, rest, hydration",
|
49 |
+
"otc_medications": [
|
50 |
+
{"name": "Acetaminophen", "dose": "500-1000mg every 4-6 hours", "max_daily": "3000mg"},
|
51 |
+
{"name": "Ibuprofen", "dose": "400-600mg every 6-8 hours", "max_daily": "1200mg"}
|
52 |
+
],
|
53 |
+
"home_remedies": ["Cold or warm compress", "Rest in dark room", "Stay hydrated"],
|
54 |
+
"when_to_seek_care": "Sudden severe headache, fever, neck stiffness, vision changes"
|
55 |
+
},
|
56 |
+
"stomach_pain": {
|
57 |
+
"symptoms": ["abdominal pain", "nausea", "bloating", "cramps"],
|
58 |
+
"treatment": "Bland diet, rest, hydration",
|
59 |
+
"otc_medications": [
|
60 |
+
{"name": "Pepto-Bismol", "dose": "525mg every 30 minutes as needed", "max_daily": "8 doses"},
|
61 |
+
{"name": "TUMS", "dose": "2-4 tablets as needed", "max_daily": "15 tablets"}
|
62 |
+
],
|
63 |
+
"home_remedies": ["BRAT diet", "Ginger tea", "Warm compress on stomach"],
|
64 |
+
"when_to_seek_care": "Severe pain, fever, vomiting, blood in stool"
|
65 |
+
}
|
66 |
+
}
|
67 |
+
}
|
68 |
|
|
|
69 |
MEDICAL_CATEGORIES = {
|
70 |
+
"respiratory": ["cough", "shortness of breath", "chest pain", "wheezing", "runny nose", "sore throat"],
|
71 |
+
"gastrointestinal": ["nausea", "vomiting", "diarrhea", "stomach pain", "heartburn", "bloating"],
|
72 |
"neurological": ["headache", "dizziness", "numbness", "tingling"],
|
73 |
"musculoskeletal": ["joint pain", "muscle pain", "back pain", "stiffness"],
|
74 |
"cardiovascular": ["chest pain", "palpitations", "swelling", "fatigue"],
|
|
|
82 |
"sudden vision changes", "loss of consciousness", "severe allergic reaction"
|
83 |
]
|
84 |
|
85 |
+
class SimpleRAGSystem:
|
86 |
+
def __init__(self):
|
87 |
+
self.knowledge_base = MEDICAL_KNOWLEDGE_BASE
|
88 |
+
self.setup_simple_retrieval()
|
89 |
+
|
90 |
+
def setup_simple_retrieval(self):
|
91 |
+
"""Setup simple keyword-based retrieval system"""
|
92 |
+
self.symptom_to_condition = {}
|
93 |
+
|
94 |
+
for condition, data in self.knowledge_base["conditions"].items():
|
95 |
+
for symptom in data["symptoms"]:
|
96 |
+
if symptom not in self.symptom_to_condition:
|
97 |
+
self.symptom_to_condition[symptom] = []
|
98 |
+
self.symptom_to_condition[symptom].append(condition)
|
99 |
+
|
100 |
+
def retrieve_relevant_conditions(self, symptoms: List[str]) -> List[Dict]:
|
101 |
+
"""Retrieve relevant medical conditions based on symptoms"""
|
102 |
+
relevant_conditions = {}
|
103 |
+
|
104 |
+
for symptom in symptoms:
|
105 |
+
symptom_lower = symptom.lower()
|
106 |
+
|
107 |
+
# Direct match
|
108 |
+
if symptom_lower in self.symptom_to_condition:
|
109 |
+
for condition in self.symptom_to_condition[symptom_lower]:
|
110 |
+
if condition not in relevant_conditions:
|
111 |
+
relevant_conditions[condition] = self.knowledge_base["conditions"][condition]
|
112 |
+
|
113 |
+
# Partial match
|
114 |
+
for kb_symptom, conditions in self.symptom_to_condition.items():
|
115 |
+
if symptom_lower in kb_symptom or kb_symptom in symptom_lower:
|
116 |
+
for condition in conditions:
|
117 |
+
if condition not in relevant_conditions:
|
118 |
+
relevant_conditions[condition] = self.knowledge_base["conditions"][condition]
|
119 |
+
|
120 |
+
return [{"condition": k, "data": v} for k, v in relevant_conditions.items()]
|
121 |
|
122 |
class EnhancedMedicalAssistant:
|
123 |
def __init__(self):
|
124 |
self.load_models()
|
125 |
+
self.rag_system = SimpleRAGSystem()
|
126 |
self.setup_langgraph()
|
127 |
+
self.conversation_count = {}
|
128 |
|
129 |
def load_models(self):
|
130 |
"""Load the AI models"""
|
131 |
print("Loading models...")
|
132 |
+
try:
|
133 |
+
# Llama-2 for conversation
|
134 |
+
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
135 |
+
if self.tokenizer.pad_token is None:
|
136 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
137 |
+
|
138 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
139 |
+
"meta-llama/Llama-2-7b-chat-hf",
|
140 |
+
torch_dtype=torch.float16,
|
141 |
+
device_map="auto"
|
142 |
+
)
|
143 |
+
|
144 |
+
# Meditron for medical suggestions
|
145 |
+
self.meditron_tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b")
|
146 |
+
if self.meditron_tokenizer.pad_token is None:
|
147 |
+
self.meditron_tokenizer.pad_token = self.meditron_tokenizer.eos_token
|
148 |
+
|
149 |
+
self.meditron_model = AutoModelForCausalLM.from_pretrained(
|
150 |
+
"epfl-llm/meditron-7b",
|
151 |
+
torch_dtype=torch.float16,
|
152 |
+
device_map="auto"
|
153 |
+
)
|
154 |
+
print("Models loaded successfully!")
|
155 |
+
except Exception as e:
|
156 |
+
print(f"Error loading models: {e}")
|
157 |
+
# Fallback - use only one model
|
158 |
+
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
159 |
+
self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
160 |
+
self.meditron_tokenizer = self.tokenizer
|
161 |
+
self.meditron_model = self.model
|
162 |
|
163 |
def setup_langgraph(self):
|
164 |
+
"""Setup simplified LangGraph workflow"""
|
165 |
workflow = StateGraph(MedicalState)
|
166 |
|
|
|
167 |
workflow.add_node("intake", self.patient_intake)
|
|
|
|
|
168 |
workflow.add_node("generate_recommendations", self.generate_recommendations)
|
169 |
workflow.add_node("emergency_triage", self.emergency_triage)
|
170 |
|
|
|
171 |
workflow.set_entry_point("intake")
|
172 |
workflow.add_conditional_edges(
|
173 |
"intake",
|
174 |
self.route_after_intake,
|
175 |
{
|
|
|
176 |
"emergency": "emergency_triage",
|
177 |
+
"recommendations": "generate_recommendations"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
}
|
179 |
)
|
180 |
workflow.add_edge("generate_recommendations", END)
|
|
|
183 |
self.workflow = workflow.compile()
|
184 |
|
185 |
def patient_intake(self, state: MedicalState) -> MedicalState:
|
186 |
+
"""Enhanced patient intake with RAG"""
|
187 |
last_message = state["conversation_history"][-1]["content"] if state["conversation_history"] else ""
|
188 |
|
189 |
+
# Extract symptoms
|
190 |
detected_symptoms = self.extract_symptoms(last_message)
|
191 |
state["symptoms"].update(detected_symptoms)
|
192 |
|
193 |
+
# Use RAG to get relevant medical knowledge
|
194 |
+
if detected_symptoms:
|
195 |
+
symptom_names = list(detected_symptoms.keys())
|
196 |
+
relevant_conditions = self.rag_system.retrieve_relevant_conditions(symptom_names)
|
197 |
+
state["retrieved_knowledge"] = relevant_conditions
|
198 |
+
|
199 |
# Check for red flags
|
200 |
red_flags = self.check_red_flags(last_message)
|
201 |
+
state["red_flags"].extend(red_flags)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
+
# Determine consultation stage
|
204 |
+
if red_flags:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
state["consultation_stage"] = "emergency"
|
206 |
else:
|
207 |
state["consultation_stage"] = "recommendations"
|
|
|
209 |
return state
|
210 |
|
211 |
def generate_recommendations(self, state: MedicalState) -> MedicalState:
|
212 |
+
"""Generate RAG-enhanced recommendations"""
|
213 |
+
# Create structured recommendations from RAG knowledge
|
214 |
+
recommendations = self.create_structured_recommendations(state)
|
|
|
|
|
215 |
state["suggested_actions"] = recommendations
|
|
|
216 |
return state
|
217 |
|
218 |
+
def create_structured_recommendations(self, state: MedicalState) -> List[str]:
|
219 |
+
"""Create structured recommendations using RAG knowledge"""
|
220 |
+
recommendations = []
|
221 |
+
|
222 |
+
if not state["retrieved_knowledge"]:
|
223 |
+
recommendations.append("I need more specific information about your symptoms to provide targeted recommendations.")
|
224 |
+
return recommendations
|
225 |
+
|
226 |
+
# Process each relevant condition
|
227 |
+
for knowledge_item in state["retrieved_knowledge"][:2]: # Limit to top 2 conditions
|
228 |
+
condition = knowledge_item["condition"]
|
229 |
+
data = knowledge_item["data"]
|
230 |
+
|
231 |
+
# Format condition information
|
232 |
+
condition_info = f"\n**Possible Condition: {condition.replace('_', ' ').title()}**\n"
|
233 |
+
|
234 |
+
# Add medications
|
235 |
+
if "otc_medications" in data:
|
236 |
+
condition_info += "\n**💊 Over-the-Counter Medications:**\n"
|
237 |
+
for med in data["otc_medications"]:
|
238 |
+
condition_info += f"• **{med['name']}**: {med['dose']} (Max daily: {med['max_daily']})\n"
|
239 |
+
|
240 |
+
# Add home remedies
|
241 |
+
if "home_remedies" in data:
|
242 |
+
condition_info += "\n**🏠 Home Remedies:**\n"
|
243 |
+
for remedy in data["home_remedies"]:
|
244 |
+
condition_info += f"• {remedy}\n"
|
245 |
+
|
246 |
+
# Add when to seek care
|
247 |
+
if "when_to_seek_care" in data:
|
248 |
+
condition_info += f"\n**⚠️ Seek Medical Care If:** {data['when_to_seek_care']}\n"
|
249 |
+
|
250 |
+
recommendations.append(condition_info)
|
251 |
+
|
252 |
+
# Add general advice
|
253 |
+
recommendations.append("""
|
254 |
+
**📋 General Recommendations:**
|
255 |
+
• Monitor your symptoms for any changes
|
256 |
+
• Stay hydrated and get adequate rest
|
257 |
+
• Follow medication instructions carefully
|
258 |
+
• Don't exceed recommended dosages
|
259 |
+
|
260 |
+
**🚨 Emergency Warning Signs:**
|
261 |
+
• Severe worsening of symptoms
|
262 |
+
• High fever (>101.3°F/38.5°C)
|
263 |
+
• Difficulty breathing
|
264 |
+
• Severe pain
|
265 |
+
• Signs of dehydration
|
266 |
+
""")
|
267 |
+
|
268 |
+
return recommendations
|
269 |
+
|
270 |
def emergency_triage(self, state: MedicalState) -> MedicalState:
|
271 |
"""Handle emergency situations"""
|
272 |
+
emergency_response = f"""
|
273 |
+
🚨 **URGENT MEDICAL ATTENTION NEEDED** 🚨
|
274 |
+
|
275 |
+
Based on your symptoms, I strongly recommend seeking immediate medical care because you mentioned: {', '.join(state['red_flags'])}
|
276 |
+
|
277 |
+
**Immediate Actions:**
|
278 |
+
• Go to the nearest emergency room, OR
|
279 |
+
• Call emergency services (911), OR
|
280 |
+
• Contact your doctor immediately
|
281 |
+
|
282 |
+
**Why This is Urgent:**
|
283 |
+
These symptoms can indicate serious conditions that require professional medical evaluation and treatment.
|
284 |
+
|
285 |
+
⚠️ **Disclaimer:** This is not a medical diagnosis, but these symptoms warrant immediate professional assessment.
|
286 |
+
"""
|
287 |
+
|
288 |
state["suggested_actions"] = [emergency_response]
|
289 |
return state
|
290 |
|
|
|
292 |
"""Route decision after intake"""
|
293 |
if state["red_flags"]:
|
294 |
return "emergency"
|
|
|
|
|
295 |
else:
|
296 |
+
return "recommendations"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
def extract_symptoms(self, text: str) -> Dict:
|
299 |
"""Extract and categorize symptoms from patient text"""
|
|
|
322 |
|
323 |
return found_flags
|
324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
def generate_response(self, message: str, history: List) -> str:
|
326 |
"""Main response generation function"""
|
327 |
+
session_id = "default_session"
|
328 |
+
|
329 |
+
# Track conversation count
|
330 |
+
if session_id not in self.conversation_count:
|
331 |
+
self.conversation_count[session_id] = 0
|
332 |
+
self.conversation_count[session_id] += 1
|
333 |
+
|
334 |
+
# Initialize state
|
335 |
state = MedicalState(
|
336 |
+
patient_id=session_id,
|
337 |
conversation_history=history + [{"role": "user", "content": message}],
|
338 |
symptoms={},
|
339 |
vital_questions_asked=[],
|
|
|
344 |
red_flags=[],
|
345 |
assessment_complete=False,
|
346 |
suggested_actions=[],
|
347 |
+
consultation_stage="intake",
|
348 |
+
retrieved_knowledge=[],
|
349 |
+
confidence_scores={}
|
350 |
)
|
351 |
|
352 |
+
# For first few messages, do conversational intake
|
353 |
+
if self.conversation_count[session_id] <= 3:
|
354 |
+
return self.generate_conversational_response(message, history)
|
355 |
+
|
356 |
+
# After gathering info, run workflow for recommendations
|
357 |
+
try:
|
358 |
+
result = self.workflow.invoke(state)
|
359 |
+
return self.format_final_response(result)
|
360 |
+
except Exception as e:
|
361 |
+
print(f"Workflow error: {e}")
|
362 |
+
return self.generate_conversational_response(message, history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
+
def generate_conversational_response(self, message: str, history: List) -> str:
|
365 |
+
"""Generate conversational response for intake phase"""
|
366 |
+
# Extract symptoms for context
|
367 |
+
symptoms = self.extract_symptoms(message)
|
368 |
+
red_flags = self.check_red_flags(message)
|
369 |
|
370 |
+
# Handle emergencies immediately
|
371 |
+
if red_flags:
|
372 |
+
return f"""
|
373 |
+
🚨 **URGENT MEDICAL ATTENTION NEEDED** 🚨
|
374 |
+
|
375 |
+
I notice you mentioned: {', '.join(red_flags)}
|
376 |
+
|
377 |
+
Please seek immediate medical care:
|
378 |
+
• Go to the nearest emergency room
|
379 |
+
• Call emergency services (911)
|
380 |
+
• Contact your doctor immediately
|
381 |
+
|
382 |
+
These symptoms require professional medical evaluation right away.
|
383 |
+
"""
|
384 |
+
|
385 |
+
# Generate contextual questions based on symptoms
|
386 |
+
if symptoms:
|
387 |
+
symptom_names = list(symptoms.keys())
|
388 |
+
prompt = f"""
|
389 |
+
You are a caring medical assistant. The patient mentioned these symptoms: {', '.join(symptom_names)}.
|
390 |
+
|
391 |
+
Respond empathetically and ask 1-2 relevant follow-up questions about:
|
392 |
+
- How long they've had these symptoms
|
393 |
+
- Severity (mild, moderate, severe)
|
394 |
+
- What makes it better or worse
|
395 |
+
- Any other symptoms they're experiencing
|
396 |
+
|
397 |
+
Be professional, caring, and concise. Don't provide treatment advice yet.
|
398 |
+
"""
|
399 |
+
else:
|
400 |
+
prompt = f"""
|
401 |
+
You are a caring medical assistant. The patient said: "{message}"
|
402 |
+
|
403 |
+
Respond empathetically and ask relevant questions to understand their health concern better.
|
404 |
+
Be professional and caring.
|
405 |
+
"""
|
406 |
|
407 |
return self.generate_llama_response(prompt)
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
def generate_llama_response(self, prompt: str) -> str:
|
410 |
+
"""Generate response using Llama-2 with better formatting"""
|
411 |
+
try:
|
412 |
+
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
|
413 |
+
inputs = self.tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
|
414 |
+
|
415 |
+
if torch.cuda.is_available():
|
416 |
+
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
417 |
+
|
418 |
+
with torch.no_grad():
|
419 |
+
outputs = self.model.generate(
|
420 |
+
**inputs,
|
421 |
+
max_new_tokens=200,
|
422 |
+
temperature=0.7,
|
423 |
+
top_p=0.9,
|
424 |
+
do_sample=True,
|
425 |
+
pad_token_id=self.tokenizer.eos_token_id
|
426 |
+
)
|
427 |
+
|
428 |
+
# Decode response
|
429 |
+
response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
430 |
+
|
431 |
+
# Clean up the response
|
432 |
+
response = response.split('</s>')[0].strip()
|
433 |
+
response = response.replace('<s>', '').replace('[INST]', '').replace('[/INST]', '').strip()
|
434 |
+
|
435 |
+
# Remove any XML-like tags
|
436 |
+
response = re.sub(r'<[^>]+>', '', response)
|
437 |
+
|
438 |
+
return response if response else "I understand your concern. Can you tell me more about what you're experiencing?"
|
439 |
+
|
440 |
+
except Exception as e:
|
441 |
+
print(f"Error generating response: {e}")
|
442 |
+
return "I understand your concern. Can you tell me more about your symptoms?"
|
443 |
+
|
444 |
+
def format_final_response(self, state: MedicalState) -> str:
|
445 |
+
"""Format the final response with recommendations"""
|
446 |
+
if state["consultation_stage"] == "emergency":
|
447 |
+
return state["suggested_actions"][0] if state["suggested_actions"] else "Please seek immediate medical attention."
|
448 |
+
|
449 |
+
# Format recommendations nicely
|
450 |
+
if state["suggested_actions"]:
|
451 |
+
response = "## 🏥 Medical Assessment & Recommendations\n\n"
|
452 |
+
response += "Based on our conversation, here's what I recommend:\n"
|
453 |
+
|
454 |
+
for action in state["suggested_actions"]:
|
455 |
+
response += f"{action}\n"
|
456 |
+
|
457 |
+
response += "\n---\n"
|
458 |
+
response += "**Important Disclaimer:** I'm an AI assistant providing general health information. "
|
459 |
+
response += "This is not a substitute for professional medical advice, diagnosis, or treatment. "
|
460 |
+
response += "Always consult with qualified healthcare providers for medical concerns."
|
461 |
+
|
462 |
+
return response
|
463 |
+
else:
|
464 |
+
return "Please provide more details about your symptoms so I can offer better guidance."
|
465 |
|
466 |
+
# Initialize the medical assistant
|
467 |
medical_assistant = EnhancedMedicalAssistant()
|
468 |
|
469 |
+
@spaces.GPU
|
470 |
def chat_interface(message, history):
|
471 |
"""Gradio chat interface"""
|
472 |
+
try:
|
473 |
+
return medical_assistant.generate_response(message, history)
|
474 |
+
except Exception as e:
|
475 |
+
print(f"Chat interface error: {e}")
|
476 |
+
return f"I apologize, but I encountered an error. Please try rephrasing your question. Error: {str(e)}"
|
477 |
|
478 |
# Create Gradio interface
|
479 |
demo = gr.ChatInterface(
|
480 |
fn=chat_interface,
|
481 |
+
title="🏥 Medical AI Assistant with medRAG",
|
482 |
description="""
|
483 |
+
I'm an AI medical assistant powered by medical knowledge retrieval (medRAG).
|
484 |
+
I can help assess your symptoms and provide evidence-based recommendations.
|
485 |
+
|
486 |
+
**How it works:**
|
487 |
+
1. Tell me about your symptoms
|
488 |
+
2. I'll ask follow-up questions
|
489 |
+
3. I'll provide personalized recommendations based on medical knowledge
|
490 |
|
491 |
⚠️ **Important**: I'm not a replacement for professional medical care. Always consult healthcare providers for serious concerns.
|
492 |
""",
|
493 |
examples=[
|
494 |
+
"I have a bad cough and sore throat",
|
495 |
+
"I've been having headaches for the past few days",
|
496 |
+
"My stomach has been hurting after meals",
|
497 |
+
"I have chest pain and trouble breathing"
|
498 |
],
|
499 |
theme="soft",
|
500 |
css="""
|
501 |
+
.message.user {
|
502 |
+
background-color: #e3f2fd;
|
503 |
+
border-radius: 10px;
|
504 |
+
padding: 10px;
|
505 |
+
margin: 5px;
|
506 |
+
}
|
507 |
+
.message.bot {
|
508 |
+
background-color: #f1f8e9;
|
509 |
+
border-radius: 10px;
|
510 |
+
padding: 10px;
|
511 |
+
margin: 5px;
|
512 |
+
}
|
513 |
"""
|
514 |
)
|
515 |
|