techindia2025 commited on
Commit
0fc7323
·
verified ·
1 Parent(s): bdce857

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +351 -268
app.py CHANGED
@@ -6,8 +6,14 @@ from langgraph.graph import StateGraph, END
6
  from typing import TypedDict, List, Dict, Optional
7
  from datetime import datetime
8
  import json
 
 
 
 
 
 
9
 
10
- # Enhanced State Management
11
  class MedicalState(TypedDict):
12
  patient_id: str
13
  conversation_history: List[Dict]
@@ -20,12 +26,49 @@ class MedicalState(TypedDict):
20
  red_flags: List[str]
21
  assessment_complete: bool
22
  suggested_actions: List[str]
23
- consultation_stage: str # intake, assessment, summary, recommendations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Medical Knowledge Base
26
  MEDICAL_CATEGORIES = {
27
- "respiratory": ["cough", "shortness of breath", "chest pain", "wheezing"],
28
- "gastrointestinal": ["nausea", "vomiting", "diarrhea", "stomach pain", "heartburn"],
29
  "neurological": ["headache", "dizziness", "numbness", "tingling"],
30
  "musculoskeletal": ["joint pain", "muscle pain", "back pain", "stiffness"],
31
  "cardiovascular": ["chest pain", "palpitations", "swelling", "fatigue"],
@@ -39,71 +82,99 @@ RED_FLAGS = [
39
  "sudden vision changes", "loss of consciousness", "severe allergic reaction"
40
  ]
41
 
42
- VITAL_QUESTIONS = {
43
- "symptom_onset": "When did your symptoms first start?",
44
- "severity": "On a scale of 1-10, how severe would you rate your symptoms?",
45
- "triggers": "What makes your symptoms better or worse?",
46
- "associated_symptoms": "Are you experiencing any other symptoms?",
47
- "medical_history": "Do you have any chronic medical conditions?",
48
- "medications": "Are you currently taking any medications?",
49
- "allergies": "Do you have any known allergies?"
50
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  class EnhancedMedicalAssistant:
53
  def __init__(self):
54
  self.load_models()
 
55
  self.setup_langgraph()
 
56
 
57
  def load_models(self):
58
  """Load the AI models"""
59
  print("Loading models...")
60
- # Llama-2 for conversation
61
- self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
62
- self.model = AutoModelForCausalLM.from_pretrained(
63
- "meta-llama/Llama-2-7b-chat-hf",
64
- torch_dtype=torch.float16,
65
- device_map="auto"
66
- )
67
-
68
- # Meditron for medical suggestions
69
- self.meditron_tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b")
70
- self.meditron_model = AutoModelForCausalLM.from_pretrained(
71
- "epfl-llm/meditron-7b",
72
- torch_dtype=torch.float16,
73
- device_map="auto"
74
- )
75
- print("Models loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def setup_langgraph(self):
78
- """Setup LangGraph workflow"""
79
  workflow = StateGraph(MedicalState)
80
 
81
- # Add nodes
82
  workflow.add_node("intake", self.patient_intake)
83
- workflow.add_node("symptom_assessment", self.assess_symptoms)
84
- workflow.add_node("risk_evaluation", self.evaluate_risks)
85
  workflow.add_node("generate_recommendations", self.generate_recommendations)
86
  workflow.add_node("emergency_triage", self.emergency_triage)
87
 
88
- # Define edges
89
  workflow.set_entry_point("intake")
90
  workflow.add_conditional_edges(
91
  "intake",
92
  self.route_after_intake,
93
  {
94
- "continue_assessment": "symptom_assessment",
95
  "emergency": "emergency_triage",
96
- "complete": "generate_recommendations"
97
- }
98
- )
99
- workflow.add_edge("symptom_assessment", "risk_evaluation")
100
- workflow.add_conditional_edges(
101
- "risk_evaluation",
102
- self.route_after_risk_eval,
103
- {
104
- "emergency": "emergency_triage",
105
- "continue": "generate_recommendations",
106
- "need_more_info": "symptom_assessment"
107
  }
108
  )
109
  workflow.add_edge("generate_recommendations", END)
@@ -112,59 +183,25 @@ class EnhancedMedicalAssistant:
112
  self.workflow = workflow.compile()
113
 
114
  def patient_intake(self, state: MedicalState) -> MedicalState:
115
- """Initial patient intake and basic information gathering"""
116
  last_message = state["conversation_history"][-1]["content"] if state["conversation_history"] else ""
117
 
118
- # Extract symptoms and categorize them
119
  detected_symptoms = self.extract_symptoms(last_message)
120
  state["symptoms"].update(detected_symptoms)
121
 
 
 
 
 
 
 
122
  # Check for red flags
123
  red_flags = self.check_red_flags(last_message)
124
- if red_flags:
125
- state["red_flags"].extend(red_flags)
126
-
127
- # Determine what vital questions still need to be asked
128
- missing_questions = self.get_missing_vital_questions(state)
129
-
130
- if missing_questions and len(state["conversation_history"]) < 6:
131
- state["consultation_stage"] = "intake"
132
- return state
133
- else:
134
- state["consultation_stage"] = "assessment"
135
- return state
136
-
137
- def assess_symptoms(self, state: MedicalState) -> MedicalState:
138
- """Detailed symptom assessment"""
139
- # Analyze symptom patterns and severity
140
- for symptom, details in state["symptoms"].items():
141
- if "severity" not in details:
142
- # Need to ask about severity
143
- state["consultation_stage"] = "assessment"
144
- return state
145
 
146
- state["assessment_complete"] = True
147
- return state
148
-
149
- def evaluate_risks(self, state: MedicalState) -> MedicalState:
150
- """Evaluate patient risks and urgency"""
151
- risk_score = 0
152
-
153
- # Check red flags
154
- if state["red_flags"]:
155
- risk_score += len(state["red_flags"]) * 3
156
-
157
- # Check severity scores
158
- for severity in state["severity_scores"].values():
159
- if severity >= 8:
160
- risk_score += 2
161
- elif severity >= 6:
162
- risk_score += 1
163
-
164
- # Check symptom duration and progression
165
- # (Implementation would analyze timeline)
166
-
167
- if risk_score >= 5:
168
  state["consultation_stage"] = "emergency"
169
  else:
170
  state["consultation_stage"] = "recommendations"
@@ -172,23 +209,82 @@ class EnhancedMedicalAssistant:
172
  return state
173
 
174
  def generate_recommendations(self, state: MedicalState) -> MedicalState:
175
- """Generate treatment recommendations and care suggestions"""
176
- patient_summary = self.create_patient_summary(state)
177
-
178
- # Use Meditron for medical recommendations
179
- recommendations = self.get_meditron_recommendations(patient_summary)
180
  state["suggested_actions"] = recommendations
181
-
182
  return state
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def emergency_triage(self, state: MedicalState) -> MedicalState:
185
  """Handle emergency situations"""
186
- emergency_response = {
187
- "urgent_care_needed": True,
188
- "recommended_action": "Seek immediate medical attention",
189
- "reasons": state["red_flags"],
190
- "instructions": "Go to the nearest emergency room or call emergency services"
191
- }
 
 
 
 
 
 
 
 
 
 
192
  state["suggested_actions"] = [emergency_response]
193
  return state
194
 
@@ -196,19 +292,8 @@ class EnhancedMedicalAssistant:
196
  """Route decision after intake"""
197
  if state["red_flags"]:
198
  return "emergency"
199
- elif len(state["vital_questions_asked"]) < 5:
200
- return "continue_assessment"
201
  else:
202
- return "complete"
203
-
204
- def route_after_risk_eval(self, state: MedicalState):
205
- """Route decision after risk evaluation"""
206
- if state["consultation_stage"] == "emergency":
207
- return "emergency"
208
- elif state["assessment_complete"]:
209
- return "continue"
210
- else:
211
- return "need_more_info"
212
 
213
  def extract_symptoms(self, text: str) -> Dict:
214
  """Extract and categorize symptoms from patient text"""
@@ -237,62 +322,18 @@ class EnhancedMedicalAssistant:
237
 
238
  return found_flags
239
 
240
- def get_missing_vital_questions(self, state: MedicalState) -> List[str]:
241
- """Determine which vital questions haven't been asked"""
242
- asked = state["vital_questions_asked"]
243
- return [q for q in VITAL_QUESTIONS.keys() if q not in asked]
244
-
245
- def create_patient_summary(self, state: MedicalState) -> str:
246
- """Create a comprehensive patient summary"""
247
- summary = f"""
248
- Patient Summary:
249
- Symptoms: {json.dumps(state['symptoms'], indent=2)}
250
- Medical History: {state['medical_history']}
251
- Current Medications: {state['current_medications']}
252
- Allergies: {state['allergies']}
253
- Severity Scores: {state['severity_scores']}
254
- Conversation History: {[msg['content'] for msg in state['conversation_history'][-3:]]}
255
- """
256
- return summary
257
-
258
- def get_meditron_recommendations(self, patient_summary: str) -> List[str]:
259
- """Get medical recommendations using Meditron model"""
260
- prompt = f"""
261
- Based on the following patient information, provide:
262
- 1. Specific over-the-counter medications with dosing
263
- 2. Home remedies and self-care measures
264
- 3. When to seek professional medical care
265
- 4. Follow-up recommendations
266
-
267
- Patient Information:
268
- {patient_summary}
269
-
270
- Response:"""
271
-
272
- inputs = self.meditron_tokenizer(prompt, return_tensors="pt").to(self.meditron_model.device)
273
-
274
- with torch.no_grad():
275
- outputs = self.meditron_model.generate(
276
- inputs.input_ids,
277
- attention_mask=inputs.attention_mask,
278
- max_new_tokens=400,
279
- temperature=0.7,
280
- top_p=0.9,
281
- do_sample=True
282
- )
283
-
284
- recommendation = self.meditron_tokenizer.decode(
285
- outputs[0][inputs.input_ids.shape[1]:],
286
- skip_special_tokens=True
287
- )
288
-
289
- return [recommendation]
290
-
291
  def generate_response(self, message: str, history: List) -> str:
292
  """Main response generation function"""
293
- # Initialize or update state
 
 
 
 
 
 
 
294
  state = MedicalState(
295
- patient_id="session_001",
296
  conversation_history=history + [{"role": "user", "content": message}],
297
  symptoms={},
298
  vital_questions_asked=[],
@@ -303,130 +344,172 @@ class EnhancedMedicalAssistant:
303
  red_flags=[],
304
  assessment_complete=False,
305
  suggested_actions=[],
306
- consultation_stage="intake"
 
 
307
  )
308
 
309
- # Run through LangGraph workflow
310
- result = self.workflow.invoke(state)
311
-
312
- # Generate contextual response
313
- response = self.generate_contextual_response(result, message)
314
-
315
- return response
316
-
317
- def generate_contextual_response(self, state: MedicalState, user_message: str) -> str:
318
- """Generate a contextual response based on the current state"""
319
- if state["consultation_stage"] == "emergency":
320
- return self.format_emergency_response(state)
321
- elif state["consultation_stage"] == "intake":
322
- return self.format_intake_response(state, user_message)
323
- elif state["consultation_stage"] == "assessment":
324
- return self.format_assessment_response(state)
325
- elif state["consultation_stage"] == "recommendations":
326
- return self.format_recommendations_response(state)
327
- else:
328
- return self.format_default_response(user_message)
329
-
330
- def format_emergency_response(self, state: MedicalState) -> str:
331
- """Format emergency response"""
332
- return f"""
333
- 🚨 URGENT MEDICAL ATTENTION NEEDED 🚨
334
-
335
- Based on your symptoms, I recommend seeking immediate medical care because:
336
- {', '.join(state['red_flags'])}
337
-
338
- Please:
339
- - Go to the nearest emergency room, OR
340
- - Call emergency services (911), OR
341
- - Contact your doctor immediately
342
-
343
- This is not a diagnosis, but these symptoms warrant immediate professional evaluation.
344
- """
345
 
346
- def format_intake_response(self, state: MedicalState, user_message: str) -> str:
347
- """Format intake response with follow-up questions"""
348
- # Use Llama-2 to generate empathetic response
349
- prompt = f"""
350
- You are a caring virtual doctor. The patient said: "{user_message}"
351
 
352
- Respond empathetically and ask 1-2 specific follow-up questions about:
353
- - Symptom details (duration, severity, triggers)
354
- - Associated symptoms
355
- - Medical history if relevant
356
-
357
- Be professional, caring, and thorough.
358
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  return self.generate_llama_response(prompt)
361
 
362
- def format_assessment_response(self, state: MedicalState) -> str:
363
- """Format detailed assessment response"""
364
- return "Let me gather a bit more information to better understand your condition..."
365
-
366
- def format_recommendations_response(self, state: MedicalState) -> str:
367
- """Format final recommendations"""
368
- recommendations = "\n".join(state["suggested_actions"])
369
- return f"""
370
- Based on our consultation, here's my assessment and recommendations:
371
-
372
- {recommendations}
373
-
374
- **Important Disclaimer:** I am an AI assistant, not a licensed medical professional.
375
- These suggestions are for informational purposes only. Please consult with a
376
- healthcare provider for proper diagnosis and treatment.
377
- """
378
-
379
- def format_default_response(self, user_message: str) -> str:
380
- """Format default response"""
381
- return self.generate_llama_response(f"Respond professionally to: {user_message}")
382
-
383
  def generate_llama_response(self, prompt: str) -> str:
384
- """Generate response using Llama-2"""
385
- formatted_prompt = f"<s>[INST] {prompt} [/INST] "
386
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
387
-
388
- with torch.no_grad():
389
- outputs = self.model.generate(
390
- inputs.input_ids,
391
- attention_mask=inputs.attention_mask,
392
- max_new_tokens=300,
393
- temperature=0.7,
394
- top_p=0.9,
395
- do_sample=True,
396
- pad_token_id=self.tokenizer.eos_token_id
397
- )
398
-
399
- response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
400
- return response.split('</s>')[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
- # Initialize the enhanced medical assistant
403
  medical_assistant = EnhancedMedicalAssistant()
404
 
405
- @spaces.GPU
406
  def chat_interface(message, history):
407
  """Gradio chat interface"""
408
- return medical_assistant.generate_response(message, history)
 
 
 
 
409
 
410
  # Create Gradio interface
411
  demo = gr.ChatInterface(
412
  fn=chat_interface,
413
- title="🏥 Advanced Medical AI Assistant",
414
  description="""
415
- I'm an AI medical assistant that can help assess your symptoms and provide guidance.
416
- I'll ask relevant questions to better understand your condition and provide appropriate recommendations.
 
 
 
 
 
417
 
418
  ⚠️ **Important**: I'm not a replacement for professional medical care. Always consult healthcare providers for serious concerns.
419
  """,
420
  examples=[
421
- "I've been having severe chest pain for the last hour",
422
- "I have a persistent cough that's been going on for 2 weeks",
423
- "I'm experiencing nausea and stomach pain after eating",
424
- "I have a headache and feel dizzy"
425
  ],
426
  theme="soft",
427
  css="""
428
- .message.user { background-color: #e3f2fd; }
429
- .message.bot { background-color: #f1f8e9; }
 
 
 
 
 
 
 
 
 
 
430
  """
431
  )
432
 
 
6
  from typing import TypedDict, List, Dict, Optional
7
  from datetime import datetime
8
  import json
9
+ import re
10
+ import numpy as np
11
+ from sentence_transformers import SentenceTransformer
12
+ import faiss
13
+ import pickle
14
+ import os
15
 
16
+ # Enhanced State Management with RAG
17
  class MedicalState(TypedDict):
18
  patient_id: str
19
  conversation_history: List[Dict]
 
26
  red_flags: List[str]
27
  assessment_complete: bool
28
  suggested_actions: List[str]
29
+ consultation_stage: str
30
+ retrieved_knowledge: List[Dict]
31
+ confidence_scores: Dict[str, float]
32
+
33
+ # Medical Knowledge Base for RAG
34
+ MEDICAL_KNOWLEDGE_BASE = {
35
+ "conditions": {
36
+ "common_cold": {
37
+ "symptoms": ["runny nose", "cough", "sneezing", "sore throat", "mild fever"],
38
+ "treatment": "Rest, fluids, OTC pain relievers",
39
+ "otc_medications": [
40
+ {"name": "Acetaminophen", "dose": "500-1000mg every 4-6 hours", "max_daily": "3000mg"},
41
+ {"name": "Ibuprofen", "dose": "200-400mg every 4-6 hours", "max_daily": "1200mg"}
42
+ ],
43
+ "home_remedies": ["Warm salt water gargle", "Honey and lemon tea", "Steam inhalation"],
44
+ "when_to_seek_care": "If symptoms worsen after 7-10 days or fever above 101.3°F"
45
+ },
46
+ "headache": {
47
+ "symptoms": ["head pain", "pressure", "throbbing"],
48
+ "treatment": "Pain relief, rest, hydration",
49
+ "otc_medications": [
50
+ {"name": "Acetaminophen", "dose": "500-1000mg every 4-6 hours", "max_daily": "3000mg"},
51
+ {"name": "Ibuprofen", "dose": "400-600mg every 6-8 hours", "max_daily": "1200mg"}
52
+ ],
53
+ "home_remedies": ["Cold or warm compress", "Rest in dark room", "Stay hydrated"],
54
+ "when_to_seek_care": "Sudden severe headache, fever, neck stiffness, vision changes"
55
+ },
56
+ "stomach_pain": {
57
+ "symptoms": ["abdominal pain", "nausea", "bloating", "cramps"],
58
+ "treatment": "Bland diet, rest, hydration",
59
+ "otc_medications": [
60
+ {"name": "Pepto-Bismol", "dose": "525mg every 30 minutes as needed", "max_daily": "8 doses"},
61
+ {"name": "TUMS", "dose": "2-4 tablets as needed", "max_daily": "15 tablets"}
62
+ ],
63
+ "home_remedies": ["BRAT diet", "Ginger tea", "Warm compress on stomach"],
64
+ "when_to_seek_care": "Severe pain, fever, vomiting, blood in stool"
65
+ }
66
+ }
67
+ }
68
 
 
69
  MEDICAL_CATEGORIES = {
70
+ "respiratory": ["cough", "shortness of breath", "chest pain", "wheezing", "runny nose", "sore throat"],
71
+ "gastrointestinal": ["nausea", "vomiting", "diarrhea", "stomach pain", "heartburn", "bloating"],
72
  "neurological": ["headache", "dizziness", "numbness", "tingling"],
73
  "musculoskeletal": ["joint pain", "muscle pain", "back pain", "stiffness"],
74
  "cardiovascular": ["chest pain", "palpitations", "swelling", "fatigue"],
 
82
  "sudden vision changes", "loss of consciousness", "severe allergic reaction"
83
  ]
84
 
85
+ class SimpleRAGSystem:
86
+ def __init__(self):
87
+ self.knowledge_base = MEDICAL_KNOWLEDGE_BASE
88
+ self.setup_simple_retrieval()
89
+
90
+ def setup_simple_retrieval(self):
91
+ """Setup simple keyword-based retrieval system"""
92
+ self.symptom_to_condition = {}
93
+
94
+ for condition, data in self.knowledge_base["conditions"].items():
95
+ for symptom in data["symptoms"]:
96
+ if symptom not in self.symptom_to_condition:
97
+ self.symptom_to_condition[symptom] = []
98
+ self.symptom_to_condition[symptom].append(condition)
99
+
100
+ def retrieve_relevant_conditions(self, symptoms: List[str]) -> List[Dict]:
101
+ """Retrieve relevant medical conditions based on symptoms"""
102
+ relevant_conditions = {}
103
+
104
+ for symptom in symptoms:
105
+ symptom_lower = symptom.lower()
106
+
107
+ # Direct match
108
+ if symptom_lower in self.symptom_to_condition:
109
+ for condition in self.symptom_to_condition[symptom_lower]:
110
+ if condition not in relevant_conditions:
111
+ relevant_conditions[condition] = self.knowledge_base["conditions"][condition]
112
+
113
+ # Partial match
114
+ for kb_symptom, conditions in self.symptom_to_condition.items():
115
+ if symptom_lower in kb_symptom or kb_symptom in symptom_lower:
116
+ for condition in conditions:
117
+ if condition not in relevant_conditions:
118
+ relevant_conditions[condition] = self.knowledge_base["conditions"][condition]
119
+
120
+ return [{"condition": k, "data": v} for k, v in relevant_conditions.items()]
121
 
122
  class EnhancedMedicalAssistant:
123
  def __init__(self):
124
  self.load_models()
125
+ self.rag_system = SimpleRAGSystem()
126
  self.setup_langgraph()
127
+ self.conversation_count = {}
128
 
129
  def load_models(self):
130
  """Load the AI models"""
131
  print("Loading models...")
132
+ try:
133
+ # Llama-2 for conversation
134
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
135
+ if self.tokenizer.pad_token is None:
136
+ self.tokenizer.pad_token = self.tokenizer.eos_token
137
+
138
+ self.model = AutoModelForCausalLM.from_pretrained(
139
+ "meta-llama/Llama-2-7b-chat-hf",
140
+ torch_dtype=torch.float16,
141
+ device_map="auto"
142
+ )
143
+
144
+ # Meditron for medical suggestions
145
+ self.meditron_tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b")
146
+ if self.meditron_tokenizer.pad_token is None:
147
+ self.meditron_tokenizer.pad_token = self.meditron_tokenizer.eos_token
148
+
149
+ self.meditron_model = AutoModelForCausalLM.from_pretrained(
150
+ "epfl-llm/meditron-7b",
151
+ torch_dtype=torch.float16,
152
+ device_map="auto"
153
+ )
154
+ print("Models loaded successfully!")
155
+ except Exception as e:
156
+ print(f"Error loading models: {e}")
157
+ # Fallback - use only one model
158
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
159
+ self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
160
+ self.meditron_tokenizer = self.tokenizer
161
+ self.meditron_model = self.model
162
 
163
  def setup_langgraph(self):
164
+ """Setup simplified LangGraph workflow"""
165
  workflow = StateGraph(MedicalState)
166
 
 
167
  workflow.add_node("intake", self.patient_intake)
 
 
168
  workflow.add_node("generate_recommendations", self.generate_recommendations)
169
  workflow.add_node("emergency_triage", self.emergency_triage)
170
 
 
171
  workflow.set_entry_point("intake")
172
  workflow.add_conditional_edges(
173
  "intake",
174
  self.route_after_intake,
175
  {
 
176
  "emergency": "emergency_triage",
177
+ "recommendations": "generate_recommendations"
 
 
 
 
 
 
 
 
 
 
178
  }
179
  )
180
  workflow.add_edge("generate_recommendations", END)
 
183
  self.workflow = workflow.compile()
184
 
185
  def patient_intake(self, state: MedicalState) -> MedicalState:
186
+ """Enhanced patient intake with RAG"""
187
  last_message = state["conversation_history"][-1]["content"] if state["conversation_history"] else ""
188
 
189
+ # Extract symptoms
190
  detected_symptoms = self.extract_symptoms(last_message)
191
  state["symptoms"].update(detected_symptoms)
192
 
193
+ # Use RAG to get relevant medical knowledge
194
+ if detected_symptoms:
195
+ symptom_names = list(detected_symptoms.keys())
196
+ relevant_conditions = self.rag_system.retrieve_relevant_conditions(symptom_names)
197
+ state["retrieved_knowledge"] = relevant_conditions
198
+
199
  # Check for red flags
200
  red_flags = self.check_red_flags(last_message)
201
+ state["red_flags"].extend(red_flags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ # Determine consultation stage
204
+ if red_flags:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  state["consultation_stage"] = "emergency"
206
  else:
207
  state["consultation_stage"] = "recommendations"
 
209
  return state
210
 
211
  def generate_recommendations(self, state: MedicalState) -> MedicalState:
212
+ """Generate RAG-enhanced recommendations"""
213
+ # Create structured recommendations from RAG knowledge
214
+ recommendations = self.create_structured_recommendations(state)
 
 
215
  state["suggested_actions"] = recommendations
 
216
  return state
217
 
218
+ def create_structured_recommendations(self, state: MedicalState) -> List[str]:
219
+ """Create structured recommendations using RAG knowledge"""
220
+ recommendations = []
221
+
222
+ if not state["retrieved_knowledge"]:
223
+ recommendations.append("I need more specific information about your symptoms to provide targeted recommendations.")
224
+ return recommendations
225
+
226
+ # Process each relevant condition
227
+ for knowledge_item in state["retrieved_knowledge"][:2]: # Limit to top 2 conditions
228
+ condition = knowledge_item["condition"]
229
+ data = knowledge_item["data"]
230
+
231
+ # Format condition information
232
+ condition_info = f"\n**Possible Condition: {condition.replace('_', ' ').title()}**\n"
233
+
234
+ # Add medications
235
+ if "otc_medications" in data:
236
+ condition_info += "\n**💊 Over-the-Counter Medications:**\n"
237
+ for med in data["otc_medications"]:
238
+ condition_info += f"• **{med['name']}**: {med['dose']} (Max daily: {med['max_daily']})\n"
239
+
240
+ # Add home remedies
241
+ if "home_remedies" in data:
242
+ condition_info += "\n**🏠 Home Remedies:**\n"
243
+ for remedy in data["home_remedies"]:
244
+ condition_info += f"• {remedy}\n"
245
+
246
+ # Add when to seek care
247
+ if "when_to_seek_care" in data:
248
+ condition_info += f"\n**⚠️ Seek Medical Care If:** {data['when_to_seek_care']}\n"
249
+
250
+ recommendations.append(condition_info)
251
+
252
+ # Add general advice
253
+ recommendations.append("""
254
+ **📋 General Recommendations:**
255
+ • Monitor your symptoms for any changes
256
+ • Stay hydrated and get adequate rest
257
+ • Follow medication instructions carefully
258
+ • Don't exceed recommended dosages
259
+
260
+ **🚨 Emergency Warning Signs:**
261
+ • Severe worsening of symptoms
262
+ • High fever (>101.3°F/38.5°C)
263
+ • Difficulty breathing
264
+ • Severe pain
265
+ • Signs of dehydration
266
+ """)
267
+
268
+ return recommendations
269
+
270
  def emergency_triage(self, state: MedicalState) -> MedicalState:
271
  """Handle emergency situations"""
272
+ emergency_response = f"""
273
+ 🚨 **URGENT MEDICAL ATTENTION NEEDED** 🚨
274
+
275
+ Based on your symptoms, I strongly recommend seeking immediate medical care because you mentioned: {', '.join(state['red_flags'])}
276
+
277
+ **Immediate Actions:**
278
+ • Go to the nearest emergency room, OR
279
+ • Call emergency services (911), OR
280
+ • Contact your doctor immediately
281
+
282
+ **Why This is Urgent:**
283
+ These symptoms can indicate serious conditions that require professional medical evaluation and treatment.
284
+
285
+ ⚠️ **Disclaimer:** This is not a medical diagnosis, but these symptoms warrant immediate professional assessment.
286
+ """
287
+
288
  state["suggested_actions"] = [emergency_response]
289
  return state
290
 
 
292
  """Route decision after intake"""
293
  if state["red_flags"]:
294
  return "emergency"
 
 
295
  else:
296
+ return "recommendations"
 
 
 
 
 
 
 
 
 
297
 
298
  def extract_symptoms(self, text: str) -> Dict:
299
  """Extract and categorize symptoms from patient text"""
 
322
 
323
  return found_flags
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  def generate_response(self, message: str, history: List) -> str:
326
  """Main response generation function"""
327
+ session_id = "default_session"
328
+
329
+ # Track conversation count
330
+ if session_id not in self.conversation_count:
331
+ self.conversation_count[session_id] = 0
332
+ self.conversation_count[session_id] += 1
333
+
334
+ # Initialize state
335
  state = MedicalState(
336
+ patient_id=session_id,
337
  conversation_history=history + [{"role": "user", "content": message}],
338
  symptoms={},
339
  vital_questions_asked=[],
 
344
  red_flags=[],
345
  assessment_complete=False,
346
  suggested_actions=[],
347
+ consultation_stage="intake",
348
+ retrieved_knowledge=[],
349
+ confidence_scores={}
350
  )
351
 
352
+ # For first few messages, do conversational intake
353
+ if self.conversation_count[session_id] <= 3:
354
+ return self.generate_conversational_response(message, history)
355
+
356
+ # After gathering info, run workflow for recommendations
357
+ try:
358
+ result = self.workflow.invoke(state)
359
+ return self.format_final_response(result)
360
+ except Exception as e:
361
+ print(f"Workflow error: {e}")
362
+ return self.generate_conversational_response(message, history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ def generate_conversational_response(self, message: str, history: List) -> str:
365
+ """Generate conversational response for intake phase"""
366
+ # Extract symptoms for context
367
+ symptoms = self.extract_symptoms(message)
368
+ red_flags = self.check_red_flags(message)
369
 
370
+ # Handle emergencies immediately
371
+ if red_flags:
372
+ return f"""
373
+ 🚨 **URGENT MEDICAL ATTENTION NEEDED** 🚨
374
+
375
+ I notice you mentioned: {', '.join(red_flags)}
376
+
377
+ Please seek immediate medical care:
378
+ • Go to the nearest emergency room
379
+ • Call emergency services (911)
380
+ • Contact your doctor immediately
381
+
382
+ These symptoms require professional medical evaluation right away.
383
+ """
384
+
385
+ # Generate contextual questions based on symptoms
386
+ if symptoms:
387
+ symptom_names = list(symptoms.keys())
388
+ prompt = f"""
389
+ You are a caring medical assistant. The patient mentioned these symptoms: {', '.join(symptom_names)}.
390
+
391
+ Respond empathetically and ask 1-2 relevant follow-up questions about:
392
+ - How long they've had these symptoms
393
+ - Severity (mild, moderate, severe)
394
+ - What makes it better or worse
395
+ - Any other symptoms they're experiencing
396
+
397
+ Be professional, caring, and concise. Don't provide treatment advice yet.
398
+ """
399
+ else:
400
+ prompt = f"""
401
+ You are a caring medical assistant. The patient said: "{message}"
402
+
403
+ Respond empathetically and ask relevant questions to understand their health concern better.
404
+ Be professional and caring.
405
+ """
406
 
407
  return self.generate_llama_response(prompt)
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  def generate_llama_response(self, prompt: str) -> str:
410
+ """Generate response using Llama-2 with better formatting"""
411
+ try:
412
+ formatted_prompt = f"<s>[INST] {prompt} [/INST]"
413
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
414
+
415
+ if torch.cuda.is_available():
416
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
417
+
418
+ with torch.no_grad():
419
+ outputs = self.model.generate(
420
+ **inputs,
421
+ max_new_tokens=200,
422
+ temperature=0.7,
423
+ top_p=0.9,
424
+ do_sample=True,
425
+ pad_token_id=self.tokenizer.eos_token_id
426
+ )
427
+
428
+ # Decode response
429
+ response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
430
+
431
+ # Clean up the response
432
+ response = response.split('</s>')[0].strip()
433
+ response = response.replace('<s>', '').replace('[INST]', '').replace('[/INST]', '').strip()
434
+
435
+ # Remove any XML-like tags
436
+ response = re.sub(r'<[^>]+>', '', response)
437
+
438
+ return response if response else "I understand your concern. Can you tell me more about what you're experiencing?"
439
+
440
+ except Exception as e:
441
+ print(f"Error generating response: {e}")
442
+ return "I understand your concern. Can you tell me more about your symptoms?"
443
+
444
+ def format_final_response(self, state: MedicalState) -> str:
445
+ """Format the final response with recommendations"""
446
+ if state["consultation_stage"] == "emergency":
447
+ return state["suggested_actions"][0] if state["suggested_actions"] else "Please seek immediate medical attention."
448
+
449
+ # Format recommendations nicely
450
+ if state["suggested_actions"]:
451
+ response = "## 🏥 Medical Assessment & Recommendations\n\n"
452
+ response += "Based on our conversation, here's what I recommend:\n"
453
+
454
+ for action in state["suggested_actions"]:
455
+ response += f"{action}\n"
456
+
457
+ response += "\n---\n"
458
+ response += "**Important Disclaimer:** I'm an AI assistant providing general health information. "
459
+ response += "This is not a substitute for professional medical advice, diagnosis, or treatment. "
460
+ response += "Always consult with qualified healthcare providers for medical concerns."
461
+
462
+ return response
463
+ else:
464
+ return "Please provide more details about your symptoms so I can offer better guidance."
465
 
466
+ # Initialize the medical assistant
467
  medical_assistant = EnhancedMedicalAssistant()
468
 
469
+ @spaces.GPU
470
  def chat_interface(message, history):
471
  """Gradio chat interface"""
472
+ try:
473
+ return medical_assistant.generate_response(message, history)
474
+ except Exception as e:
475
+ print(f"Chat interface error: {e}")
476
+ return f"I apologize, but I encountered an error. Please try rephrasing your question. Error: {str(e)}"
477
 
478
  # Create Gradio interface
479
  demo = gr.ChatInterface(
480
  fn=chat_interface,
481
+ title="🏥 Medical AI Assistant with medRAG",
482
  description="""
483
+ I'm an AI medical assistant powered by medical knowledge retrieval (medRAG).
484
+ I can help assess your symptoms and provide evidence-based recommendations.
485
+
486
+ **How it works:**
487
+ 1. Tell me about your symptoms
488
+ 2. I'll ask follow-up questions
489
+ 3. I'll provide personalized recommendations based on medical knowledge
490
 
491
  ⚠️ **Important**: I'm not a replacement for professional medical care. Always consult healthcare providers for serious concerns.
492
  """,
493
  examples=[
494
+ "I have a bad cough and sore throat",
495
+ "I've been having headaches for the past few days",
496
+ "My stomach has been hurting after meals",
497
+ "I have chest pain and trouble breathing"
498
  ],
499
  theme="soft",
500
  css="""
501
+ .message.user {
502
+ background-color: #e3f2fd;
503
+ border-radius: 10px;
504
+ padding: 10px;
505
+ margin: 5px;
506
+ }
507
+ .message.bot {
508
+ background-color: #f1f8e9;
509
+ border-radius: 10px;
510
+ padding: 10px;
511
+ margin: 5px;
512
+ }
513
  """
514
  )
515