File size: 3,135 Bytes
b80af5b
 
6d5190c
 
 
 
 
b80af5b
6d5190c
aca454d
6d5190c
 
 
 
 
 
 
 
aca454d
6d5190c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aca454d
6d5190c
 
 
 
 
 
aca454d
6d5190c
 
aca454d
6d5190c
 
 
 
aca454d
6d5190c
 
 
 
 
 
 
 
aca454d
6d5190c
 
 
 
 
 
1cf7fb2
6d5190c
 
b80af5b
6d5190c
 
 
 
 
 
b80af5b
 
6d5190c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import ChatMessageHistory

MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

SYSTEM_PROMPT = (
    "You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. "
    "Start by greeting the user politely and ask them to describe their health concern. Based on their input, ask follow-up questions to gather as much relevant information as possible. "
    "Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity, possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history. "
    "Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification. "
    "Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician. "
    "Ask one or two questions at a time, and wait for user input."
)

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto"
)
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.9,
    pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=pipe)
print("Model loaded successfully!")

# LangChain prompt
prompt = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_PROMPT),
    MessagesPlaceholder(variable_name="history"),
    ("human", "{input}")
])

# Memory store
store = {}

def get_session_history(session_id: str) -> ChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

# Chain with memory
chain = prompt | llm
chain_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history"
)

@spaces.GPU
def gradio_chat(user_message, history):
    session_id = "default-session"  # For demo; can be made unique per user
    response = chain_with_history.invoke(
        {"input": user_message},
        config={"configurable": {"session_id": session_id}}
    )
    # LangChain returns a "AIMessage" object; get text
    return response.content if hasattr(response, "content") else str(response)

# Gradio UI
demo = gr.ChatInterface(
    fn=gradio_chat,
    title="Medbot Chatbot (Llama-2 + LangChain + Gradio)",
    description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI."
)

if __name__ == "__main__":
    demo.launch()