Spaces:
Sleeping
Sleeping
File size: 12,894 Bytes
b80af5b 71bcd31 9f6ac99 c4447f4 000ab02 71bcd31 6e237a4 a985489 71bcd31 6e237a4 71bcd31 5522bf8 71bcd31 c4447f4 71bcd31 d6da22c 71bcd31 a0597d0 a7f6391 d6da22c a7f6391 d6da22c 71bcd31 bdce857 71bcd31 a7f6391 000ab02 a7f6391 000ab02 a7f6391 000ab02 a7f6391 aa89cd7 a0597d0 c4447f4 a0597d0 a7f6391 000ab02 a7f6391 000ab02 a7f6391 000ab02 a7f6391 a985489 a7f6391 d6da22c a7f6391 d6da22c a7f6391 c4447f4 a7f6391 71bcd31 a7f6391 71bcd31 c4447f4 a7f6391 aa89cd7 a7f6391 aa89cd7 c4447f4 aa89cd7 b80af5b 71bcd31 6d5190c 71bcd31 a7f6391 8b29c0d a7f6391 8b29c0d 71bcd31 6d5190c b80af5b 71bcd31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.memory import ConversationBufferMemory
import re
# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's name, age, health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
Always begin by asking for the user's name and age if not already provided.
**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
- Detailed description of symptoms
- Duration (when did it start?)
- Severity (scale of 1-10)
- Aggravating or alleviating factors
- Related symptoms
- Medical history
- Current medications and allergies
After collecting sufficient information (at least 4-5 exchanges, but continue up to 10 if the user keeps responding), summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care.
If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.
Do NOT make specific prescriptions for prescription-only drugs.
Respond empathetically and clearly. Always be professional and thorough."""
MEDITRON_PROMPT = """<|im_start|>system
You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice.
For each patient case:
1. Analyze presented symptoms systematically using medical terminology
2. Create a structured differential diagnosis with most likely conditions first
3. Recommend appropriate next steps (testing, monitoring, or treatment)
4. Provide specific medication recommendations with precise dosing regimens
5. Include clear red flags that would necessitate urgent medical attention
6. Base all recommendations on current clinical guidelines and evidence-based medicine
7. Maintain professional, clear, and compassionate communication
Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers.
<|im_start|>user
Patient information: {patient_info}
<|im_end|>
<|im_start|>assistant
"""
print("Loading Llama-2 model...")
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
model = AutoModelForCausalLM.from_pretrained(
LLAMA_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Llama-2 model loaded successfully!")
print("Loading Meditron model...")
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
meditron_model = AutoModelForCausalLM.from_pretrained(
MEDITRON_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Meditron model loaded successfully!")
# Initialize LangChain memory
memory = ConversationBufferMemory(return_messages=True)
def build_llama2_prompt(system_prompt, messages, user_input, followup_stage=None):
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
for msg in messages:
if msg.type == "human":
prompt += f"{msg.content} [/INST] "
elif msg.type == "ai":
prompt += f"{msg.content} </s><s>[INST] "
# Add a specific follow-up question if in followup stage
if followup_stage is not None:
followup_questions = [
"Can you describe your main symptoms in more detail? What exactly are you experiencing?",
"How long have you been experiencing these symptoms? When did they first start?",
"On a scale of 1-10, how would you rate the severity of your symptoms?",
"Have you noticed anything that makes your symptoms better or worse? Any triggers or relief factors?",
"Do you have any other related symptoms, such as fever, fatigue, nausea, or changes in appetite?"
]
if followup_stage < len(followup_questions):
prompt += f"{followup_questions[followup_stage]} [/INST] "
else:
prompt += f"{user_input} [/INST] "
else:
prompt += f"{user_input} [/INST] "
return prompt
def get_meditron_suggestions(patient_info):
"""Use Meditron model to generate medicine and remedy suggestions."""
prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
with torch.no_grad():
outputs = meditron_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True
)
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return suggestion
def extract_name_age_intelligent(text):
"""Intelligently extract name and age from user input using multiple patterns."""
name, age = None, None
text_lower = text.lower().strip()
# Age extraction patterns (more comprehensive)
age_patterns = [
r'(?:i am|i\'m|im|age is|aged|my age is|years old|year old)\s*(\d{1,3})',
r'(\d{1,3})\s*(?:years old|year old|yrs old|yr old)',
r'\b(\d{1,3})\s*(?:and|,)?\s*(?:years|yrs|y\.o\.)',
r'(?:^|\s)(\d{1,3})(?:\s|$)', # standalone numbers
]
for pattern in age_patterns:
match = re.search(pattern, text_lower)
if match:
potential_age = int(match.group(1))
if 1 <= potential_age <= 120: # reasonable age range
age = str(potential_age)
break
# Name extraction patterns (more comprehensive)
name_patterns = [
r'(?:my name is|name is|i am|i\'m|im|call me|this is)\s+([a-zA-Z][a-zA-Z\s]{1,20}?)(?:\s+and|\s+\d|\s*$)',
r'^([a-zA-Z][a-zA-Z\s]{1,20}?)\s+(?:and|,)?\s*\d', # name followed by number
r'(?:^|\s)([a-zA-Z]{2,15})(?:\s+and|\s+\d)', # simple name pattern
]
for pattern in name_patterns:
match = re.search(pattern, text_lower)
if match:
potential_name = match.group(1).strip().title()
# Filter out common non-name words
non_names = ['it', 'is', 'am', 'my', 'me', 'the', 'and', 'or', 'but', 'yes', 'no']
if potential_name.lower() not in non_names and len(potential_name) >= 2:
name = potential_name
break
# Special case: handle "thanush and 23" or "it thanush and im 23" patterns
special_patterns = [
r'(?:it\s+)?([a-zA-Z]{2,15})\s+and\s+(?:im\s+|i\'m\s+)?(\d{1,3})',
r'([a-zA-Z]{2,15})\s+and\s+(\d{1,3})',
]
for pattern in special_patterns:
match = re.search(pattern, text_lower)
if match:
potential_name = match.group(1).strip().title()
potential_age = int(match.group(2))
if potential_name.lower() not in ['it', 'is', 'am'] and 1 <= potential_age <= 120:
name = potential_name
age = str(potential_age)
break
return name, age
def extract_name_age_from_all_messages(messages):
"""Extract name and age from all conversation messages."""
name, age = None, None
for msg in messages:
if msg.type == "human":
extracted_name, extracted_age = extract_name_age_intelligent(msg.content)
if extracted_name and not name:
name = extracted_name
if extracted_age and not age:
age = extracted_age
return name, age
def is_medical_symptom_message(text):
"""Check if the message contains medical symptoms rather than just name/age."""
medical_keywords = [
'hurt', 'pain', 'ache', 'sick', 'fever', 'cough', 'headache', 'stomach', 'throat',
'nausea', 'dizzy', 'tired', 'fatigue', 'breathe', 'chest', 'back', 'leg', 'arm',
'symptom', 'feel', 'suffering', 'problem', 'issue', 'uncomfortable', 'sore'
]
text_lower = text.lower()
return any(keyword in text_lower for keyword in medical_keywords)
@spaces.GPU
def generate_response(message, history):
"""Generate a response using both models, with full context."""
# Save the latest user message and last assistant response to memory
if history and len(history[-1]) == 2:
memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
memory.save_context({"input": message}, {"output": ""})
messages = memory.chat_memory.messages
# Extract name and age from all messages
name, age = extract_name_age_from_all_messages(messages)
# Check what information is missing
missing_info = []
if not name:
missing_info.append("your name")
if not age:
missing_info.append("your age")
# If missing basic info, ask for it
if missing_info:
ask = "Hello! Before we discuss your health concerns, could you please tell me " + " and ".join(missing_info) + "?"
return ask
# Count meaningful medical information exchanges (exclude name/age only messages)
medical_info_turns = 0
for msg in messages:
if msg.type == "human":
# Count only if it's not just name/age info and contains medical content
if is_medical_symptom_message(msg.content) or not any(keyword in msg.content.lower() for keyword in ['name', 'age', 'years', 'old', 'im', 'i am']):
medical_info_turns += 1
# Ensure we have at least one medical symptom mentioned
if medical_info_turns == 0 and not is_medical_symptom_message(message):
return f"Thank you, {name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
# Ask up to 5 intelligent follow-up questions, then provide diagnosis and treatment
if medical_info_turns < 5:
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=medical_info_turns)
else:
# Time for final diagnosis and treatment recommendations
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
prompt = prompt.replace("[/INST] ", "[/INST] Based on all the information provided, please provide a comprehensive assessment including: 1) Most likely diagnosis, 2) Recommended next steps, and 3) When to seek immediate medical attention. ")
# Generate response using Llama-2
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
# After 5 medical info turns, add Meditron suggestions
if medical_info_turns >= 4: # Start suggesting after 4+ turns
# Compile patient information for Meditron
patient_summary = f"Patient: {name}, Age: {age}\n\n"
patient_summary += "Medical Information:\n"
for msg in messages:
if msg.type == "human" and is_medical_symptom_message(msg.content):
patient_summary += f"- {msg.content}\n"
patient_summary += f"\nLatest input: {message}\n"
patient_summary += f"\nInitial Assessment: {llama_response}"
# Get Meditron suggestions
medicine_suggestions = get_meditron_suggestions(patient_summary)
final_response = (
f"{llama_response}\n\n"
f"--- MEDICATION AND HOME CARE RECOMMENDATIONS ---\n\n"
f"{medicine_suggestions}\n\n"
f"**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice, especially if symptoms persist or worsen."
)
return final_response
return llama_response
# Create the Gradio interface
demo = gr.ChatInterface(
fn=generate_response,
title="🩺 AI Medical Assistant with Treatment Suggestions",
description="Describe your symptoms and I'll gather information to provide medical insights and treatment recommendations.",
examples=[
"Hi, I'm Sarah and I'm 25. I have a persistent cough and sore throat.",
"My name is John, I'm 35, and I've been having severe headaches.",
"I'm Lisa, 28 years old, and my stomach has been hurting since yesterday."
],
theme="soft"
)
if __name__ == "__main__":
demo.launch() |