Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import spaces | |
# Model configuration - Using only Me-LLaMA 13B-chat | |
ME_LLAMA_MODEL = "clinicalnlplab/me-llama-13b-chat" | |
# System prompts for different phases | |
CONSULTATION_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. | |
Ask 1-2 follow-up questions at a time to gather more details about: | |
- Detailed description of symptoms | |
- Duration (when did it start?) | |
- Severity (scale of 1-10) | |
- Aggravating or alleviating factors | |
- Related symptoms | |
- Medical history | |
- Current medications and allergies | |
After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments. | |
Respond empathetically and clearly. Always be professional and thorough.""" | |
MEDICINE_PROMPT = """You are a specialized medical assistant. Based on the patient information gathered, provide: | |
1. One specific over-the-counter medicine with proper adult dosing instructions | |
2. One practical home remedy that might help | |
3. Clear guidance on when to seek professional medical care | |
Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional. | |
Patient information: {patient_info}""" | |
# Global variables | |
me_llama_model = None | |
me_llama_tokenizer = None | |
conversation_turns = 0 | |
patient_data = [] | |
def build_me_llama_prompt(system_prompt, history, user_input): | |
"""Format the conversation for Me-LLaMA chat model.""" | |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] " | |
# Add the current user input | |
prompt += f"{user_input} [/INST] " | |
return prompt | |
def load_model_if_needed(): | |
"""Load Me-LLaMA model only when GPU is available.""" | |
global me_llama_model, me_llama_tokenizer | |
if me_llama_model is None: | |
print("Loading Me-LLaMA 13B-chat model...") | |
me_llama_tokenizer = AutoTokenizer.from_pretrained(ME_LLAMA_MODEL) | |
me_llama_model = AutoModelForCausalLM.from_pretrained( | |
ME_LLAMA_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("Me-LLaMA 13B-chat model loaded successfully!") | |
def generate_medicine_suggestions(patient_info): | |
"""Use Me-LLaMA to generate medicine and remedy suggestions.""" | |
load_model_if_needed() | |
# Create a simple prompt for medicine suggestions | |
prompt = f"<s>[INST] {MEDICINE_PROMPT.format(patient_info=patient_info)} [/INST] " | |
inputs = me_llama_tokenizer(prompt, return_tensors="pt") | |
# Move inputs to the same device as the model | |
if torch.cuda.is_available(): | |
inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = me_llama_model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=300, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=me_llama_tokenizer.eos_token_id | |
) | |
suggestion = me_llama_tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
return suggestion | |
def generate_response(message, history): | |
"""Generate response using only Me-LLaMA for both consultation and medicine suggestions.""" | |
global conversation_turns, patient_data | |
# Load model if needed | |
load_model_if_needed() | |
# Track conversation turns | |
conversation_turns += 1 | |
# Store patient data | |
patient_data.append(message) | |
# Phase 1-3: Information gathering | |
if conversation_turns < 4: | |
# Build consultation prompt | |
prompt = build_me_llama_prompt(CONSULTATION_PROMPT, history, message) | |
inputs = me_llama_tokenizer(prompt, return_tensors="pt") | |
# Move inputs to the same device as the model | |
if torch.cuda.is_available(): | |
inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()} | |
# Generate consultation response | |
with torch.no_grad(): | |
outputs = me_llama_model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=400, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=me_llama_tokenizer.eos_token_id | |
) | |
# Decode response | |
full_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False) | |
response = full_response.split('[/INST]')[-1].split('</s>')[0].strip() | |
return response | |
# Phase 4+: Summary and medicine suggestions | |
else: | |
# First, get summary from consultation | |
summary_prompt = build_me_llama_prompt( | |
CONSULTATION_PROMPT + "\n\nNow summarize what you've learned and suggest when professional care may be needed.", | |
history, | |
message | |
) | |
inputs = me_llama_tokenizer(summary_prompt, return_tensors="pt") | |
if torch.cuda.is_available(): | |
inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()} | |
# Generate summary | |
with torch.no_grad(): | |
outputs = me_llama_model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=400, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=me_llama_tokenizer.eos_token_id | |
) | |
summary_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False) | |
summary = summary_response.split('[/INST]')[-1].split('</s>')[0].strip() | |
# Then get medicine suggestions using the same model | |
full_patient_info = "\n".join(patient_data) + f"\n\nMedical Summary: {summary}" | |
medicine_suggestions = generate_medicine_suggestions(full_patient_info) | |
# Combine both responses | |
final_response = ( | |
f"**MEDICAL SUMMARY:**\n{summary}\n\n" | |
f"**MEDICATION AND HOME CARE SUGGESTIONS:**\n{medicine_suggestions}\n\n" | |
f"**DISCLAIMER:** This is AI-generated advice for informational purposes only. Please consult a licensed healthcare provider for proper medical diagnosis and treatment." | |
) | |
return final_response | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
title="🏥 Complete Medical Assistant - Me-LLaMA 13B", | |
description="Comprehensive medical consultation powered by Me-LLaMA 13B-chat. One model handles both consultation and medicine suggestions. Tell me about your symptoms!", | |
examples=[ | |
"I have a persistent cough and sore throat for 3 days", | |
"I've been having severe headaches and feel dizzy", | |
"My stomach hurts and I feel nauseous after eating" | |
], | |
theme="soft" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |