medbot_2 / app.py
Thanush
Enhance user information collection in app.py by emphasizing follow-up questions and refining response generation logic based on actual information turns.
a985489
raw
history blame
7.56 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.memory import ConversationBufferMemory
import re
# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's name, age, health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
Always begin by asking for the user's name and age if not already provided.
**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
- Detailed description of symptoms
- Duration (when did it start?)
- Severity (scale of 1-10)
- Aggravating or alleviating factors
- Related symptoms
- Medical history
- Current medications and allergies
After collecting sufficient information (at least 4-5 exchanges, but continue up to 10 if the user keeps responding), summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care.
If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.
Do NOT make specific prescriptions for prescription-only drugs.
Respond empathetically and clearly. Always be professional and thorough."""
MEDITRON_PROMPT = """<|im_start|>system
You are a specialized medical assistant focusing ONLY on suggesting over-the-counter medicines and home remedies based on patient information.
Based on the following patient information, provide ONLY:
1. One specific over-the-counter medicine with proper adult dosing instructions
2. One practical home remedy that might help
3. Clear guidance on when to seek professional medical care
Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
<|im_end|>
<|im_start|>user
Patient information: {patient_info}
<|im_end|>
<|im_start|>assistant
"""
print("Loading Llama-2 model...")
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
model = AutoModelForCausalLM.from_pretrained(
LLAMA_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Llama-2 model loaded successfully!")
print("Loading Meditron model...")
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
meditron_model = AutoModelForCausalLM.from_pretrained(
MEDITRON_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Meditron model loaded successfully!")
# Initialize LangChain memory
memory = ConversationBufferMemory(return_messages=True)
def build_llama2_prompt(system_prompt, messages, user_input):
"""Format the conversation history and user input for Llama-2 chat models, using the full message sequence."""
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
for msg in messages:
if msg.type == "human":
prompt += f"{msg.content} [/INST] "
elif msg.type == "ai":
prompt += f"{msg.content} </s><s>[INST] "
prompt += f"{user_input} [/INST] "
return prompt
def get_meditron_suggestions(patient_info):
"""Use Meditron model to generate medicine and remedy suggestions."""
prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
with torch.no_grad():
outputs = meditron_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True
)
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return suggestion
def extract_name_age(messages):
name, age = None, None
for msg in messages:
if msg.type == "human":
age_match = re.search(r"(?:I am|I'm|age is|aged|My age is)\s*(\d{1,3})", msg.content, re.IGNORECASE)
if age_match and not age:
age = age_match.group(1)
name_match = re.search(r"(?:my name is|I'm|I am)\s*([A-Za-z]+)", msg.content, re.IGNORECASE)
if name_match and not name:
name = name_match.group(1)
return name, age
@spaces.GPU
def generate_response(message, history):
"""Generate a response using both models, with full context."""
# Save the latest user message and last assistant response to memory
if history and len(history[-1]) == 2:
memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
memory.save_context({"input": message}, {"output": ""})
messages = memory.chat_memory.messages
name, age = extract_name_age(messages)
missing_info = []
if not name:
missing_info.append("your name")
if not age:
missing_info.append("your age")
if missing_info:
ask = "Before we continue, could you please tell me " + " and ".join(missing_info) + "?"
return ask
# Count how many user turns have actually provided new info (not just name/age)
info_turns = 0
for msg in messages:
if msg.type == "human":
# Ignore turns that only provide name/age
if not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", msg.content, re.IGNORECASE):
info_turns += 1
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
# Only add summarization ONCE, not on every turn after 4 info turns
if info_turns == 4:
prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate the Llama-2 response
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode and extract Llama-2's response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
# After 4 info turns, add medicine suggestions from Meditron, but only once
if info_turns == 4:
full_patient_info = "\n".join([
m.content for m in messages if m.type == "human" and not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", m.content, re.IGNORECASE)
] + [message]) + "\n\nSummary: " + llama_response
medicine_suggestions = get_meditron_suggestions(full_patient_info)
final_response = (
f"{llama_response}\n\n"
f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
f"{medicine_suggestions}"
)
return final_response
return llama_response
# Create the Gradio interface
demo = gr.ChatInterface(
fn=generate_response,
title="Medical Assistant with Medicine Suggestions",
description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.",
examples=[
"I have a cough and my throat hurts",
"I've been having headaches for a week",
"My stomach has been hurting since yesterday"
],
theme="soft"
)
if __name__ == "__main__":
demo.launch()