File size: 4,656 Bytes
b80af5b
 
6d5190c
 
 
 
 
b80af5b
8b29c0d
6d5190c
aca454d
8b29c0d
 
 
 
 
 
 
 
 
 
 
d5f0232
6d5190c
8b29c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aca454d
8b29c0d
6d5190c
 
 
 
 
aca454d
8b29c0d
6d5190c
aca454d
6d5190c
8b29c0d
6d5190c
 
 
aca454d
8b29c0d
6d5190c
 
 
 
 
 
 
aca454d
8b29c0d
 
6d5190c
8b29c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b80af5b
8b29c0d
6d5190c
 
 
8b29c0d
 
 
 
 
 
 
6d5190c
b80af5b
8b29c0d
b80af5b
8b29c0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory

# Model configuration
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

# System prompt that guides the bot's behavior
SYSTEM_PROMPT = """
You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, 
symptoms, medical history, medications, lifestyle, and other relevant data. Start by greeting the user politely and ask 
them to describe their health concern. For each user reply, ask only 1 or 2 follow-up questions at a time to gather more details. 
Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity, 
possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history. 
Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification. 
Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician. 
Wait for the user's answer before asking more questions.
"""

print("Loading model...")
try:
    # Initialize the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype="auto",
        device_map="auto"
    )
    
    # Create a pipeline for text generation
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    
    llm = HuggingFacePipeline(pipeline=pipe)
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    # Fallback to a smaller model or provide an error message
    raise

# LangChain prompt template
prompt = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_PROMPT),
    MessagesPlaceholder(variable_name="history"),
    ("human", "{input}")
])

# Memory store to maintain conversation history
store = {}

def get_session_history(session_id: str) -> ChatMessageHistory:
    """Get or create a chat history for the given session ID"""
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

# Create a chain with memory
chain = prompt | llm
chain_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history"
)

# Our handler for chat interactions
@spaces.GPU  # Request GPU for this space
def gradio_chat(user_message, history):
    """Process the user message and return the chatbot response"""
    # Use a unique session ID in production
    session_id = "default-session"
    
    # Invoke the chain with history
    try:
        response = chain_with_history.invoke(
            {"input": user_message},
            config={"configurable": {"session_id": session_id}}
        )
        
        # Extract the text from the response
        response_text = response.content if hasattr(response, "content") else str(response)
        
        # Format as "Virtual doctor: " response to match the expected format
        formatted_response = f"Virtual doctor: {response_text}"
        
        return formatted_response
    except Exception as e:
        print(f"Error processing message: {e}")
        return "Virtual doctor: I apologize, but I'm experiencing technical difficulties. Please try again."

# Customize the CSS for better appearance
css = """
.gradio-container {
    font-family: 'Arial', sans-serif;
}
.chat-bot .bot-message {
    background-color: #f0f7ff !important;
}
.chat-bot .user-message {
    background-color: #e6f7e6 !important;
}
"""

# Create the Gradio interface
demo = gr.ChatInterface(
    fn=gradio_chat,
    title="Medbot Chatbot (Llama-2 + LangChain + Gradio)",
    description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI.",
    examples=[
        "I have a cough and my throat hurts",
        "I've been having headaches for a week",
        "My stomach has been hurting since yesterday"
    ],
    css=css
)

# Launch the app
if __name__ == "__main__":
    demo.launch(share=False)