|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from threading import Thread |
|
import torch |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_name = "HuggingFaceH4/zephyr-7b-beta" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
device_map="auto" |
|
) |
|
|
|
|
|
MEDICAL_SYSTEM_PROMPT = """You are a medical AI assistant. You MUST ONLY answer health and medical-related questions. |
|
Your responses should be professional, accurate, and focused on medical topics only. |
|
For any non-medical questions, respond with a redirection to medical topics. |
|
For medication queries, provide general information and recommend consulting a healthcare professional. |
|
|
|
USE BULLET POINTS TO STRUCTURE YOUR ANSWERS AND THEY MUST BE PUNCTUATED AND GRAMMATICALLY CORRECT. |
|
Your responses should be concise and informative. |
|
|
|
For Questions About Body Parts/Organs: |
|
- Provide anatomical details |
|
- Explain primary functions |
|
- Describe structure |
|
- Mention related conditions |
|
|
|
For Medication Queries: |
|
- List generic/brand names |
|
- Mention drug class |
|
- Explain general uses |
|
- Note common side effects |
|
- List important drug interactions |
|
- Include a disclaimer about consulting healthcare providers |
|
|
|
For Symptom Analysis: |
|
- List possible causes |
|
- Describe key characteristics |
|
- Suggest self-care measures |
|
- Highlight when to seek immediate medical care |
|
|
|
For Treatment Information: |
|
- Explain conservative measures |
|
- Describe medical interventions |
|
- Suggest prevention strategies |
|
- Outline follow-up care |
|
|
|
For Emergency Questions: |
|
- Provide immediate actions |
|
- Suggest temporary measures while waiting for help |
|
- Clarify when to contact emergency services |
|
|
|
For Mental Health Questions: |
|
- List common symptoms |
|
- Suggest coping strategies |
|
- Explain when to seek professional help |
|
- Include crisis support information |
|
|
|
For Preventive Care: |
|
- Provide lifestyle recommendations |
|
- Mention screening tests |
|
- Include vaccination information if applicable |
|
|
|
For Diet/Nutrition: |
|
- Give dietary recommendations |
|
- Highlight key nutrients |
|
- Offer practical meal planning tips""" |
|
|
|
|
|
def is_medical_query(query): |
|
medical_keywords = [ |
|
"health", "disease", "symptom", "doctor", "medicine", "medical", "treatment", |
|
"hospital", "clinic", "diagnosis", "patient", "drug", "prescription", "therapy", |
|
"cancer", "diabetes", "heart", "blood", "pain", "surgery", "vaccine", "infection", |
|
"allergy", "diet", "nutrition", "vitamin", "exercise", "mental health", "depression", |
|
"anxiety", "disorder", "syndrome", "chronic", "acute", "emergency", "pharmacy", |
|
"dosage", "side effect", "contraindication", "body", "organ", "immune", "virus", |
|
"bacterial", "fungal", "parasite", "genetic", "hereditary", "congenital", "prenatal", |
|
"headaches", "ache", "stomach ache", "skin", "head", "arm", "leg", "chest", "back", |
|
"throat", "eye", "ear", "nose", "mouth" |
|
] |
|
|
|
query_lower = query.lower() |
|
return any(keyword in query_lower for keyword in medical_keywords) |
|
|
|
|
|
def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95): |
|
|
|
if not is_medical_query(message): |
|
return "I'm specialized in medical topics only. I cannot answer this question. How can I assist with a health-related concern instead?" |
|
|
|
|
|
formatted_history = [] |
|
for user_msg, assistant_msg in history: |
|
formatted_history.append({"role": "user", "content": user_msg}) |
|
formatted_history.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": MEDICAL_SYSTEM_PROMPT}, |
|
*formatted_history, |
|
{"role": "user", "content": message} |
|
] |
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = dict( |
|
inputs, |
|
streamer=streamer, |
|
max_new_tokens=max_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
|
|
partial_response = "" |
|
for new_text in streamer: |
|
partial_response += new_text |
|
yield partial_response |
|
|
|
|
|
if not partial_response: |
|
yield "I apologize, but I encountered an error generating a response. Please try again." |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=respond, |
|
title="MedexDroid - Medical Assistant", |
|
examples=[ |
|
"What are the symptoms of diabetes?", |
|
"How can I improve my diet for heart health?", |
|
"What is the treatment for a migraine?", |
|
"What are the side effects of aspirin?", |
|
"What are the causes of high blood pressure?" |
|
], |
|
description="An AI Medical Assistant. Please ask health-related questions only.", |
|
theme=gr.themes.Soft(), |
|
css=".gradio-container {background-color: #f0f4f8}" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|
|
|