medbot_2 / app.py
techindia2025's picture
Update app.py
9f6ac99 verified
raw
history blame
3.42 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
Ask 1-2 follow-up questions at a time to gather more details about:
- Detailed description of symptoms
- Duration (when did it start?)
- Severity (scale of 1-10)
- Aggravating or alleviating factors
- Related symptoms
- Medical history
- Current medications and allergies
After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.
Respond empathetically and clearly. Always be professional and thorough."""
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto"
)
print("Model loaded successfully!")
# Conversation state tracking
conversation_turns = {}
def build_llama2_prompt(system_prompt, history, user_input):
"""Format the conversation history and user input for Llama-2 chat models."""
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
# Add conversation history
for user_msg, assistant_msg in history:
prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
# Add the current user input
prompt += f"{user_input} [/INST] "
return prompt
@spaces.GPU
def generate_response(message, history):
"""Generate a response using the Llama-2 model with proper formatting."""
# Track conversation turns
session_id = "default-session"
if session_id not in conversation_turns:
conversation_turns[session_id] = 0
conversation_turns[session_id] += 1
# Build the prompt with proper Llama-2 formatting
prompt = build_llama2_prompt(SYSTEM_PROMPT, history, message)
# Add summarization instruction after 4 turns
if conversation_turns[session_id] >= 4:
prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate the response
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode and extract only the assistant's response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
assistant_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
return assistant_response
# Create the Gradio interface
demo = gr.ChatInterface(
fn=generate_response,
title="Medical Assistant Chatbot",
description="Ask about your symptoms and I'll help gather relevant information.",
examples=[
"I have a cough and my throat hurts",
"I've been having headaches for a week",
"My stomach has been hurting since yesterday"
],
theme="soft"
)
if __name__ == "__main__":
demo.launch()