Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import re | |
# Model configuration | |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" | |
MEDITRON_MODEL = "epfl-llm/meditron-7b" | |
SYSTEM_PROMPT = """You are a professional virtual doctor conducting a medical consultation. Your role is to gather comprehensive information about the patient's condition through intelligent questioning. | |
**CONSULTATION APPROACH:** | |
- Ask thoughtful, relevant follow-up questions based on the patient's responses | |
- Prioritize gathering information about: symptom details, duration, severity, triggers, related symptoms, medical history, medications, and lifestyle factors | |
- Ask 1-2 specific questions at a time that build naturally on their previous answers | |
- Be empathetic, professional, and thorough in your questioning | |
- Adapt your questions based on the symptoms they describe | |
**IMPORTANT GUIDELINES:** | |
- Generate intelligent follow-up questions that are contextually relevant to their specific symptoms | |
- Don't ask generic questions - tailor each question to their particular situation | |
- If they mention pain, ask about location, type, and triggers | |
- If they mention duration, ask about progression or changes | |
- Build each question logically from their previous responses | |
After 4-5 meaningful exchanges, provide assessment and recommendations. | |
Do NOT make specific prescriptions for prescription-only drugs. | |
Always maintain a professional, caring tone throughout the consultation.""" | |
MEDITRON_PROMPT = """You are a board-certified physician providing evidence-based medical assessment. | |
Based on the patient information provided, please: | |
1. Analyze the symptoms systematically | |
2. Provide a differential diagnosis with most likely conditions | |
3. Recommend appropriate next steps (testing, monitoring, or treatment) | |
4. Suggest appropriate medications or remedies with dosing if applicable | |
5. Include red flags that would require urgent medical attention | |
6. Base recommendations on clinical guidelines | |
Patient Information: {patient_info} | |
Please provide a structured medical assessment:""" | |
# Load models | |
print("Loading models...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
LLAMA_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Llama-2 model loaded successfully!") | |
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL) | |
if meditron_tokenizer.pad_token is None: | |
meditron_tokenizer.pad_token = meditron_tokenizer.eos_token | |
meditron_model = AutoModelForCausalLM.from_pretrained( | |
MEDITRON_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Meditron model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
class MedicalConsultationBot: | |
def __init__(self): | |
self.reset_conversation() | |
def reset_conversation(self): | |
"""Reset all conversation state""" | |
self.conversation_history = [] | |
self.patient_name = None | |
self.patient_age = None | |
self.medical_turns = 0 | |
self.stage = "greeting" # greeting -> name -> age -> symptoms -> diagnosis | |
def add_to_history(self, user_message, bot_response): | |
"""Add exchange to conversation history""" | |
self.conversation_history.append({ | |
"user": user_message, | |
"bot": bot_response | |
}) | |
def get_conversation_context(self): | |
"""Get full conversation context as string""" | |
context = "" | |
if self.patient_name: | |
context += f"Patient Name: {self.patient_name}\n" | |
if self.patient_age: | |
context += f"Patient Age: {self.patient_age}\n" | |
context += "\nConversation History:\n" | |
for exchange in self.conversation_history: | |
context += f"Patient: {exchange['user']}\n" | |
context += f"Doctor: {exchange['bot']}\n" | |
return context | |
def build_llama_prompt(self, current_message): | |
"""Build prompt for Llama model with conversation context""" | |
prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n" | |
# Add conversation context | |
context = self.get_conversation_context() | |
if context.strip(): | |
prompt += f"Previous conversation context:\n{context}\n\n" | |
prompt += f"Current patient message: {current_message}\n\nProvide a professional medical response with appropriate follow-up questions. [/INST]" | |
return prompt | |
# Global bot instance | |
medical_bot = MedicalConsultationBot() | |
def get_meditron_diagnosis(patient_info): | |
"""Use Meditron model to generate medical assessment""" | |
try: | |
prompt = MEDITRON_PROMPT.format(patient_info=patient_info) | |
inputs = meditron_tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True | |
).to(meditron_model.device) | |
with torch.no_grad(): | |
outputs = meditron_model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=300, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=meditron_tokenizer.pad_token_id | |
) | |
response = meditron_tokenizer.decode( | |
outputs[0][inputs.input_ids.shape[1]:], | |
skip_special_tokens=True | |
).strip() | |
return response | |
except Exception as e: | |
return f"Error generating medical assessment: {str(e)}" | |
def medical_chat_response(message, history): | |
"""Main chat response function with proper state management""" | |
global medical_bot | |
# If this is a new conversation (empty history), reset the bot | |
if not history: | |
medical_bot.reset_conversation() | |
user_message = message.strip() | |
# Stage 1: Initial greeting and ask for name | |
if medical_bot.stage == "greeting": | |
bot_response = "Hello! I'm your AI medical assistant. Before we discuss your health concerns, could you please tell me your name?" | |
medical_bot.stage = "name" | |
medical_bot.add_to_history(user_message, bot_response) | |
return bot_response | |
# Stage 2: Collect name and ask for age | |
elif medical_bot.stage == "name": | |
medical_bot.patient_name = user_message | |
bot_response = f"Nice to meet you, {medical_bot.patient_name}! Could you please tell me your age?" | |
medical_bot.stage = "age" | |
medical_bot.add_to_history(user_message, bot_response) | |
return bot_response | |
# Stage 3: Collect age and start medical consultation | |
elif medical_bot.stage == "age": | |
medical_bot.patient_age = user_message | |
bot_response = f"Thank you, {medical_bot.patient_name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing." | |
medical_bot.stage = "symptoms" | |
medical_bot.add_to_history(user_message, bot_response) | |
return bot_response | |
# Stage 4: Medical consultation - gather symptoms with intelligent follow-ups | |
elif medical_bot.stage == "symptoms": | |
medical_bot.medical_turns += 1 | |
# If we've had enough turns, move to diagnosis | |
if medical_bot.medical_turns >= 4: | |
medical_bot.stage = "diagnosis" | |
return generate_final_diagnosis(user_message) | |
# Generate intelligent follow-up questions | |
try: | |
prompt = medical_bot.build_llama_prompt(user_message) | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True | |
).to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=200, | |
temperature=0.8, | |
top_p=0.95, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
bot_response = full_response.split('[/INST]')[-1].strip() | |
# Clean up the response | |
bot_response = bot_response.replace('<s>', '').replace('</s>', '').strip() | |
medical_bot.add_to_history(user_message, bot_response) | |
return bot_response | |
except Exception as e: | |
bot_response = f"I understand. Could you tell me more about how long you've been experiencing this and if there are any specific triggers or patterns you've noticed?" | |
medical_bot.add_to_history(user_message, bot_response) | |
return bot_response | |
# Stage 5: Final diagnosis and treatment recommendations | |
elif medical_bot.stage == "diagnosis": | |
return generate_final_diagnosis(user_message) | |
# Handle any questions after diagnosis | |
else: | |
# Check if they're asking about their name or previous information | |
if "name" in user_message.lower() and medical_bot.patient_name: | |
return f"Your name is {medical_bot.patient_name}." | |
elif "age" in user_message.lower() and medical_bot.patient_age: | |
return f"You told me you are {medical_bot.patient_age} years old." | |
else: | |
return "Is there anything else about your health concerns I can help you with today?" | |
def generate_final_diagnosis(current_message): | |
"""Generate final diagnosis using both models""" | |
global medical_bot | |
# Add current message to history | |
medical_bot.add_to_history(current_message, "") | |
# Compile complete patient information | |
patient_info = f""" | |
Patient Name: {medical_bot.patient_name} | |
Patient Age: {medical_bot.patient_age} | |
Complete Consultation History: | |
""" | |
for exchange in medical_bot.conversation_history[:-1]: # Exclude the empty last entry | |
patient_info += f"Doctor: {exchange['bot']}\n" | |
patient_info += f"Patient: {exchange['user']}\n" | |
patient_info += f"Patient: {current_message}\n" | |
# Get diagnosis from Meditron | |
meditron_assessment = get_meditron_diagnosis(patient_info) | |
# Generate comprehensive response | |
final_response = f"""Thank you for providing all this information, {medical_bot.patient_name}. Based on our consultation, here is my assessment: | |
**MEDICAL ASSESSMENT & RECOMMENDATIONS:** | |
{meditron_assessment} | |
**IMPORTANT DISCLAIMER:** This assessment is for informational purposes only and should not replace professional medical advice. Please consult with a healthcare provider for proper diagnosis and treatment. | |
**NEXT STEPS:** I recommend scheduling an appointment with your primary care physician or appropriate specialist for further evaluation and personalized treatment. | |
Is there anything specific about this assessment you'd like me to clarify?""" | |
# Update conversation history with final response | |
medical_bot.conversation_history[-1]["bot"] = final_response | |
medical_bot.stage = "complete" | |
return final_response | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
fn=medical_chat_response, | |
title="๐ฉบ AI Medical Assistant with Memory", | |
description="I'm an AI medical assistant that will remember our conversation. I'll first ask for your basic information, then gather details about your symptoms through intelligent follow-up questions, and finally provide a medical assessment.", | |
examples=[ | |
"Hello, I need medical help", | |
"I have a persistent cough", | |
"I've been having headaches", | |
"My stomach hurts" | |
], | |
theme="soft", | |
retry_btn=None, | |
undo_btn=None, | |
clear_btn="๐ Start New Consultation" | |
) | |
if __name__ == "__main__": | |
demo.launch() |