Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langgraph.graph import StateGraph, END | |
from typing import TypedDict, List, Dict, Optional | |
from datetime import datetime | |
import json | |
# Enhanced State Management | |
class MedicalState(TypedDict): | |
patient_id: str | |
conversation_history: List[Dict] | |
symptoms: Dict[str, any] | |
vital_questions_asked: List[str] | |
medical_history: Dict | |
current_medications: List[str] | |
allergies: List[str] | |
severity_scores: Dict[str, int] | |
red_flags: List[str] | |
assessment_complete: bool | |
suggested_actions: List[str] | |
consultation_stage: str # intake, assessment, summary, recommendations | |
# Medical Knowledge Base | |
MEDICAL_CATEGORIES = { | |
"respiratory": ["cough", "shortness of breath", "chest pain", "wheezing"], | |
"gastrointestinal": ["nausea", "vomiting", "diarrhea", "stomach pain", "heartburn"], | |
"neurological": ["headache", "dizziness", "numbness", "tingling"], | |
"musculoskeletal": ["joint pain", "muscle pain", "back pain", "stiffness"], | |
"cardiovascular": ["chest pain", "palpitations", "swelling", "fatigue"], | |
"dermatological": ["rash", "itching", "skin changes", "wounds"], | |
"mental_health": ["anxiety", "depression", "sleep issues", "stress"] | |
} | |
RED_FLAGS = [ | |
"chest pain", "difficulty breathing", "severe headache", "high fever", | |
"blood in stool", "blood in urine", "severe abdominal pain", | |
"sudden vision changes", "loss of consciousness", "severe allergic reaction" | |
] | |
VITAL_QUESTIONS = { | |
"symptom_onset": "When did your symptoms first start?", | |
"severity": "On a scale of 1-10, how severe would you rate your symptoms?", | |
"triggers": "What makes your symptoms better or worse?", | |
"associated_symptoms": "Are you experiencing any other symptoms?", | |
"medical_history": "Do you have any chronic medical conditions?", | |
"medications": "Are you currently taking any medications?", | |
"allergies": "Do you have any known allergies?" | |
} | |
class EnhancedMedicalAssistant: | |
def __init__(self): | |
self.load_models() | |
self.setup_langgraph() | |
def load_models(self): | |
"""Load the AI models""" | |
print("Loading models...") | |
# Llama-2 for conversation | |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Llama-2-7b-chat-hf", | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Meditron for medical suggestions | |
self.meditron_tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b") | |
self.meditron_model = AutoModelForCausalLM.from_pretrained( | |
"epfl-llm/meditron-7b", | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Models loaded successfully!") | |
def setup_langgraph(self): | |
"""Setup LangGraph workflow""" | |
workflow = StateGraph(MedicalState) | |
# Add nodes | |
workflow.add_node("intake", self.patient_intake) | |
workflow.add_node("symptom_assessment", self.assess_symptoms) | |
workflow.add_node("risk_evaluation", self.evaluate_risks) | |
workflow.add_node("generate_recommendations", self.generate_recommendations) | |
workflow.add_node("emergency_triage", self.emergency_triage) | |
# Define edges | |
workflow.set_entry_point("intake") | |
workflow.add_conditional_edges( | |
"intake", | |
self.route_after_intake, | |
{ | |
"continue_assessment": "symptom_assessment", | |
"emergency": "emergency_triage", | |
"complete": "generate_recommendations" | |
} | |
) | |
workflow.add_edge("symptom_assessment", "risk_evaluation") | |
workflow.add_conditional_edges( | |
"risk_evaluation", | |
self.route_after_risk_eval, | |
{ | |
"emergency": "emergency_triage", | |
"continue": "generate_recommendations", | |
"need_more_info": "symptom_assessment" | |
} | |
) | |
workflow.add_edge("generate_recommendations", END) | |
workflow.add_edge("emergency_triage", END) | |
self.workflow = workflow.compile() | |
def patient_intake(self, state: MedicalState) -> MedicalState: | |
"""Initial patient intake and basic information gathering""" | |
last_message = state["conversation_history"][-1]["content"] if state["conversation_history"] else "" | |
# Extract symptoms and categorize them | |
detected_symptoms = self.extract_symptoms(last_message) | |
state["symptoms"].update(detected_symptoms) | |
# Check for red flags | |
red_flags = self.check_red_flags(last_message) | |
if red_flags: | |
state["red_flags"].extend(red_flags) | |
# Determine what vital questions still need to be asked | |
missing_questions = self.get_missing_vital_questions(state) | |
if missing_questions and len(state["conversation_history"]) < 6: | |
state["consultation_stage"] = "intake" | |
return state | |
else: | |
state["consultation_stage"] = "assessment" | |
return state | |
def assess_symptoms(self, state: MedicalState) -> MedicalState: | |
"""Detailed symptom assessment""" | |
# Analyze symptom patterns and severity | |
for symptom, details in state["symptoms"].items(): | |
if "severity" not in details: | |
# Need to ask about severity | |
state["consultation_stage"] = "assessment" | |
return state | |
state["assessment_complete"] = True | |
return state | |
def evaluate_risks(self, state: MedicalState) -> MedicalState: | |
"""Evaluate patient risks and urgency""" | |
risk_score = 0 | |
# Check red flags | |
if state["red_flags"]: | |
risk_score += len(state["red_flags"]) * 3 | |
# Check severity scores | |
for severity in state["severity_scores"].values(): | |
if severity >= 8: | |
risk_score += 2 | |
elif severity >= 6: | |
risk_score += 1 | |
# Check symptom duration and progression | |
# (Implementation would analyze timeline) | |
if risk_score >= 5: | |
state["consultation_stage"] = "emergency" | |
else: | |
state["consultation_stage"] = "recommendations" | |
return state | |
def generate_recommendations(self, state: MedicalState) -> MedicalState: | |
"""Generate treatment recommendations and care suggestions""" | |
patient_summary = self.create_patient_summary(state) | |
# Use Meditron for medical recommendations | |
recommendations = self.get_meditron_recommendations(patient_summary) | |
state["suggested_actions"] = recommendations | |
return state | |
def emergency_triage(self, state: MedicalState) -> MedicalState: | |
"""Handle emergency situations""" | |
emergency_response = { | |
"urgent_care_needed": True, | |
"recommended_action": "Seek immediate medical attention", | |
"reasons": state["red_flags"], | |
"instructions": "Go to the nearest emergency room or call emergency services" | |
} | |
state["suggested_actions"] = [emergency_response] | |
return state | |
def route_after_intake(self, state: MedicalState): | |
"""Route decision after intake""" | |
if state["red_flags"]: | |
return "emergency" | |
elif len(state["vital_questions_asked"]) < 5: | |
return "continue_assessment" | |
else: | |
return "complete" | |
def route_after_risk_eval(self, state: MedicalState): | |
"""Route decision after risk evaluation""" | |
if state["consultation_stage"] == "emergency": | |
return "emergency" | |
elif state["assessment_complete"]: | |
return "continue" | |
else: | |
return "need_more_info" | |
def extract_symptoms(self, text: str) -> Dict: | |
"""Extract and categorize symptoms from patient text""" | |
symptoms = {} | |
text_lower = text.lower() | |
for category, symptom_list in MEDICAL_CATEGORIES.items(): | |
for symptom in symptom_list: | |
if symptom in text_lower: | |
symptoms[symptom] = { | |
"category": category, | |
"mentioned_at": datetime.now().isoformat(), | |
"context": text | |
} | |
return symptoms | |
def check_red_flags(self, text: str) -> List[str]: | |
"""Check for emergency red flags""" | |
found_flags = [] | |
text_lower = text.lower() | |
for flag in RED_FLAGS: | |
if flag in text_lower: | |
found_flags.append(flag) | |
return found_flags | |
def get_missing_vital_questions(self, state: MedicalState) -> List[str]: | |
"""Determine which vital questions haven't been asked""" | |
asked = state["vital_questions_asked"] | |
return [q for q in VITAL_QUESTIONS.keys() if q not in asked] | |
def create_patient_summary(self, state: MedicalState) -> str: | |
"""Create a comprehensive patient summary""" | |
summary = f""" | |
Patient Summary: | |
Symptoms: {json.dumps(state['symptoms'], indent=2)} | |
Medical History: {state['medical_history']} | |
Current Medications: {state['current_medications']} | |
Allergies: {state['allergies']} | |
Severity Scores: {state['severity_scores']} | |
Conversation History: {[msg['content'] for msg in state['conversation_history'][-3:]]} | |
""" | |
return summary | |
def get_meditron_recommendations(self, patient_summary: str) -> List[str]: | |
"""Get medical recommendations using Meditron model""" | |
prompt = f""" | |
Based on the following patient information, provide: | |
1. Specific over-the-counter medications with dosing | |
2. Home remedies and self-care measures | |
3. When to seek professional medical care | |
4. Follow-up recommendations | |
Patient Information: | |
{patient_summary} | |
Response:""" | |
inputs = self.meditron_tokenizer(prompt, return_tensors="pt").to(self.meditron_model.device) | |
with torch.no_grad(): | |
outputs = self.meditron_model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=400, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
recommendation = self.meditron_tokenizer.decode( | |
outputs[0][inputs.input_ids.shape[1]:], | |
skip_special_tokens=True | |
) | |
return [recommendation] | |
def generate_response(self, message: str, history: List) -> str: | |
"""Main response generation function""" | |
# Initialize or update state | |
state = MedicalState( | |
patient_id="session_001", | |
conversation_history=history + [{"role": "user", "content": message}], | |
symptoms={}, | |
vital_questions_asked=[], | |
medical_history={}, | |
current_medications=[], | |
allergies=[], | |
severity_scores={}, | |
red_flags=[], | |
assessment_complete=False, | |
suggested_actions=[], | |
consultation_stage="intake" | |
) | |
# Run through LangGraph workflow | |
result = self.workflow.invoke(state) | |
# Generate contextual response | |
response = self.generate_contextual_response(result, message) | |
return response | |
def generate_contextual_response(self, state: MedicalState, user_message: str) -> str: | |
"""Generate a contextual response based on the current state""" | |
if state["consultation_stage"] == "emergency": | |
return self.format_emergency_response(state) | |
elif state["consultation_stage"] == "intake": | |
return self.format_intake_response(state, user_message) | |
elif state["consultation_stage"] == "assessment": | |
return self.format_assessment_response(state) | |
elif state["consultation_stage"] == "recommendations": | |
return self.format_recommendations_response(state) | |
else: | |
return self.format_default_response(user_message) | |
def format_emergency_response(self, state: MedicalState) -> str: | |
"""Format emergency response""" | |
return f""" | |
π¨ URGENT MEDICAL ATTENTION NEEDED π¨ | |
Based on your symptoms, I recommend seeking immediate medical care because: | |
{', '.join(state['red_flags'])} | |
Please: | |
- Go to the nearest emergency room, OR | |
- Call emergency services (911), OR | |
- Contact your doctor immediately | |
This is not a diagnosis, but these symptoms warrant immediate professional evaluation. | |
""" | |
def format_intake_response(self, state: MedicalState, user_message: str) -> str: | |
"""Format intake response with follow-up questions""" | |
# Use Llama-2 to generate empathetic response | |
prompt = f""" | |
You are a caring virtual doctor. The patient said: "{user_message}" | |
Respond empathetically and ask 1-2 specific follow-up questions about: | |
- Symptom details (duration, severity, triggers) | |
- Associated symptoms | |
- Medical history if relevant | |
Be professional, caring, and thorough. | |
""" | |
return self.generate_llama_response(prompt) | |
def format_assessment_response(self, state: MedicalState) -> str: | |
"""Format detailed assessment response""" | |
return "Let me gather a bit more information to better understand your condition..." | |
def format_recommendations_response(self, state: MedicalState) -> str: | |
"""Format final recommendations""" | |
recommendations = "\n".join(state["suggested_actions"]) | |
return f""" | |
Based on our consultation, here's my assessment and recommendations: | |
{recommendations} | |
**Important Disclaimer:** I am an AI assistant, not a licensed medical professional. | |
These suggestions are for informational purposes only. Please consult with a | |
healthcare provider for proper diagnosis and treatment. | |
""" | |
def format_default_response(self, user_message: str) -> str: | |
"""Format default response""" | |
return self.generate_llama_response(f"Respond professionally to: {user_message}") | |
def generate_llama_response(self, prompt: str) -> str: | |
"""Generate response using Llama-2""" | |
formatted_prompt = f"<s>[INST] {prompt} [/INST] " | |
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=300, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return response.split('</s>')[0].strip() | |
# Initialize the enhanced medical assistant | |
medical_assistant = EnhancedMedicalAssistant() | |
def chat_interface(message, history): | |
"""Gradio chat interface""" | |
return medical_assistant.generate_response(message, history) | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
fn=chat_interface, | |
title="π₯ Advanced Medical AI Assistant", | |
description=""" | |
I'm an AI medical assistant that can help assess your symptoms and provide guidance. | |
I'll ask relevant questions to better understand your condition and provide appropriate recommendations. | |
β οΈ **Important**: I'm not a replacement for professional medical care. Always consult healthcare providers for serious concerns. | |
""", | |
examples=[ | |
"I've been having severe chest pain for the last hour", | |
"I have a persistent cough that's been going on for 2 weeks", | |
"I'm experiencing nausea and stomach pain after eating", | |
"I have a headache and feel dizzy" | |
], | |
theme="soft", | |
css=""" | |
.message.user { background-color: #e3f2fd; } | |
.message.bot { background-color: #f1f8e9; } | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |