Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
# Model configuration | |
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
# System prompt that guides the bot's behavior | |
SYSTEM_PROMPT = """ | |
You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, | |
symptoms, medical history, medications, lifestyle, and other relevant data. Start by greeting the user politely and ask | |
them to describe their health concern. For each user reply, ask only 1 or 2 follow-up questions at a time to gather more details. | |
Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity, | |
possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history. | |
Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification. | |
Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician. | |
Wait for the user's answer before asking more questions. | |
""" | |
print("Loading model...") | |
try: | |
# Initialize the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
# Create a pipeline for text generation | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Fallback to a smaller model or provide an error message | |
raise | |
# LangChain prompt template | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", SYSTEM_PROMPT), | |
MessagesPlaceholder(variable_name="history"), | |
("human", "{input}") | |
]) | |
# Memory store to maintain conversation history | |
store = {} | |
def get_session_history(session_id: str) -> ChatMessageHistory: | |
"""Get or create a chat history for the given session ID""" | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
# Create a chain with memory | |
chain = prompt | llm | |
chain_with_history = RunnableWithMessageHistory( | |
chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="history" | |
) | |
# Our handler for chat interactions | |
# Request GPU for this space | |
def gradio_chat(user_message, history): | |
"""Process the user message and return the chatbot response""" | |
# Use a unique session ID in production | |
session_id = "default-session" | |
# Invoke the chain with history | |
try: | |
response = chain_with_history.invoke( | |
{"input": user_message}, | |
config={"configurable": {"session_id": session_id}} | |
) | |
# Extract the text from the response | |
response_text = response.content if hasattr(response, "content") else str(response) | |
# Format as "Virtual doctor: " response to match the expected format | |
formatted_response = f"Virtual doctor: {response_text}" | |
return formatted_response | |
except Exception as e: | |
print(f"Error processing message: {e}") | |
return "Virtual doctor: I apologize, but I'm experiencing technical difficulties. Please try again." | |
# Customize the CSS for better appearance | |
css = """ | |
.gradio-container { | |
font-family: 'Arial', sans-serif; | |
} | |
.chat-bot .bot-message { | |
background-color: #f0f7ff !important; | |
} | |
.chat-bot .user-message { | |
background-color: #e6f7e6 !important; | |
} | |
""" | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=gradio_chat, | |
title="Medbot Chatbot (Llama-2 + LangChain + Gradio)", | |
description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI.", | |
examples=[ | |
"I have a cough and my throat hurts", | |
"I've been having headaches for a week", | |
"My stomach has been hurting since yesterday" | |
], | |
css=css | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=False) |