Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,218 Bytes
b80af5b 71bcd31 9f6ac99 000ab02 71bcd31 01a984c 6e237a4 01a984c 6e237a4 01a984c 6e237a4 01a984c 6e237a4 01a984c 71bcd31 1728da9 5522bf8 1728da9 5522bf8 1728da9 71bcd31 1728da9 a7f6391 1728da9 43e5827 1728da9 43e5827 1728da9 43e5827 1728da9 43e5827 1728da9 43e5827 1728da9 f3b4260 1728da9 a7f6391 1728da9 a7f6391 1728da9 a7f6391 1728da9 f3b4260 1728da9 a7f6391 43e5827 1728da9 43e5827 1728da9 43e5827 1728da9 43e5827 1728da9 43e5827 1728da9 a7f6391 1728da9 f3b4260 1728da9 c4447f4 1728da9 6d5190c 1728da9 8b29c0d 1728da9 43e5827 8b29c0d 1728da9 6d5190c b80af5b 71bcd31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"
SYSTEM_PROMPT = """You are a professional virtual doctor conducting a medical consultation. Your role is to gather comprehensive information about the patient's condition through intelligent questioning.
**CONSULTATION APPROACH:**
- Ask thoughtful, relevant follow-up questions based on the patient's responses
- Prioritize gathering information about: symptom details, duration, severity, triggers, related symptoms, medical history, medications, and lifestyle factors
- Ask 1-2 specific questions at a time that build naturally on their previous answers
- Be empathetic, professional, and thorough in your questioning
- Adapt your questions based on the symptoms they describe
**IMPORTANT GUIDELINES:**
- Generate intelligent follow-up questions that are contextually relevant to their specific symptoms
- Don't ask generic questions - tailor each question to their particular situation
- If they mention pain, ask about location, type, and triggers
- If they mention duration, ask about progression or changes
- Build each question logically from their previous responses
After 4-5 meaningful exchanges, provide assessment and recommendations.
Do NOT make specific prescriptions for prescription-only drugs.
Always maintain a professional, caring tone throughout the consultation."""
MEDITRON_PROMPT = """You are a board-certified physician providing evidence-based medical assessment.
Based on the patient information provided, please:
1. Analyze the symptoms systematically
2. Provide a differential diagnosis with most likely conditions
3. Recommend appropriate next steps (testing, monitoring, or treatment)
4. Suggest appropriate medications or remedies with dosing if applicable
5. Include red flags that would require urgent medical attention
6. Base recommendations on clinical guidelines
Patient Information: {patient_info}
Please provide a structured medical assessment:"""
# Load models
print("Loading models...")
try:
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
LLAMA_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Llama-2 model loaded successfully!")
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
if meditron_tokenizer.pad_token is None:
meditron_tokenizer.pad_token = meditron_tokenizer.eos_token
meditron_model = AutoModelForCausalLM.from_pretrained(
MEDITRON_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Meditron model loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
class MedicalConsultationBot:
def __init__(self):
self.reset_conversation()
def reset_conversation(self):
"""Reset all conversation state"""
self.conversation_history = []
self.patient_name = None
self.patient_age = None
self.medical_turns = 0
self.stage = "greeting" # greeting -> name -> age -> symptoms -> diagnosis
def add_to_history(self, user_message, bot_response):
"""Add exchange to conversation history"""
self.conversation_history.append({
"user": user_message,
"bot": bot_response
})
def get_conversation_context(self):
"""Get full conversation context as string"""
context = ""
if self.patient_name:
context += f"Patient Name: {self.patient_name}\n"
if self.patient_age:
context += f"Patient Age: {self.patient_age}\n"
context += "\nConversation History:\n"
for exchange in self.conversation_history:
context += f"Patient: {exchange['user']}\n"
context += f"Doctor: {exchange['bot']}\n"
return context
def build_llama_prompt(self, current_message):
"""Build prompt for Llama model with conversation context"""
prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
# Add conversation context
context = self.get_conversation_context()
if context.strip():
prompt += f"Previous conversation context:\n{context}\n\n"
prompt += f"Current patient message: {current_message}\n\nProvide a professional medical response with appropriate follow-up questions. [/INST]"
return prompt
# Global bot instance
medical_bot = MedicalConsultationBot()
def get_meditron_diagnosis(patient_info):
"""Use Meditron model to generate medical assessment"""
try:
prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
inputs = meditron_tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True
).to(meditron_model.device)
with torch.no_grad():
outputs = meditron_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=meditron_tokenizer.pad_token_id
)
response = meditron_tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
).strip()
return response
except Exception as e:
return f"Error generating medical assessment: {str(e)}"
@spaces.GPU
def medical_chat_response(message, history):
"""Main chat response function with proper state management"""
global medical_bot
# If this is a new conversation (empty history), reset the bot
if not history:
medical_bot.reset_conversation()
user_message = message.strip()
# Stage 1: Initial greeting and ask for name
if medical_bot.stage == "greeting":
bot_response = "Hello! I'm your AI medical assistant. Before we discuss your health concerns, could you please tell me your name?"
medical_bot.stage = "name"
medical_bot.add_to_history(user_message, bot_response)
return bot_response
# Stage 2: Collect name and ask for age
elif medical_bot.stage == "name":
medical_bot.patient_name = user_message
bot_response = f"Nice to meet you, {medical_bot.patient_name}! Could you please tell me your age?"
medical_bot.stage = "age"
medical_bot.add_to_history(user_message, bot_response)
return bot_response
# Stage 3: Collect age and start medical consultation
elif medical_bot.stage == "age":
medical_bot.patient_age = user_message
bot_response = f"Thank you, {medical_bot.patient_name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
medical_bot.stage = "symptoms"
medical_bot.add_to_history(user_message, bot_response)
return bot_response
# Stage 4: Medical consultation - gather symptoms with intelligent follow-ups
elif medical_bot.stage == "symptoms":
medical_bot.medical_turns += 1
# If we've had enough turns, move to diagnosis
if medical_bot.medical_turns >= 4:
medical_bot.stage = "diagnosis"
return generate_final_diagnosis(user_message)
# Generate intelligent follow-up questions
try:
prompt = medical_bot.build_llama_prompt(user_message)
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=1024,
truncation=True
).to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=200,
temperature=0.8,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.pad_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
bot_response = full_response.split('[/INST]')[-1].strip()
# Clean up the response
bot_response = bot_response.replace('<s>', '').replace('</s>', '').strip()
medical_bot.add_to_history(user_message, bot_response)
return bot_response
except Exception as e:
bot_response = f"I understand. Could you tell me more about how long you've been experiencing this and if there are any specific triggers or patterns you've noticed?"
medical_bot.add_to_history(user_message, bot_response)
return bot_response
# Stage 5: Final diagnosis and treatment recommendations
elif medical_bot.stage == "diagnosis":
return generate_final_diagnosis(user_message)
# Handle any questions after diagnosis
else:
# Check if they're asking about their name or previous information
if "name" in user_message.lower() and medical_bot.patient_name:
return f"Your name is {medical_bot.patient_name}."
elif "age" in user_message.lower() and medical_bot.patient_age:
return f"You told me you are {medical_bot.patient_age} years old."
else:
return "Is there anything else about your health concerns I can help you with today?"
def generate_final_diagnosis(current_message):
"""Generate final diagnosis using both models"""
global medical_bot
# Add current message to history
medical_bot.add_to_history(current_message, "")
# Compile complete patient information
patient_info = f"""
Patient Name: {medical_bot.patient_name}
Patient Age: {medical_bot.patient_age}
Complete Consultation History:
"""
for exchange in medical_bot.conversation_history[:-1]: # Exclude the empty last entry
patient_info += f"Doctor: {exchange['bot']}\n"
patient_info += f"Patient: {exchange['user']}\n"
patient_info += f"Patient: {current_message}\n"
# Get diagnosis from Meditron
meditron_assessment = get_meditron_diagnosis(patient_info)
# Generate comprehensive response
final_response = f"""Thank you for providing all this information, {medical_bot.patient_name}. Based on our consultation, here is my assessment:
**MEDICAL ASSESSMENT & RECOMMENDATIONS:**
{meditron_assessment}
**IMPORTANT DISCLAIMER:** This assessment is for informational purposes only and should not replace professional medical advice. Please consult with a healthcare provider for proper diagnosis and treatment.
**NEXT STEPS:** I recommend scheduling an appointment with your primary care physician or appropriate specialist for further evaluation and personalized treatment.
Is there anything specific about this assessment you'd like me to clarify?"""
# Update conversation history with final response
medical_bot.conversation_history[-1]["bot"] = final_response
medical_bot.stage = "complete"
return final_response
# Create Gradio interface
demo = gr.ChatInterface(
fn=medical_chat_response,
title="๐ฉบ AI Medical Assistant with Memory",
description="I'm an AI medical assistant that will remember our conversation. I'll first ask for your basic information, then gather details about your symptoms through intelligent follow-up questions, and finally provide a medical assessment.",
examples=[
"Hello, I need medical help",
"I have a persistent cough",
"I've been having headaches",
"My stomach hurts"
],
theme="soft",
retry_btn=None,
undo_btn=None,
clear_btn="๐ Start New Consultation"
)
if __name__ == "__main__":
demo.launch() |