Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langgraph.graph import StateGraph, END | |
from typing import TypedDict, List, Dict, Optional | |
from datetime import datetime | |
import json | |
import re | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import pickle | |
import os | |
# Enhanced State Management with RAG | |
class MedicalState(TypedDict): | |
patient_id: str | |
conversation_history: List[Dict] | |
symptoms: Dict[str, any] | |
vital_questions_asked: List[str] | |
medical_history: Dict | |
current_medications: List[str] | |
allergies: List[str] | |
severity_scores: Dict[str, int] | |
red_flags: List[str] | |
assessment_complete: bool | |
suggested_actions: List[str] | |
consultation_stage: str | |
retrieved_knowledge: List[Dict] | |
confidence_scores: Dict[str, float] | |
# Medical Knowledge Base for RAG | |
MEDICAL_KNOWLEDGE_BASE = { | |
"conditions": { | |
"common_cold": { | |
"symptoms": ["runny nose", "cough", "sneezing", "sore throat", "mild fever"], | |
"treatment": "Rest, fluids, OTC pain relievers", | |
"otc_medications": [ | |
{"name": "Acetaminophen", "dose": "500-1000mg every 4-6 hours", "max_daily": "3000mg"}, | |
{"name": "Ibuprofen", "dose": "200-400mg every 4-6 hours", "max_daily": "1200mg"} | |
], | |
"home_remedies": ["Warm salt water gargle", "Honey and lemon tea", "Steam inhalation"], | |
"when_to_seek_care": "If symptoms worsen after 7-10 days or fever above 101.3°F" | |
}, | |
"headache": { | |
"symptoms": ["head pain", "pressure", "throbbing"], | |
"treatment": "Pain relief, rest, hydration", | |
"otc_medications": [ | |
{"name": "Acetaminophen", "dose": "500-1000mg every 4-6 hours", "max_daily": "3000mg"}, | |
{"name": "Ibuprofen", "dose": "400-600mg every 6-8 hours", "max_daily": "1200mg"} | |
], | |
"home_remedies": ["Cold or warm compress", "Rest in dark room", "Stay hydrated"], | |
"when_to_seek_care": "Sudden severe headache, fever, neck stiffness, vision changes" | |
}, | |
"stomach_pain": { | |
"symptoms": ["abdominal pain", "nausea", "bloating", "cramps"], | |
"treatment": "Bland diet, rest, hydration", | |
"otc_medications": [ | |
{"name": "Pepto-Bismol", "dose": "525mg every 30 minutes as needed", "max_daily": "8 doses"}, | |
{"name": "TUMS", "dose": "2-4 tablets as needed", "max_daily": "15 tablets"} | |
], | |
"home_remedies": ["BRAT diet", "Ginger tea", "Warm compress on stomach"], | |
"when_to_seek_care": "Severe pain, fever, vomiting, blood in stool" | |
} | |
} | |
} | |
MEDICAL_CATEGORIES = { | |
"respiratory": ["cough", "shortness of breath", "chest pain", "wheezing", "runny nose", "sore throat"], | |
"gastrointestinal": ["nausea", "vomiting", "diarrhea", "stomach pain", "heartburn", "bloating"], | |
"neurological": ["headache", "dizziness", "numbness", "tingling"], | |
"musculoskeletal": ["joint pain", "muscle pain", "back pain", "stiffness"], | |
"cardiovascular": ["chest pain", "palpitations", "swelling", "fatigue"], | |
"dermatological": ["rash", "itching", "skin changes", "wounds"], | |
"mental_health": ["anxiety", "depression", "sleep issues", "stress"] | |
} | |
RED_FLAGS = [ | |
"chest pain", "difficulty breathing", "severe headache", "high fever", | |
"blood in stool", "blood in urine", "severe abdominal pain", | |
"sudden vision changes", "loss of consciousness", "severe allergic reaction" | |
] | |
class SimpleRAGSystem: | |
def __init__(self): | |
self.knowledge_base = MEDICAL_KNOWLEDGE_BASE | |
self.setup_simple_retrieval() | |
def setup_simple_retrieval(self): | |
"""Setup simple keyword-based retrieval system""" | |
self.symptom_to_condition = {} | |
for condition, data in self.knowledge_base["conditions"].items(): | |
for symptom in data["symptoms"]: | |
if symptom not in self.symptom_to_condition: | |
self.symptom_to_condition[symptom] = [] | |
self.symptom_to_condition[symptom].append(condition) | |
def retrieve_relevant_conditions(self, symptoms: List[str]) -> List[Dict]: | |
"""Retrieve relevant medical conditions based on symptoms""" | |
relevant_conditions = {} | |
for symptom in symptoms: | |
symptom_lower = symptom.lower() | |
# Direct match | |
if symptom_lower in self.symptom_to_condition: | |
for condition in self.symptom_to_condition[symptom_lower]: | |
if condition not in relevant_conditions: | |
relevant_conditions[condition] = self.knowledge_base["conditions"][condition] | |
# Partial match | |
for kb_symptom, conditions in self.symptom_to_condition.items(): | |
if symptom_lower in kb_symptom or kb_symptom in symptom_lower: | |
for condition in conditions: | |
if condition not in relevant_conditions: | |
relevant_conditions[condition] = self.knowledge_base["conditions"][condition] | |
return [{"condition": k, "data": v} for k, v in relevant_conditions.items()] | |
class EnhancedMedicalAssistant: | |
def __init__(self): | |
self.load_models() | |
self.rag_system = SimpleRAGSystem() | |
self.setup_langgraph() | |
self.conversation_count = {} | |
def load_models(self): | |
"""Load the AI models""" | |
print("Loading models...") | |
try: | |
# Llama-2 for conversation | |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Llama-2-7b-chat-hf", | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Meditron for medical suggestions | |
self.meditron_tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b") | |
if self.meditron_tokenizer.pad_token is None: | |
self.meditron_tokenizer.pad_token = self.meditron_tokenizer.eos_token | |
self.meditron_model = AutoModelForCausalLM.from_pretrained( | |
"epfl-llm/meditron-7b", | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Models loaded successfully!") | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
# Fallback - use only one model | |
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
self.meditron_tokenizer = self.tokenizer | |
self.meditron_model = self.model | |
def setup_langgraph(self): | |
"""Setup simplified LangGraph workflow""" | |
workflow = StateGraph(MedicalState) | |
workflow.add_node("intake", self.patient_intake) | |
workflow.add_node("generate_recommendations", self.generate_recommendations) | |
workflow.add_node("emergency_triage", self.emergency_triage) | |
workflow.set_entry_point("intake") | |
workflow.add_conditional_edges( | |
"intake", | |
self.route_after_intake, | |
{ | |
"emergency": "emergency_triage", | |
"recommendations": "generate_recommendations" | |
} | |
) | |
workflow.add_edge("generate_recommendations", END) | |
workflow.add_edge("emergency_triage", END) | |
self.workflow = workflow.compile() | |
def patient_intake(self, state: MedicalState) -> MedicalState: | |
"""Enhanced patient intake with RAG""" | |
last_message = state["conversation_history"][-1]["content"] if state["conversation_history"] else "" | |
# Extract symptoms | |
detected_symptoms = self.extract_symptoms(last_message) | |
state["symptoms"].update(detected_symptoms) | |
# Use RAG to get relevant medical knowledge | |
if detected_symptoms: | |
symptom_names = list(detected_symptoms.keys()) | |
relevant_conditions = self.rag_system.retrieve_relevant_conditions(symptom_names) | |
state["retrieved_knowledge"] = relevant_conditions | |
# Check for red flags | |
red_flags = self.check_red_flags(last_message) | |
state["red_flags"].extend(red_flags) | |
# Determine consultation stage | |
if red_flags: | |
state["consultation_stage"] = "emergency" | |
else: | |
state["consultation_stage"] = "recommendations" | |
return state | |
def generate_recommendations(self, state: MedicalState) -> MedicalState: | |
"""Generate RAG-enhanced recommendations""" | |
# Create structured recommendations from RAG knowledge | |
recommendations = self.create_structured_recommendations(state) | |
state["suggested_actions"] = recommendations | |
return state | |
def create_structured_recommendations(self, state: MedicalState) -> List[str]: | |
"""Create structured recommendations using RAG knowledge""" | |
recommendations = [] | |
if not state["retrieved_knowledge"]: | |
recommendations.append("I need more specific information about your symptoms to provide targeted recommendations.") | |
return recommendations | |
# Process each relevant condition | |
for knowledge_item in state["retrieved_knowledge"][:2]: # Limit to top 2 conditions | |
condition = knowledge_item["condition"] | |
data = knowledge_item["data"] | |
# Format condition information | |
condition_info = f"\n**Possible Condition: {condition.replace('_', ' ').title()}**\n" | |
# Add medications | |
if "otc_medications" in data: | |
condition_info += "\n**💊 Over-the-Counter Medications:**\n" | |
for med in data["otc_medications"]: | |
condition_info += f"• **{med['name']}**: {med['dose']} (Max daily: {med['max_daily']})\n" | |
# Add home remedies | |
if "home_remedies" in data: | |
condition_info += "\n**🏠 Home Remedies:**\n" | |
for remedy in data["home_remedies"]: | |
condition_info += f"• {remedy}\n" | |
# Add when to seek care | |
if "when_to_seek_care" in data: | |
condition_info += f"\n**⚠️ Seek Medical Care If:** {data['when_to_seek_care']}\n" | |
recommendations.append(condition_info) | |
# Add general advice | |
recommendations.append(""" | |
**📋 General Recommendations:** | |
• Monitor your symptoms for any changes | |
• Stay hydrated and get adequate rest | |
• Follow medication instructions carefully | |
• Don't exceed recommended dosages | |
**🚨 Emergency Warning Signs:** | |
• Severe worsening of symptoms | |
• High fever (>101.3°F/38.5°C) | |
• Difficulty breathing | |
• Severe pain | |
• Signs of dehydration | |
""") | |
return recommendations | |
def emergency_triage(self, state: MedicalState) -> MedicalState: | |
"""Handle emergency situations""" | |
emergency_response = f""" | |
🚨 **URGENT MEDICAL ATTENTION NEEDED** 🚨 | |
Based on your symptoms, I strongly recommend seeking immediate medical care because you mentioned: {', '.join(state['red_flags'])} | |
**Immediate Actions:** | |
• Go to the nearest emergency room, OR | |
• Call emergency services (911), OR | |
• Contact your doctor immediately | |
**Why This is Urgent:** | |
These symptoms can indicate serious conditions that require professional medical evaluation and treatment. | |
⚠️ **Disclaimer:** This is not a medical diagnosis, but these symptoms warrant immediate professional assessment. | |
""" | |
state["suggested_actions"] = [emergency_response] | |
return state | |
def route_after_intake(self, state: MedicalState): | |
"""Route decision after intake""" | |
if state["red_flags"]: | |
return "emergency" | |
else: | |
return "recommendations" | |
def extract_symptoms(self, text: str) -> Dict: | |
"""Extract and categorize symptoms from patient text""" | |
symptoms = {} | |
text_lower = text.lower() | |
for category, symptom_list in MEDICAL_CATEGORIES.items(): | |
for symptom in symptom_list: | |
if symptom in text_lower: | |
symptoms[symptom] = { | |
"category": category, | |
"mentioned_at": datetime.now().isoformat(), | |
"context": text | |
} | |
return symptoms | |
def check_red_flags(self, text: str) -> List[str]: | |
"""Check for emergency red flags""" | |
found_flags = [] | |
text_lower = text.lower() | |
for flag in RED_FLAGS: | |
if flag in text_lower: | |
found_flags.append(flag) | |
return found_flags | |
def generate_response(self, message: str, history: List) -> str: | |
"""Main response generation function""" | |
session_id = "default_session" | |
# Track conversation count | |
if session_id not in self.conversation_count: | |
self.conversation_count[session_id] = 0 | |
self.conversation_count[session_id] += 1 | |
# Initialize state | |
state = MedicalState( | |
patient_id=session_id, | |
conversation_history=history + [{"role": "user", "content": message}], | |
symptoms={}, | |
vital_questions_asked=[], | |
medical_history={}, | |
current_medications=[], | |
allergies=[], | |
severity_scores={}, | |
red_flags=[], | |
assessment_complete=False, | |
suggested_actions=[], | |
consultation_stage="intake", | |
retrieved_knowledge=[], | |
confidence_scores={} | |
) | |
# For first few messages, do conversational intake | |
if self.conversation_count[session_id] <= 3: | |
return self.generate_conversational_response(message, history) | |
# After gathering info, run workflow for recommendations | |
try: | |
result = self.workflow.invoke(state) | |
return self.format_final_response(result) | |
except Exception as e: | |
print(f"Workflow error: {e}") | |
return self.generate_conversational_response(message, history) | |
def generate_conversational_response(self, message: str, history: List) -> str: | |
"""Generate conversational response for intake phase""" | |
# Extract symptoms for context | |
symptoms = self.extract_symptoms(message) | |
red_flags = self.check_red_flags(message) | |
# Handle emergencies immediately | |
if red_flags: | |
return f""" | |
🚨 **URGENT MEDICAL ATTENTION NEEDED** 🚨 | |
I notice you mentioned: {', '.join(red_flags)} | |
Please seek immediate medical care: | |
• Go to the nearest emergency room | |
• Call emergency services (911) | |
• Contact your doctor immediately | |
These symptoms require professional medical evaluation right away. | |
""" | |
# Generate contextual questions based on symptoms | |
if symptoms: | |
symptom_names = list(symptoms.keys()) | |
prompt = f""" | |
You are a caring medical assistant. The patient mentioned these symptoms: {', '.join(symptom_names)}. | |
Respond empathetically and ask 1-2 relevant follow-up questions about: | |
- How long they've had these symptoms | |
- Severity (mild, moderate, severe) | |
- What makes it better or worse | |
- Any other symptoms they're experiencing | |
Be professional, caring, and concise. Don't provide treatment advice yet. | |
""" | |
else: | |
prompt = f""" | |
You are a caring medical assistant. The patient said: "{message}" | |
Respond empathetically and ask relevant questions to understand their health concern better. | |
Be professional and caring. | |
""" | |
return self.generate_llama_response(prompt) | |
def generate_llama_response(self, prompt: str) -> str: | |
"""Generate response using Llama-2 with better formatting""" | |
try: | |
formatted_prompt = f"<s>[INST] {prompt} [/INST]" | |
inputs = self.tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512) | |
if torch.cuda.is_available(): | |
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=200, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
# Decode response | |
response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
# Clean up the response | |
response = response.split('</s>')[0].strip() | |
response = response.replace('<s>', '').replace('[INST]', '').replace('[/INST]', '').strip() | |
# Remove any XML-like tags | |
response = re.sub(r'<[^>]+>', '', response) | |
return response if response else "I understand your concern. Can you tell me more about what you're experiencing?" | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
return "I understand your concern. Can you tell me more about your symptoms?" | |
def format_final_response(self, state: MedicalState) -> str: | |
"""Format the final response with recommendations""" | |
if state["consultation_stage"] == "emergency": | |
return state["suggested_actions"][0] if state["suggested_actions"] else "Please seek immediate medical attention." | |
# Format recommendations nicely | |
if state["suggested_actions"]: | |
response = "## 🏥 Medical Assessment & Recommendations\n\n" | |
response += "Based on our conversation, here's what I recommend:\n" | |
for action in state["suggested_actions"]: | |
response += f"{action}\n" | |
response += "\n---\n" | |
response += "**Important Disclaimer:** I'm an AI assistant providing general health information. " | |
response += "This is not a substitute for professional medical advice, diagnosis, or treatment. " | |
response += "Always consult with qualified healthcare providers for medical concerns." | |
return response | |
else: | |
return "Please provide more details about your symptoms so I can offer better guidance." | |
# Initialize the medical assistant | |
medical_assistant = EnhancedMedicalAssistant() | |
def chat_interface(message, history): | |
"""Gradio chat interface""" | |
try: | |
return medical_assistant.generate_response(message, history) | |
except Exception as e: | |
print(f"Chat interface error: {e}") | |
return f"I apologize, but I encountered an error. Please try rephrasing your question. Error: {str(e)}" | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
fn=chat_interface, | |
title="🏥 Medical AI Assistant with medRAG", | |
description=""" | |
I'm an AI medical assistant powered by medical knowledge retrieval (medRAG). | |
I can help assess your symptoms and provide evidence-based recommendations. | |
**How it works:** | |
1. Tell me about your symptoms | |
2. I'll ask follow-up questions | |
3. I'll provide personalized recommendations based on medical knowledge | |
⚠️ **Important**: I'm not a replacement for professional medical care. Always consult healthcare providers for serious concerns. | |
""", | |
examples=[ | |
"I have a bad cough and sore throat", | |
"I've been having headaches for the past few days", | |
"My stomach has been hurting after meals", | |
"I have chest pain and trouble breathing" | |
], | |
theme="soft", | |
css=""" | |
.message.user { | |
background-color: #e3f2fd; | |
border-radius: 10px; | |
padding: 10px; | |
margin: 5px; | |
} | |
.message.bot { | |
background-color: #f1f8e9; | |
border-radius: 10px; | |
padding: 10px; | |
margin: 5px; | |
} | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |