File size: 12,894 Bytes
b80af5b
71bcd31
9f6ac99
 
c4447f4
000ab02
71bcd31
 
 
 
 
6e237a4
 
 
 
a985489
71bcd31
 
 
 
 
 
 
6e237a4
 
 
 
 
 
 
71bcd31
 
 
5522bf8
 
 
 
 
 
 
 
 
 
 
 
71bcd31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4447f4
 
71bcd31
d6da22c
71bcd31
a0597d0
 
 
 
 
a7f6391
d6da22c
 
 
a7f6391
 
 
 
 
d6da22c
 
 
 
 
 
 
71bcd31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdce857
 
71bcd31
 
 
a7f6391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
000ab02
a7f6391
000ab02
 
a7f6391
 
 
 
 
 
000ab02
 
a7f6391
 
 
 
 
 
 
 
 
 
 
aa89cd7
 
a0597d0
c4447f4
 
 
 
 
a0597d0
a7f6391
 
 
 
 
000ab02
 
 
 
 
a7f6391
 
000ab02
a7f6391
000ab02
a7f6391
 
 
a985489
 
a7f6391
 
 
 
 
 
 
 
 
 
 
d6da22c
a7f6391
d6da22c
a7f6391
c4447f4
a7f6391
71bcd31
 
 
 
 
 
 
 
 
 
 
a7f6391
71bcd31
 
c4447f4
a7f6391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa89cd7
 
a7f6391
 
 
aa89cd7
 
c4447f4
aa89cd7
b80af5b
71bcd31
6d5190c
71bcd31
a7f6391
 
8b29c0d
a7f6391
 
 
8b29c0d
71bcd31
6d5190c
b80af5b
 
71bcd31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.memory import ConversationBufferMemory
import re

# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"

SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's name, age, health condition, symptoms, medical history, medications, lifestyle, and other relevant data.

Always begin by asking for the user's name and age if not already provided.

**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
- Detailed description of symptoms
- Duration (when did it start?)
- Severity (scale of 1-10)
- Aggravating or alleviating factors
- Related symptoms
- Medical history
- Current medications and allergies

After collecting sufficient information (at least 4-5 exchanges, but continue up to 10 if the user keeps responding), summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care.

If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.

Do NOT make specific prescriptions for prescription-only drugs.

Respond empathetically and clearly. Always be professional and thorough."""

MEDITRON_PROMPT = """<|im_start|>system
You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice.

