Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Define the medical assistant system prompt | |
SYSTEM_PROMPT = """ | |
You are a knowledgeable medical assistant. Follow these steps in order: | |
1) INITIAL ASSESSMENT: First, warmly greet the user and ask about their primary concern. | |
2) ASK FOLLOW-UP QUESTIONS: For any health concern mentioned, systematically gather information by asking 1-2 specific follow-up questions at a time about: | |
- Detailed description of symptoms | |
- Duration (when did it start?) | |
- Severity (scale of 1-10) | |
- Aggravating or alleviating factors | |
- Related symptoms | |
- Medical history | |
- Current medications and allergies | |
- Family history of similar conditions | |
3) SUMMARIZE FINDINGS: Once you have gathered sufficient information (at least 4-5 exchanges with the user), organize what you've learned into clear categories: | |
- Symptoms | |
- Duration | |
- Severity | |
- Possible Causes | |
- Medications/Allergies | |
- Family History | |
4) PROVIDE RECOMMENDATIONS: Only after gathering comprehensive information, suggest: | |
- One specific OTC medicine with proper adult dosing | |
- One practical home remedy | |
- When they should seek professional medical care | |
5) END WITH DISCLAIMER: Always end with a clear medical disclaimer that you are not a licensed medical professional and your suggestions are not a substitute for professional medical advice. | |
IMPORTANT: Do not skip ahead to recommendations without gathering comprehensive information through multiple exchanges. Your primary goal is information gathering through thoughtful questions. | |
""" | |
# Define model options | |
MODELS = { | |
"TinyLlama-1.1B": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"Llama-2-7b": "meta-llama/Llama-2-7b-chat-hf" | |
} | |
# Global variables to store loaded models and tokenizers | |
loaded_models = {} | |
loaded_tokenizers = {} | |
def load_model(model_name): | |
"""Load model and tokenizer if not already loaded""" | |
if model_name not in loaded_models: | |
print(f"Loading {model_name}...") | |
model_path = MODELS[model_name] | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype="auto", | |
device_map="auto" # Use GPU if available | |
) | |
loaded_models[model_name] = model | |
loaded_tokenizers[model_name] = tokenizer | |
print(f"{model_name} loaded successfully!") | |
return loaded_models[model_name], loaded_tokenizers[model_name] | |
# Pre-load the smaller model to start with | |
print("Pre-loading TinyLlama model...") | |
load_model("TinyLlama-1.1B") | |
# Required by ZeroGPU! | |
def generate_response(message, history, model_choice): | |
"""Generate a response from the selected model""" | |
# Load the selected model if not already loaded | |
model, tokenizer = load_model(model_choice) | |
# Format the prompt based on the history and system prompt | |
formatted_prompt = SYSTEM_PROMPT + "\n\n" | |
# Add conversation history | |
for human, assistant in history: | |
formatted_prompt += f"User: {human}\nAssistant: {assistant}\n" | |
# Add the current message | |
formatted_prompt += f"User: {message}\nAssistant:" | |
# Generate the response | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
) | |
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
return response.strip() | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medical Assistant Chatbot") | |
gr.Markdown("This chatbot uses LLM models to provide medical information and assistance. Please note that this is not a substitute for professional medical advice.") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value="TinyLlama-1.1B", | |
label="Select Model" | |
) | |
chatbot = gr.ChatInterface( | |
fn=lambda message, history, model_choice: generate_response(message, history, model_choice), | |
additional_inputs=[model_dropdown], | |
) | |
if __name__ == "__main__": | |
demo.launch() |