File size: 12,218 Bytes
b80af5b
71bcd31
9f6ac99
 
000ab02
71bcd31
 
 
 
 
01a984c
6e237a4
01a984c
 
 
 
 
 
6e237a4
01a984c
 
 
 
 
 
6e237a4
01a984c
6e237a4
01a984c
71bcd31
1728da9
5522bf8
1728da9
 
 
5522bf8
1728da9
 
 
71bcd31
1728da9
a7f6391
1728da9
 
 
 
 
 
 
 
43e5827
1728da9
 
 
 
 
 
43e5827
1728da9
 
 
43e5827
1728da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43e5827
1728da9
 
 
 
 
 
 
 
 
 
 
 
 
 
43e5827
1728da9
 
 
 
f3b4260
1728da9
a7f6391
1728da9
 
 
a7f6391
1728da9
 
 
 
a7f6391
1728da9
f3b4260
1728da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f6391
43e5827
1728da9
43e5827
 
1728da9
43e5827
 
 
1728da9
43e5827
 
1728da9
 
 
 
43e5827
1728da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f6391
1728da9
 
 
 
f3b4260
1728da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4447f4
1728da9
6d5190c
1728da9
 
 
8b29c0d
1728da9
43e5827
 
 
8b29c0d
1728da9
 
 
 
6d5190c
b80af5b
 
71bcd31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re

# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"

SYSTEM_PROMPT = """You are a professional virtual doctor conducting a medical consultation. Your role is to gather comprehensive information about the patient's condition through intelligent questioning.

**CONSULTATION APPROACH:**
- Ask thoughtful, relevant follow-up questions based on the patient's responses
- Prioritize gathering information about: symptom details, duration, severity, triggers, related symptoms, medical history, medications, and lifestyle factors
- Ask 1-2 specific questions at a time that build naturally on their previous answers
- Be empathetic, professional, and thorough in your questioning
- Adapt your questions based on the symptoms they describe

**IMPORTANT GUIDELINES:**
- Generate intelligent follow-up questions that are contextually relevant to their specific symptoms
- Don't ask generic questions - tailor each question to their particular situation
- If they mention pain, ask about location, type, and triggers
- If they mention duration, ask about progression or changes
- Build each question logically from their previous responses

After 4-5 meaningful exchanges, provide assessment and recommendations.
Do NOT make specific prescriptions for prescription-only drugs.
Always maintain a professional, caring tone throughout the consultation."""

MEDITRON_PROMPT = """You are a board-certified physician providing evidence-based medical assessment. 

Based on the patient information provided, please:
1. Analyze the symptoms systematically
2. Provide a differential diagnosis with most likely conditions
3. Recommend appropriate next steps (testing, monitoring, or treatment)
4. Suggest appropriate medications or remedies with dosing if applicable
5. Include red flags that would require urgent medical attention
6. Base recommendations on clinical guidelines

Patient Information: {patient_info}

Please provide a structured medical assessment:"""

# Load models
print("Loading models...")
try:
    tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        LLAMA_MODEL,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    print("Llama-2 model loaded successfully!")
    
    meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
    if meditron_tokenizer.pad_token is None:
        meditron_tokenizer.pad_token = meditron_tokenizer.eos_token
        
    meditron_model = AutoModelForCausalLM.from_pretrained(
        MEDITRON_MODEL,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    print("Meditron model loaded successfully!")
except Exception as e:
    print(f"Error loading models: {e}")

class MedicalConsultationBot:
    def __init__(self):
        self.reset_conversation()
    
    def reset_conversation(self):
        """Reset all conversation state"""
        self.conversation_history = []
        self.patient_name = None
        self.patient_age = None
        self.medical_turns = 0
        self.stage = "greeting"  # greeting -> name -> age -> symptoms -> diagnosis
        
    def add_to_history(self, user_message, bot_response):
        """Add exchange to conversation history"""
        self.conversation_history.append({
            "user": user_message,
            "bot": bot_response
        })
    
    def get_conversation_context(self):
        """Get full conversation context as string"""
        context = ""
        if self.patient_name:
            context += f"Patient Name: {self.patient_name}\n"
        if self.patient_age:
            context += f"Patient Age: {self.patient_age}\n"
        
        context += "\nConversation History:\n"
        for exchange in self.conversation_history:
            context += f"Patient: {exchange['user']}\n"
            context += f"Doctor: {exchange['bot']}\n"
        
        return context
    
    def build_llama_prompt(self, current_message):
        """Build prompt for Llama model with conversation context"""
        prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
        
        # Add conversation context
        context = self.get_conversation_context()
        if context.strip():
            prompt += f"Previous conversation context:\n{context}\n\n"
        
        prompt += f"Current patient message: {current_message}\n\nProvide a professional medical response with appropriate follow-up questions. [/INST]"
        
        return prompt

# Global bot instance
medical_bot = MedicalConsultationBot()

def get_meditron_diagnosis(patient_info):
    """Use Meditron model to generate medical assessment"""
    try:
        prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
        inputs = meditron_tokenizer(
            prompt, 
            return_tensors="pt", 
            max_length=512, 
            truncation=True
        ).to(meditron_model.device)
        
        with torch.no_grad():
            outputs = meditron_model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=300,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=meditron_tokenizer.pad_token_id
            )
        
        response = meditron_tokenizer.decode(
            outputs[0][inputs.input_ids.shape[1]:], 
            skip_special_tokens=True
        ).strip()
        
        return response
    except Exception as e:
        return f"Error generating medical assessment: {str(e)}"

@spaces.GPU
def medical_chat_response(message, history):
    """Main chat response function with proper state management"""
    global medical_bot
    
    # If this is a new conversation (empty history), reset the bot
    if not history:
        medical_bot.reset_conversation()
    
    user_message = message.strip()
    
    # Stage 1: Initial greeting and ask for name
    if medical_bot.stage == "greeting":
        bot_response = "Hello! I'm your AI medical assistant. Before we discuss your health concerns, could you please tell me your name?"
        medical_bot.stage = "name"
        medical_bot.add_to_history(user_message, bot_response)
        return bot_response
    
    # Stage 2: Collect name and ask for age
    elif medical_bot.stage == "name":
        medical_bot.patient_name = user_message
        bot_response = f"Nice to meet you, {medical_bot.patient_name}! Could you please tell me your age?"
        medical_bot.stage = "age"
        medical_bot.add_to_history(user_message, bot_response)
        return bot_response
    
    # Stage 3: Collect age and start medical consultation
    elif medical_bot.stage == "age":
        medical_bot.patient_age = user_message
        bot_response = f"Thank you, {medical_bot.patient_name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
        medical_bot.stage = "symptoms"
        medical_bot.add_to_history(user_message, bot_response)
        return bot_response
    
    # Stage 4: Medical consultation - gather symptoms with intelligent follow-ups
    elif medical_bot.stage == "symptoms":
        medical_bot.medical_turns += 1
        
        # If we've had enough turns, move to diagnosis
        if medical_bot.medical_turns >= 4:
            medical_bot.stage = "diagnosis"
            return generate_final_diagnosis(user_message)
        
        # Generate intelligent follow-up questions
        try:
            prompt = medical_bot.build_llama_prompt(user_message)
            inputs = tokenizer(
                prompt, 
                return_tensors="pt", 
                max_length=1024, 
                truncation=True
            ).to(model.device)
            
            with torch.no_grad():
                outputs = model.generate(
                    inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    max_new_tokens=200,
                    temperature=0.8,
                    top_p=0.95,
                    do_sample=True,
                    pad_token_id=tokenizer.pad_token_id
                )
            
            full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            bot_response = full_response.split('[/INST]')[-1].strip()
            
            # Clean up the response
            bot_response = bot_response.replace('<s>', '').replace('</s>', '').strip()
            
            medical_bot.add_to_history(user_message, bot_response)
            return bot_response
            
        except Exception as e:
            bot_response = f"I understand. Could you tell me more about how long you've been experiencing this and if there are any specific triggers or patterns you've noticed?"
            medical_bot.add_to_history(user_message, bot_response)
            return bot_response
    
    # Stage 5: Final diagnosis and treatment recommendations
    elif medical_bot.stage == "diagnosis":
        return generate_final_diagnosis(user_message)
    
    # Handle any questions after diagnosis
    else:
        # Check if they're asking about their name or previous information
        if "name" in user_message.lower() and medical_bot.patient_name:
            return f"Your name is {medical_bot.patient_name}."
        elif "age" in user_message.lower() and medical_bot.patient_age:
            return f"You told me you are {medical_bot.patient_age} years old."
        else:
            return "Is there anything else about your health concerns I can help you with today?"

def generate_final_diagnosis(current_message):
    """Generate final diagnosis using both models"""
    global medical_bot
    
    # Add current message to history
    medical_bot.add_to_history(current_message, "")
    
    # Compile complete patient information
    patient_info = f"""
Patient Name: {medical_bot.patient_name}
Patient Age: {medical_bot.patient_age}

Complete Consultation History:
"""
    
    for exchange in medical_bot.conversation_history[:-1]:  # Exclude the empty last entry
        patient_info += f"Doctor: {exchange['bot']}\n"
        patient_info += f"Patient: {exchange['user']}\n"
    
    patient_info += f"Patient: {current_message}\n"
    
    # Get diagnosis from Meditron
    meditron_assessment = get_meditron_diagnosis(patient_info)
    
    # Generate comprehensive response
    final_response = f"""Thank you for providing all this information, {medical_bot.patient_name}. Based on our consultation, here is my assessment:

**MEDICAL ASSESSMENT & RECOMMENDATIONS:**

{meditron_assessment}

**IMPORTANT DISCLAIMER:** This assessment is for informational purposes only and should not replace professional medical advice. Please consult with a healthcare provider for proper diagnosis and treatment.

**NEXT STEPS:** I recommend scheduling an appointment with your primary care physician or appropriate specialist for further evaluation and personalized treatment.

Is there anything specific about this assessment you'd like me to clarify?"""
    
    # Update conversation history with final response
    medical_bot.conversation_history[-1]["bot"] = final_response
    medical_bot.stage = "complete"
    
    return final_response

# Create Gradio interface
demo = gr.ChatInterface(
    fn=medical_chat_response,
    title="๐Ÿฉบ AI Medical Assistant with Memory",
    description="I'm an AI medical assistant that will remember our conversation. I'll first ask for your basic information, then gather details about your symptoms through intelligent follow-up questions, and finally provide a medical assessment.",
    examples=[
        "Hello, I need medical help",
        "I have a persistent cough",
        "I've been having headaches",
        "My stomach hurts"
    ],
    theme="soft",
    retry_btn=None,
    undo_btn=None,
    clear_btn="๐Ÿ”„ Start New Consultation"
)

if __name__ == "__main__":
    demo.launch()