medbot_2 / app.py
techindia2025's picture
Update app.py
6d5190c verified
raw
history blame
3.14 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
SYSTEM_PROMPT = (
"You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. "
"Start by greeting the user politely and ask them to describe their health concern. Based on their input, ask follow-up questions to gather as much relevant information as possible. "
"Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity, possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history. "
"Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification. "
"Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician. "
"Ask one or two questions at a time, and wait for user input."
)
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto"
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=pipe)
print("Model loaded successfully!")
# LangChain prompt
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="history"),
("human", "{input}")
])
# Memory store
store = {}
def get_session_history(session_id: str) -> ChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
# Chain with memory
chain = prompt | llm
chain_with_history = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input",
history_messages_key="history"
)
@spaces.GPU
def gradio_chat(user_message, history):
session_id = "default-session" # For demo; can be made unique per user
response = chain_with_history.invoke(
{"input": user_message},
config={"configurable": {"session_id": session_id}}
)
# LangChain returns a "AIMessage" object; get text
return response.content if hasattr(response, "content") else str(response)
# Gradio UI
demo = gr.ChatInterface(
fn=gradio_chat,
title="Medbot Chatbot (Llama-2 + LangChain + Gradio)",
description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI."
)
if __name__ == "__main__":
demo.launch()