Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
SYSTEM_PROMPT = ( | |
"You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. " | |
"Start by greeting the user politely and ask them to describe their health concern. Based on their input, ask follow-up questions to gather as much relevant information as possible. " | |
"Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity, possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history. " | |
"Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification. " | |
"Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician. " | |
"Ask one or two questions at a time, and wait for user input." | |
) | |
print("Loading model...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
print("Model loaded successfully!") | |
# LangChain prompt | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", SYSTEM_PROMPT), | |
MessagesPlaceholder(variable_name="history"), | |
("human", "{input}") | |
]) | |
# Memory store | |
store = {} | |
def get_session_history(session_id: str) -> ChatMessageHistory: | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
# Chain with memory | |
chain = prompt | llm | |
chain_with_history = RunnableWithMessageHistory( | |
chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="history" | |
) | |
def gradio_chat(user_message, history): | |
session_id = "default-session" # For demo; can be made unique per user | |
response = chain_with_history.invoke( | |
{"input": user_message}, | |
config={"configurable": {"session_id": session_id}} | |
) | |
# LangChain returns a "AIMessage" object; get text | |
return response.content if hasattr(response, "content") else str(response) | |
# Gradio UI | |
demo = gr.ChatInterface( | |
fn=gradio_chat, | |
title="Medbot Chatbot (Llama-2 + LangChain + Gradio)", | |
description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |