techindia2025 commited on
Commit
8b29c0d
·
verified ·
1 Parent(s): d5f0232

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -41
app.py CHANGED
@@ -6,54 +6,66 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
6
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
  from langchain_community.chat_message_histories import ChatMessageHistory
8
 
 
9
  MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
10
 
11
- SYSTEM_PROMPT = (
12
- "You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. "
13
- "Start by greeting the user politely and ask them to describe their health concern. "
14
- "For each user reply, ask only 1 or 2 follow-up questions at a time to gather more details. "
15
- "Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity, possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history. "
16
- "Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification. "
17
- "Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician. "
18
- "Wait for the user's answer before asking more questions."
19
- )
20
-
 
21
 
22
  print("Loading model...")
23
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
- model = AutoModelForCausalLM.from_pretrained(
25
- MODEL_NAME,
26
- torch_dtype="auto",
27
- device_map="auto"
28
- )
29
- pipe = pipeline(
30
- "text-generation",
31
- model=model,
32
- tokenizer=tokenizer,
33
- max_new_tokens=512,
34
- temperature=0.7,
35
- top_p=0.9,
36
- pad_token_id=tokenizer.eos_token_id
37
- )
38
- llm = HuggingFacePipeline(pipeline=pipe)
39
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
40
 
41
- # LangChain prompt
42
  prompt = ChatPromptTemplate.from_messages([
43
  ("system", SYSTEM_PROMPT),
44
  MessagesPlaceholder(variable_name="history"),
45
  ("human", "{input}")
46
  ])
47
 
48
- # Memory store
49
  store = {}
50
 
51
  def get_session_history(session_id: str) -> ChatMessageHistory:
 
52
  if session_id not in store:
53
  store[session_id] = ChatMessageHistory()
54
  return store[session_id]
55
 
56
- # Chain with memory
57
  chain = prompt | llm
58
  chain_with_history = RunnableWithMessageHistory(
59
  chain,
@@ -62,22 +74,57 @@ chain_with_history = RunnableWithMessageHistory(
62
  history_messages_key="history"
63
  )
64
 
65
- @spaces.GPU
 
66
  def gradio_chat(user_message, history):
67
- session_id = "default-session" # For demo; can be made unique per user
68
- response = chain_with_history.invoke(
69
- {"input": user_message},
70
- config={"configurable": {"session_id": session_id}}
71
- )
72
- # LangChain returns a "AIMessage" object; get text
73
- return response.content if hasattr(response, "content") else str(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Gradio UI
76
  demo = gr.ChatInterface(
77
  fn=gradio_chat,
78
  title="Medbot Chatbot (Llama-2 + LangChain + Gradio)",
79
- description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI."
 
 
 
 
 
 
80
  )
81
 
 
82
  if __name__ == "__main__":
83
- demo.launch()
 
6
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
  from langchain_community.chat_message_histories import ChatMessageHistory
8
 
9
+ # Model configuration
10
  MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
11
 
12
+ # System prompt that guides the bot's behavior
13
+ SYSTEM_PROMPT = """
14
+ You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition,
15
+ symptoms, medical history, medications, lifestyle, and other relevant data. Start by greeting the user politely and ask
16
+ them to describe their health concern. For each user reply, ask only 1 or 2 follow-up questions at a time to gather more details.
17
+ Be structured and thorough in your questioning. Organize the information into categories: symptoms, duration, severity,
18
+ possible causes, past medical history, medications, allergies, habits (e.g., smoking, alcohol), and family history.
19
+ Always confirm and summarize what the user tells you. Respond empathetically and clearly. If unsure, ask for clarification.
20
+ Do NOT make a final diagnosis or suggest treatments. You are only here to collect and organize medical data to support a licensed physician.
21
+ Wait for the user's answer before asking more questions.
22
+ """
23
 
24
  print("Loading model...")
25
+ try:
26
+ # Initialize the tokenizer and model
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ MODEL_NAME,
30
+ torch_dtype="auto",
31
+ device_map="auto"
32
+ )
33
+
34
+ # Create a pipeline for text generation
35
+ pipe = pipeline(
36
+ "text-generation",
37
+ model=model,
38
+ tokenizer=tokenizer,
39
+ max_new_tokens=512,
40
+ temperature=0.7,
41
+ top_p=0.9,
42
+ pad_token_id=tokenizer.eos_token_id
43
+ )
44
+
45
+ llm = HuggingFacePipeline(pipeline=pipe)
46
+ print("Model loaded successfully!")
47
+ except Exception as e:
48
+ print(f"Error loading model: {e}")
49
+ # Fallback to a smaller model or provide an error message
50
+ raise
51
 
52
+ # LangChain prompt template
53
  prompt = ChatPromptTemplate.from_messages([
54
  ("system", SYSTEM_PROMPT),
55
  MessagesPlaceholder(variable_name="history"),
56
  ("human", "{input}")
57
  ])
58
 
59
+ # Memory store to maintain conversation history
60
  store = {}
61
 
62
  def get_session_history(session_id: str) -> ChatMessageHistory:
63
+ """Get or create a chat history for the given session ID"""
64
  if session_id not in store:
65
  store[session_id] = ChatMessageHistory()
66
  return store[session_id]
67
 
68
+ # Create a chain with memory
69
  chain = prompt | llm
70
  chain_with_history = RunnableWithMessageHistory(
71
  chain,
 
74
  history_messages_key="history"
75
  )
76
 
77
+ # Our handler for chat interactions
78
+ @spaces.GPU # Request GPU for this space
79
  def gradio_chat(user_message, history):
80
+ """Process the user message and return the chatbot response"""
81
+ # Use a unique session ID in production
82
+ session_id = "default-session"
83
+
84
+ # Invoke the chain with history
85
+ try:
86
+ response = chain_with_history.invoke(
87
+ {"input": user_message},
88
+ config={"configurable": {"session_id": session_id}}
89
+ )
90
+
91
+ # Extract the text from the response
92
+ response_text = response.content if hasattr(response, "content") else str(response)
93
+
94
+ # Format as "Virtual doctor: " response to match the expected format
95
+ formatted_response = f"Virtual doctor: {response_text}"
96
+
97
+ return formatted_response
98
+ except Exception as e:
99
+ print(f"Error processing message: {e}")
100
+ return "Virtual doctor: I apologize, but I'm experiencing technical difficulties. Please try again."
101
+
102
+ # Customize the CSS for better appearance
103
+ css = """
104
+ .gradio-container {
105
+ font-family: 'Arial', sans-serif;
106
+ }
107
+ .chat-bot .bot-message {
108
+ background-color: #f0f7ff !important;
109
+ }
110
+ .chat-bot .user-message {
111
+ background-color: #e6f7e6 !important;
112
+ }
113
+ """
114
 
115
+ # Create the Gradio interface
116
  demo = gr.ChatInterface(
117
  fn=gradio_chat,
118
  title="Medbot Chatbot (Llama-2 + LangChain + Gradio)",
119
+ description="Medical chatbot using Llama-2-7b-chat-hf, LangChain memory, and Gradio UI.",
120
+ examples=[
121
+ "I have a cough and my throat hurts",
122
+ "I've been having headaches for a week",
123
+ "My stomach has been hurting since yesterday"
124
+ ],
125
+ css=css
126
  )
127
 
128
+ # Launch the app
129
  if __name__ == "__main__":
130
+ demo.launch(share=False)