Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
# Model configuration | |
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
# System prompt that guides the bot's behavior | |
SYSTEM_PROMPT = """ | |
You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, | |
symptoms, medical history, medications, lifestyle, and other relevant data. Start by greeting the user politely and ask | |
them to describe their health concern. For each user reply, ask only 1 or 2 follow-up questions at a time to gather more details. | |
Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity, | |
possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history. | |
Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification. | |
Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician. | |
Wait for the user's answer before asking more questions. | |
""" | |
print("Loading model...") | |
try: | |
# Initialize the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
# Create a pipeline for text generation | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Fallback to a smaller model or provide an error message | |
raise | |
# Modify the prompt template with a clearer structure to prevent system prompt leakage | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", SYSTEM_PROMPT), | |
MessagesPlaceholder(variable_name="history"), | |
("human", "{input}"), | |
("system", "Remember to respond as Virtual Doctor without including system instructions in your reply.") | |
]) | |
# Memory store to maintain conversation history | |
store = {} | |
def get_session_history(session_id: str) -> ChatMessageHistory: | |
"""Get or create a chat history for the given session ID""" | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
# Create a more robust filtering chain that will intercept the model's responses | |
def filter_response(response_text): | |
"""Filter out system prompts and format the response correctly""" | |
# Remove any system prompt references | |
if "system" in response_text.lower() and ("your goal is" in response_text.lower() or "professional virtual doctor" in response_text.lower()): | |
# Find the actual doctor response after any system text | |
for marker in ["Virtual Doctor:", "Virtual doctor:", "Human:"]: | |
if marker in response_text: | |
parts = response_text.split(marker) | |
if len(parts) > 1: | |
# Get the last part after any system prompts | |
response_text = parts[-1].strip() | |
break | |
# Remove any remaining system prompt text or instructions | |
filtered_text = [] | |
skip_line = False | |
for line in response_text.split('\n'): | |
lower_line = line.lower() | |
if any(phrase in lower_line for phrase in [ | |
"system:", "your goal is", "start by greeting", "wait for the user", | |
"do not make a final diagnosis", "be structured", "ask only 1 or 2" | |
]): | |
skip_line = True | |
elif any(marker in line for marker in ["Virtual Doctor:", "Virtual doctor:", "Hello", "Thank you"]): | |
skip_line = False | |
if not skip_line: | |
filtered_text.append(line) | |
clean_text = '\n'.join(filtered_text).strip() | |
# Ensure proper formatting with "Virtual Doctor:" prefix | |
if not clean_text.startswith("Virtual Doctor:") and not clean_text.startswith("Virtual doctor:"): | |
clean_text = f"Virtual Doctor: {clean_text}" | |
return clean_text | |
# Chain with memory | |
chain = prompt | llm | |
chain_with_history = RunnableWithMessageHistory( | |
chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="history" | |
) | |
# Our handler for chat interactions | |
# Request GPU for this space | |
def gradio_chat(user_message, history): | |
"""Process the user message and return the chatbot response""" | |
# Use a unique session ID in production | |
session_id = "default-session" | |
# Invoke the chain with history | |
try: | |
response = chain_with_history.invoke( | |
{"input": user_message}, | |
config={"configurable": {"session_id": session_id}} | |
) | |
# Extract the text from the response | |
response_text = response.content if hasattr(response, "content") else str(response) | |
# Apply our filtering function to clean up the response | |
clean_response = filter_response(response_text) | |
return clean_response | |
except Exception as e: | |
print(f"Error processing message: {e}") | |
return "Virtual Doctor: I apologize, but I'm experiencing technical difficulties. Please try again." | |
# Customize the CSS for better appearance | |
css = """ | |
.gradio-container { | |
font-family: 'Arial', sans-serif; | |
} | |
.chat-bot .bot-message { | |
background-color: #f0f7ff !important; | |
} | |
.chat-bot .user-message { | |
background-color: #e6f7e6 !important; | |
} | |
""" | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=gradio_chat, | |
title="Medbot Chatbot (Llama-2 + LangChain + Gradio)", | |
description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI.", | |
examples=[ | |
"I have a cough and my throat hurts", | |
"I've been having headaches for a week", | |
"My stomach has been hurting since yesterday" | |
], | |
css=css | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=False) |