medbot_2 / app.py
techindia2025's picture
Update app.py
1728da9 verified
raw
history blame
12.2 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"
SYSTEM_PROMPT = """You are a professional virtual doctor conducting a medical consultation. Your role is to gather comprehensive information about the patient's condition through intelligent questioning.
**CONSULTATION APPROACH:**
- Ask thoughtful, relevant follow-up questions based on the patient's responses
- Prioritize gathering information about: symptom details, duration, severity, triggers, related symptoms, medical history, medications, and lifestyle factors
- Ask 1-2 specific questions at a time that build naturally on their previous answers
- Be empathetic, professional, and thorough in your questioning
- Adapt your questions based on the symptoms they describe
**IMPORTANT GUIDELINES:**
- Generate intelligent follow-up questions that are contextually relevant to their specific symptoms
- Don't ask generic questions - tailor each question to their particular situation
- If they mention pain, ask about location, type, and triggers
- If they mention duration, ask about progression or changes
- Build each question logically from their previous responses
After 4-5 meaningful exchanges, provide assessment and recommendations.
Do NOT make specific prescriptions for prescription-only drugs.
Always maintain a professional, caring tone throughout the consultation."""
MEDITRON_PROMPT = """You are a board-certified physician providing evidence-based medical assessment.
Based on the patient information provided, please:
1. Analyze the symptoms systematically
2. Provide a differential diagnosis with most likely conditions
3. Recommend appropriate next steps (testing, monitoring, or treatment)
4. Suggest appropriate medications or remedies with dosing if applicable
5. Include red flags that would require urgent medical attention
6. Base recommendations on clinical guidelines
Patient Information: {patient_info}
Please provide a structured medical assessment:"""
# Load models
print("Loading models...")
try:
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
LLAMA_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Llama-2 model loaded successfully!")
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
if meditron_tokenizer.pad_token is None:
meditron_tokenizer.pad_token = meditron_tokenizer.eos_token
meditron_model = AutoModelForCausalLM.from_pretrained(
MEDITRON_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Meditron model loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
class MedicalConsultationBot:
def __init__(self):
self.reset_conversation()
def reset_conversation(self):
"""Reset all conversation state"""
self.conversation_history = []
self.patient_name = None
self.patient_age = None
self.medical_turns = 0
self.stage = "greeting" # greeting -> name -> age -> symptoms -> diagnosis
def add_to_history(self, user_message, bot_response):
"""Add exchange to conversation history"""
self.conversation_history.append({
"user": user_message,
"bot": bot_response
})
def get_conversation_context(self):
"""Get full conversation context as string"""
context = ""
if self.patient_name:
context += f"Patient Name: {self.patient_name}\n"
if self.patient_age:
context += f"Patient Age: {self.patient_age}\n"
context += "\nConversation History:\n"
for exchange in self.conversation_history:
context += f"Patient: {exchange['user']}\n"
context += f"Doctor: {exchange['bot']}\n"
return context
def build_llama_prompt(self, current_message):
"""Build prompt for Llama model with conversation context"""
prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
# Add conversation context
context = self.get_conversation_context()
if context.strip():
prompt += f"Previous conversation context:\n{context}\n\n"
prompt += f"Current patient message: {current_message}\n\nProvide a professional medical response with appropriate follow-up questions. [/INST]"
return prompt
# Global bot instance
medical_bot = MedicalConsultationBot()
def get_meditron_diagnosis(patient_info):
"""Use Meditron model to generate medical assessment"""
try:
prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
inputs = meditron_tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True
).to(meditron_model.device)
with torch.no_grad():
outputs = meditron_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=meditron_tokenizer.pad_token_id
)
response = meditron_tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
).strip()
return response
except Exception as e:
return f"Error generating medical assessment: {str(e)}"
@spaces.GPU
def medical_chat_response(message, history):
"""Main chat response function with proper state management"""
global medical_bot
# If this is a new conversation (empty history), reset the bot
if not history:
medical_bot.reset_conversation()
user_message = message.strip()
# Stage 1: Initial greeting and ask for name
if medical_bot.stage == "greeting":
bot_response = "Hello! I'm your AI medical assistant. Before we discuss your health concerns, could you please tell me your name?"
medical_bot.stage = "name"
medical_bot.add_to_history(user_message, bot_response)
return bot_response
# Stage 2: Collect name and ask for age
elif medical_bot.stage == "name":
medical_bot.patient_name = user_message
bot_response = f"Nice to meet you, {medical_bot.patient_name}! Could you please tell me your age?"
medical_bot.stage = "age"
medical_bot.add_to_history(user_message, bot_response)
return bot_response
# Stage 3: Collect age and start medical consultation
elif medical_bot.stage == "age":
medical_bot.patient_age = user_message
bot_response = f"Thank you, {medical_bot.patient_name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
medical_bot.stage = "symptoms"
medical_bot.add_to_history(user_message, bot_response)
return bot_response
# Stage 4: Medical consultation - gather symptoms with intelligent follow-ups
elif medical_bot.stage == "symptoms":
medical_bot.medical_turns += 1
# If we've had enough turns, move to diagnosis
if medical_bot.medical_turns >= 4:
medical_bot.stage = "diagnosis"
return generate_final_diagnosis(user_message)
# Generate intelligent follow-up questions
try:
prompt = medical_bot.build_llama_prompt(user_message)
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=1024,
truncation=True
).to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=200,
temperature=0.8,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.pad_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
bot_response = full_response.split('[/INST]')[-1].strip()
# Clean up the response
bot_response = bot_response.replace('<s>', '').replace('</s>', '').strip()
medical_bot.add_to_history(user_message, bot_response)
return bot_response
except Exception as e:
bot_response = f"I understand. Could you tell me more about how long you've been experiencing this and if there are any specific triggers or patterns you've noticed?"
medical_bot.add_to_history(user_message, bot_response)
return bot_response
# Stage 5: Final diagnosis and treatment recommendations
elif medical_bot.stage == "diagnosis":
return generate_final_diagnosis(user_message)
# Handle any questions after diagnosis
else:
# Check if they're asking about their name or previous information
if "name" in user_message.lower() and medical_bot.patient_name:
return f"Your name is {medical_bot.patient_name}."
elif "age" in user_message.lower() and medical_bot.patient_age:
return f"You told me you are {medical_bot.patient_age} years old."
else:
return "Is there anything else about your health concerns I can help you with today?"
def generate_final_diagnosis(current_message):
"""Generate final diagnosis using both models"""
global medical_bot
# Add current message to history
medical_bot.add_to_history(current_message, "")
# Compile complete patient information
patient_info = f"""
Patient Name: {medical_bot.patient_name}
Patient Age: {medical_bot.patient_age}
Complete Consultation History:
"""
for exchange in medical_bot.conversation_history[:-1]: # Exclude the empty last entry
patient_info += f"Doctor: {exchange['bot']}\n"
patient_info += f"Patient: {exchange['user']}\n"
patient_info += f"Patient: {current_message}\n"
# Get diagnosis from Meditron
meditron_assessment = get_meditron_diagnosis(patient_info)
# Generate comprehensive response
final_response = f"""Thank you for providing all this information, {medical_bot.patient_name}. Based on our consultation, here is my assessment:
**MEDICAL ASSESSMENT & RECOMMENDATIONS:**
{meditron_assessment}
**IMPORTANT DISCLAIMER:** This assessment is for informational purposes only and should not replace professional medical advice. Please consult with a healthcare provider for proper diagnosis and treatment.
**NEXT STEPS:** I recommend scheduling an appointment with your primary care physician or appropriate specialist for further evaluation and personalized treatment.
Is there anything specific about this assessment you'd like me to clarify?"""
# Update conversation history with final response
medical_bot.conversation_history[-1]["bot"] = final_response
medical_bot.stage = "complete"
return final_response
# Create Gradio interface
demo = gr.ChatInterface(
fn=medical_chat_response,
title="๐Ÿฉบ AI Medical Assistant with Memory",
description="I'm an AI medical assistant that will remember our conversation. I'll first ask for your basic information, then gather details about your symptoms through intelligent follow-up questions, and finally provide a medical assessment.",
examples=[
"Hello, I need medical help",
"I have a persistent cough",
"I've been having headaches",
"My stomach hurts"
],
theme="soft",
retry_btn=None,
undo_btn=None,
clear_btn="๐Ÿ”„ Start New Consultation"
)
if __name__ == "__main__":
demo.launch()