testbot / app.py
josondev's picture
Update app.py
b0f7227 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import torch
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and tokenizer
model_name = "HuggingFaceH4/zephyr-7b-beta" # You can change this to any model you have access to
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto"
)
# Define the medical system prompt
MEDICAL_SYSTEM_PROMPT = """You are a medical AI assistant. You MUST ONLY answer health and medical-related questions.
Your responses should be professional, accurate, and focused on medical topics only.
For any non-medical questions, respond with a redirection to medical topics.
For medication queries, provide general information and recommend consulting a healthcare professional.
USE BULLET POINTS TO STRUCTURE YOUR ANSWERS AND THEY MUST BE PUNCTUATED AND GRAMMATICALLY CORRECT.
Your responses should be concise and informative.
For Questions About Body Parts/Organs:
- Provide anatomical details
- Explain primary functions
- Describe structure
- Mention related conditions
For Medication Queries:
- List generic/brand names
- Mention drug class
- Explain general uses
- Note common side effects
- List important drug interactions
- Include a disclaimer about consulting healthcare providers
For Symptom Analysis:
- List possible causes
- Describe key characteristics
- Suggest self-care measures
- Highlight when to seek immediate medical care
For Treatment Information:
- Explain conservative measures
- Describe medical interventions
- Suggest prevention strategies
- Outline follow-up care
For Emergency Questions:
- Provide immediate actions
- Suggest temporary measures while waiting for help
- Clarify when to contact emergency services
For Mental Health Questions:
- List common symptoms
- Suggest coping strategies
- Explain when to seek professional help
- Include crisis support information
For Preventive Care:
- Provide lifestyle recommendations
- Mention screening tests
- Include vaccination information if applicable
For Diet/Nutrition:
- Give dietary recommendations
- Highlight key nutrients
- Offer practical meal planning tips"""
# Function to check if a query is medical-related
def is_medical_query(query):
medical_keywords = [
"health", "disease", "symptom", "doctor", "medicine", "medical", "treatment",
"hospital", "clinic", "diagnosis", "patient", "drug", "prescription", "therapy",
"cancer", "diabetes", "heart", "blood", "pain", "surgery", "vaccine", "infection",
"allergy", "diet", "nutrition", "vitamin", "exercise", "mental health", "depression",
"anxiety", "disorder", "syndrome", "chronic", "acute", "emergency", "pharmacy",
"dosage", "side effect", "contraindication", "body", "organ", "immune", "virus",
"bacterial", "fungal", "parasite", "genetic", "hereditary", "congenital", "prenatal",
"headaches", "ache", "stomach ache", "skin", "head", "arm", "leg", "chest", "back",
"throat", "eye", "ear", "nose", "mouth"
]
query_lower = query.lower()
return any(keyword in query_lower for keyword in medical_keywords)
# Response function for the chatbot
def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95):
# Check if the query is medical-related
if not is_medical_query(message):
return "I'm specialized in medical topics only. I cannot answer this question. How can I assist with a health-related concern instead?"
# Format the conversation history
formatted_history = []
for user_msg, assistant_msg in history:
formatted_history.append({"role": "user", "content": user_msg})
formatted_history.append({"role": "assistant", "content": assistant_msg})
# Prepare the prompt
messages = [
{"role": "system", "content": MEDICAL_SYSTEM_PROMPT},
*formatted_history,
{"role": "user", "content": message}
]
# Format messages for the model
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Create the streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generate in a separate thread
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
partial_response = ""
for new_text in streamer:
partial_response += new_text
yield partial_response
# If no text was generated, return an error message
if not partial_response:
yield "I apologize, but I encountered an error generating a response. Please try again."
# Create the Gradio interface
demo = gr.ChatInterface(
fn=respond,
title="MedexDroid - Medical Assistant",
examples=[
"What are the symptoms of diabetes?",
"How can I improve my diet for heart health?",
"What is the treatment for a migraine?",
"What are the side effects of aspirin?",
"What are the causes of high blood pressure?"
],
description="An AI Medical Assistant. Please ask health-related questions only.",
theme=gr.themes.Soft(),
css=".gradio-container {background-color: #f0f4f8}"
)
if __name__ == "__main__":
demo.launch()