Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,656 Bytes
b80af5b 6d5190c b80af5b 8b29c0d 6d5190c aca454d 8b29c0d d5f0232 6d5190c 8b29c0d aca454d 8b29c0d 6d5190c aca454d 8b29c0d 6d5190c aca454d 6d5190c 8b29c0d 6d5190c aca454d 8b29c0d 6d5190c aca454d 8b29c0d 6d5190c 8b29c0d b80af5b 8b29c0d 6d5190c 8b29c0d 6d5190c b80af5b 8b29c0d b80af5b 8b29c0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory
# Model configuration
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
# System prompt that guides the bot's behavior
SYSTEM_PROMPT = """
You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition,
symptoms, medical history, medications, lifestyle, and other relevant data. Start by greeting the user politely and ask
them to describe their health concern. For each user reply, ask only 1 or 2 follow-up questions at a time to gather more details.
Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity,
possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history.
Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification.
Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician.
Wait for the user's answer before asking more questions.
"""
print("Loading model...")
try:
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto"
)
# Create a pipeline for text generation
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=pipe)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Fallback to a smaller model or provide an error message
raise
# LangChain prompt template
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="history"),
("human", "{input}")
])
# Memory store to maintain conversation history
store = {}
def get_session_history(session_id: str) -> ChatMessageHistory:
"""Get or create a chat history for the given session ID"""
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
# Create a chain with memory
chain = prompt | llm
chain_with_history = RunnableWithMessageHistory(
chain,
get_session_history,
input_messages_key="input",
history_messages_key="history"
)
# Our handler for chat interactions
@spaces.GPU # Request GPU for this space
def gradio_chat(user_message, history):
"""Process the user message and return the chatbot response"""
# Use a unique session ID in production
session_id = "default-session"
# Invoke the chain with history
try:
response = chain_with_history.invoke(
{"input": user_message},
config={"configurable": {"session_id": session_id}}
)
# Extract the text from the response
response_text = response.content if hasattr(response, "content") else str(response)
# Format as "Virtual doctor: " response to match the expected format
formatted_response = f"Virtual doctor: {response_text}"
return formatted_response
except Exception as e:
print(f"Error processing message: {e}")
return "Virtual doctor: I apologize, but I'm experiencing technical difficulties. Please try again."
# Customize the CSS for better appearance
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.chat-bot .bot-message {
background-color: #f0f7ff !important;
}
.chat-bot .user-message {
background-color: #e6f7e6 !important;
}
"""
# Create the Gradio interface
demo = gr.ChatInterface(
fn=gradio_chat,
title="Medbot Chatbot (Llama-2 + LangChain + Gradio)",
description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI.",
examples=[
"I have a cough and my throat hurts",
"I've been having headaches for a week",
"My stomach has been hurting since yesterday"
],
css=css
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=False) |