Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. | |
Ask 1-2 follow-up questions at a time to gather more details about: | |
- Detailed description of symptoms | |
- Duration (when did it start?) | |
- Severity (scale of 1-10) | |
- Aggravating or alleviating factors | |
- Related symptoms | |
- Medical history | |
- Current medications and allergies | |
After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments. | |
Respond empathetically and clearly. Always be professional and thorough.""" | |
print("Loading model...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Model loaded successfully!") | |
# Conversation state tracking | |
conversation_turns = {} | |
def build_llama2_prompt(system_prompt, history, user_input): | |
"""Format the conversation history and user input for Llama-2 chat models.""" | |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] " | |
# Add the current user input | |
prompt += f"{user_input} [/INST] " | |
return prompt | |
def generate_response(message, history): | |
"""Generate a response using the Llama-2 model with proper formatting.""" | |
# Track conversation turns | |
session_id = "default-session" | |
if session_id not in conversation_turns: | |
conversation_turns[session_id] = 0 | |
conversation_turns[session_id] += 1 | |
# Build the prompt with proper Llama-2 formatting | |
prompt = build_llama2_prompt(SYSTEM_PROMPT, history, message) | |
# Add summarization instruction after 4 turns | |
if conversation_turns[session_id] >= 4: | |
prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ") | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate the response | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and extract only the assistant's response | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
assistant_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip() | |
return assistant_response | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
title="Medical Assistant Chatbot", | |
description="Ask about your symptoms and I'll help gather relevant information.", | |
examples=[ | |
"I have a cough and my throat hurts", | |
"I've been having headaches for a week", | |
"My stomach has been hurting since yesterday" | |
], | |
theme="soft" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |