Spaces:
Sleeping
Sleeping
File size: 6,535 Bytes
b80af5b 6d5190c b80af5b 8b29c0d 6d5190c aca454d 8b29c0d d5f0232 6d5190c 8b29c0d aca454d 645c015 6d5190c 645c015 6d5190c aca454d 8b29c0d 6d5190c aca454d 6d5190c 8b29c0d 6d5190c aca454d 645c015 6d5190c aca454d 8b29c0d 6d5190c 8b29c0d 645c015 8b29c0d 645c015 8b29c0d 645c015 8b29c0d b80af5b 8b29c0d 6d5190c 8b29c0d 6d5190c b80af5b 8b29c0d b80af5b 8b29c0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory
# Model configuration
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
# System prompt that guides the bot's behavior
SYSTEM_PROMPT = """
You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition,
symptoms, medical history, medications, lifestyle, and other relevant data. Start by greeting the user politely and ask
them to describe their health concern. For each user reply, ask only 1 or 2 follow-up questions at a time to gather more details.
Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity,
possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history.
Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification.
Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician.
Wait for the user's answer before asking more questions.
"""
print("Loading model...")
try:
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto"
)
# Create a pipeline for text generation
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=pipe)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback to a smaller model or provide an error message
raise
# Modify the prompt template with a clearer structure to prevent system prompt leakage
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
("system", "Remember to respond as Virtual Doctor without including system instructions in your reply.")
])
# Memory store to maintain conversation history
store = {}
def get_session_history(session_id: str) -> ChatMessageHistory:
"""Get or create a chat history for the given session ID"""
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
# Create a more robust filtering chain that will intercept the model's responses
def filter_response(response_text):
"""Filter out system prompts and format the response correctly"""
# Remove any system prompt references
if "system" in response_text.lower() and ("your goal is" in response_text.lower() or "professional virtual doctor" in response_text.lower()):
# Find the actual doctor response after any system text
for marker in ["Virtual Doctor:", "Virtual doctor:", "Human:"]:
if marker in response_text:
parts = response_text.split(marker)
if len(parts) > 1:
# Get the last part after any system prompts
response_text = parts[-1].strip()
break
# Remove any remaining system prompt text or instructions
filtered_text = []
skip_line = False
for line in response_text.split('\n'):
lower_line = line.lower()
if any(phrase in lower_line for phrase in [
"system:", "your goal is", "start by greeting", "wait for the user",
"do not make a final diagnosis", "be structured", "ask only 1 or 2"
]):
skip_line = True
elif any(marker in line for marker in ["Virtual Doctor:", "Virtual doctor:", "Hello", "Thank you"]):
skip_line = False
if not skip_line:
filtered_text.append(line)
clean_text = '\n'.join(filtered_text).strip()
# Ensure proper formatting with "Virtual Doctor:" prefix
if not clean_text.startswith("Virtual Doctor:") and not clean_text.startswith("Virtual doctor:"):
clean_text = f"Virtual Doctor: {clean_text}"
return clean_text
# Chain with memory
chain = prompt | llm
chain_with_history = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input",
history_messages_key="history"
)
# Our handler for chat interactions
@spaces.GPU # Request GPU for this space
def gradio_chat(user_message, history):
"""Process the user message and return the chatbot response"""
# Use a unique session ID in production
session_id = "default-session"
# Invoke the chain with history
try:
response = chain_with_history.invoke(
{"input": user_message},
config={"configurable": {"session_id": session_id}}
)
# Extract the text from the response
response_text = response.content if hasattr(response, "content") else str(response)
# Apply our filtering function to clean up the response
clean_response = filter_response(response_text)
return clean_response
except Exception as e:
print(f"Error processing message: {e}")
return "Virtual Doctor: I apologize, but I'm experiencing technical difficulties. Please try again."
# Customize the CSS for better appearance
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.chat-bot .bot-message {
background-color: #f0f7ff !important;
}
.chat-bot .user-message {
background-color: #e6f7e6 !important;
}
"""
# Create the Gradio interface
demo = gr.ChatInterface(
fn=gradio_chat,
title="Medbot Chatbot (Llama-2 + LangChain + Gradio)",
description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI.",
examples=[
"I have a cough and my throat hurts",
"I've been having headaches for a week",
"My stomach has been hurting since yesterday"
],
css=css
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=False) |