medbot_2 / app.py
techindia2025's picture
Update app.py
8b29c0d verified
raw
history blame
4.66 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory
# Model configuration
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
# System prompt that guides the bot's behavior
SYSTEM_PROMPT = """
You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition,
symptoms, medical history, medications, lifestyle, and other relevant data. Start by greeting the user politely and ask
them to describe their health concern. For each user reply, ask only 1 or 2 follow-up questions at a time to gather more details.
Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity,
possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history.
Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification.
Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician.
Wait for the user's answer before asking more questions.
"""
print("Loading model...")
try:
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto"
)
# Create a pipeline for text generation
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=pipe)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback to a smaller model or provide an error message
raise
# LangChain prompt template
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="history"),
("human", "{input}")
])
# Memory store to maintain conversation history
store = {}
def get_session_history(session_id: str) -> ChatMessageHistory:
"""Get or create a chat history for the given session ID"""
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
# Create a chain with memory
chain = prompt | llm
chain_with_history = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input",
history_messages_key="history"
)
# Our handler for chat interactions
@spaces.GPU # Request GPU for this space
def gradio_chat(user_message, history):
"""Process the user message and return the chatbot response"""
# Use a unique session ID in production
session_id = "default-session"
# Invoke the chain with history
try:
response = chain_with_history.invoke(
{"input": user_message},
config={"configurable": {"session_id": session_id}}
)
# Extract the text from the response
response_text = response.content if hasattr(response, "content") else str(response)
# Format as "Virtual doctor: " response to match the expected format
formatted_response = f"Virtual doctor: {response_text}"
return formatted_response
except Exception as e:
print(f"Error processing message: {e}")
return "Virtual doctor: I apologize, but I'm experiencing technical difficulties. Please try again."
# Customize the CSS for better appearance
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.chat-bot .bot-message {
background-color: #f0f7ff !important;
}
.chat-bot .user-message {
background-color: #e6f7e6 !important;
}
"""
# Create the Gradio interface
demo = gr.ChatInterface(
fn=gradio_chat,
title="Medbot Chatbot (Llama-2 + LangChain + Gradio)",
description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI.",
examples=[
"I have a cough and my throat hurts",
"I've been having headaches for a week",
"My stomach has been hurting since yesterday"
],
css=css
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=False)