For each patient case:
1. Analyze presented symptoms systematically using medical terminology
2. Create a structured differential diagnosis with most likely conditions first
3. Recommend appropriate next steps (testing, monitoring, or treatment)
4. Provide specific medication recommendations with precise dosing regimens
5. Include clear red flags that would necessitate urgent medical attention
6. Base all recommendations on current clinical guidelines and evidence-based medicine
7. Maintain professional, clear, and compassionate communication

Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers.
<|im_start|>user
Patient information: {patient_info}
<|im_end|>
<|im_start|>assistant
"""

print("Loading Llama-2 model...")
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
model = AutoModelForCausalLM.from_pretrained(
    LLAMA_MODEL,
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Llama-2 model loaded successfully!")

print("Loading Meditron model...")
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
meditron_model = AutoModelForCausalLM.from_pretrained(
    MEDITRON_MODEL,
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Meditron model loaded successfully!")

# Initialize LangChain memory
memory = ConversationBufferMemory(return_messages=True)

def build_llama2_prompt(system_prompt, messages, user_input, followup_stage=None):
    prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
    for msg in messages:
        if msg.type == "human":
            prompt += f"{msg.content} [/INST] "
        elif msg.type == "ai":
            prompt += f"{msg.content} </s><s>[INST] "
    
    # Add a specific follow-up question if in followup stage
    if followup_stage is not None:
        followup_questions = [
            "Can you describe your main symptoms in more detail? What exactly are you experiencing?",
            "How long have you been experiencing these symptoms? When did they first start?",
            "On a scale of 1-10, how would you rate the severity of your symptoms?",
            "Have you noticed anything that makes your symptoms better or worse? Any triggers or relief factors?",
            "Do you have any other related symptoms, such as fever, fatigue, nausea, or changes in appetite?"
        ]
        if followup_stage < len(followup_questions):
            prompt += f"{followup_questions[followup_stage]} [/INST] "
        else:
            prompt += f"{user_input} [/INST] "
    else:
        prompt += f"{user_input} [/INST] "
    return prompt

def get_meditron_suggestions(patient_info):
    """Use Meditron model to generate medicine and remedy suggestions."""
    prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
    inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
    
    with torch.no_grad():
        outputs = meditron_model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=256,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
    
    suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return suggestion

def extract_name_age_intelligent(text):
    """Intelligently extract name and age from user input using multiple patterns."""
    name, age = None, None
    text_lower = text.lower().strip()
    
    # Age extraction patterns (more comprehensive)
    age_patterns = [
        r'(?:i am|i\'m|im|age is|aged|my age is|years old|year old)\s*(\d{1,3})',
        r'(\d{1,3})\s*(?:years old|year old|yrs old|yr old)',
        r'\b(\d{1,3})\s*(?:and|,)?\s*(?:years|yrs|y\.o\.)',
        r'(?:^|\s)(\d{1,3})(?:\s|$)',  # standalone numbers
    ]
    
    for pattern in age_patterns:
        match = re.search(pattern, text_lower)
        if match:
            potential_age = int(match.group(1))
            if 1 <= potential_age <= 120:  # reasonable age range
                age = str(potential_age)
                break
    
    # Name extraction patterns (more comprehensive)
    name_patterns = [
        r'(?:my name is|name is|i am|i\'m|im|call me|this is)\s+([a-zA-Z][a-zA-Z\s]{1,20}?)(?:\s+and|\s+\d|\s*$)',
        r'^([a-zA-Z][a-zA-Z\s]{1,20}?)\s+(?:and|,)?\s*\d',  # name followed by number
        r'(?:^|\s)([a-zA-Z]{2,15})(?:\s+and|\s+\d)',  # simple name pattern
    ]
    
    for pattern in name_patterns:
        match = re.search(pattern, text_lower)
        if match:
            potential_name = match.group(1).strip().title()
            # Filter out common non-name words
            non_names = ['it', 'is', 'am', 'my', 'me', 'the', 'and', 'or', 'but', 'yes', 'no']
            if potential_name.lower() not in non_names and len(potential_name) >= 2:
                name = potential_name
                break
    
    # Special case: handle "thanush and 23" or "it thanush and im 23" patterns
    special_patterns = [
        r'(?:it\s+)?([a-zA-Z]{2,15})\s+and\s+(?:im\s+|i\'m\s+)?(\d{1,3})',
        r'([a-zA-Z]{2,15})\s+and\s+(\d{1,3})',
    ]
    
    for pattern in special_patterns:
        match = re.search(pattern, text_lower)
        if match:
            potential_name = match.group(1).strip().title()
            potential_age = int(match.group(2))
            if potential_name.lower() not in ['it', 'is', 'am'] and 1 <= potential_age <= 120:
                name = potential_name
                age = str(potential_age)
                break
    
    return name, age

def extract_name_age_from_all_messages(messages):
    """Extract name and age from all conversation messages."""
    name, age = None, None
    
    for msg in messages:
        if msg.type == "human":
            extracted_name, extracted_age = extract_name_age_intelligent(msg.content)
            if extracted_name and not name:
                name = extracted_name
            if extracted_age and not age:
                age = extracted_age
    
    return name, age

def is_medical_symptom_message(text):
    """Check if the message contains medical symptoms rather than just name/age."""
    medical_keywords = [
        'hurt', 'pain', 'ache', 'sick', 'fever', 'cough', 'headache', 'stomach', 'throat',
        'nausea', 'dizzy', 'tired', 'fatigue', 'breathe', 'chest', 'back', 'leg', 'arm',
        'symptom', 'feel', 'suffering', 'problem', 'issue', 'uncomfortable', 'sore'
    ]
    
    text_lower = text.lower()
    return any(keyword in text_lower for keyword in medical_keywords)

@spaces.GPU
def generate_response(message, history):
    """Generate a response using both models, with full context."""
    # Save the latest user message and last assistant response to memory
    if history and len(history[-1]) == 2:
        memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
    memory.save_context({"input": message}, {"output": ""})

    messages = memory.chat_memory.messages
    
    # Extract name and age from all messages
    name, age = extract_name_age_from_all_messages(messages)
    
    # Check what information is missing
    missing_info = []
    if not name:
        missing_info.append("your name")
    if not age:
        missing_info.append("your age")
    
    # If missing basic info, ask for it
    if missing_info:
        ask = "Hello! Before we discuss your health concerns, could you please tell me " + " and ".join(missing_info) + "?"
        return ask
    
    # Count meaningful medical information exchanges (exclude name/age only messages)
    medical_info_turns = 0
    for msg in messages:
        if msg.type == "human":
            # Count only if it's not just name/age info and contains medical content
            if is_medical_symptom_message(msg.content) or not any(keyword in msg.content.lower() for keyword in ['name', 'age', 'years', 'old', 'im', 'i am']):
                medical_info_turns += 1
    
    # Ensure we have at least one medical symptom mentioned
    if medical_info_turns == 0 and not is_medical_symptom_message(message):
        return f"Thank you, {name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
    
    # Ask up to 5 intelligent follow-up questions, then provide diagnosis and treatment
    if medical_info_turns < 5:
        prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=medical_info_turns)
    else:
        # Time for final diagnosis and treatment recommendations
        prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
        prompt = prompt.replace("[/INST] ", "[/INST] Based on all the information provided, please provide a comprehensive assessment including: 1) Most likely diagnosis, 2) Recommended next steps, and 3) When to seek immediate medical attention. ")

    # Generate response using Llama-2
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()

    # After 5 medical info turns, add Meditron suggestions
    if medical_info_turns >= 4:  # Start suggesting after 4+ turns
        # Compile patient information for Meditron
        patient_summary = f"Patient: {name}, Age: {age}\n\n"
        patient_summary += "Medical Information:\n"
        
        for msg in messages:
            if msg.type == "human" and is_medical_symptom_message(msg.content):
                patient_summary += f"- {msg.content}\n"
        
        patient_summary += f"\nLatest input: {message}\n"
        patient_summary += f"\nInitial Assessment: {llama_response}"
        
        # Get Meditron suggestions
        medicine_suggestions = get_meditron_suggestions(patient_summary)
        
        final_response = (
            f"{llama_response}\n\n"
            f"--- MEDICATION AND HOME CARE RECOMMENDATIONS ---\n\n"
            f"{medicine_suggestions}\n\n"
            f"**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice, especially if symptoms persist or worsen."
        )
        return final_response

    return llama_response

# Create the Gradio interface
demo = gr.ChatInterface(
    fn=generate_response,
    title="🩺 AI Medical Assistant with Treatment Suggestions",
    description="Describe your symptoms and I'll gather information to provide medical insights and treatment recommendations.",
    examples=[
        "Hi, I'm Sarah and I'm 25. I have a persistent cough and sore throat.",
        "My name is John, I'm 35, and I've been having severe headaches.",
        "I'm Lisa, 28 years old, and my stomach has been hurting since yesterday."
    ],
    theme="soft"
)

if __name__ == "__main__":
    demo.launch()