Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
4 |
+
from langsmith import Client, traceable
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langsmith import Client, traceable
|
8 |
+
from langchain_core.output_parsers import StrOutputParser
|
9 |
+
|
10 |
+
from langchain_nomic.embeddings import NomicEmbeddings
|
11 |
+
from langchain_groq import ChatGroq
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
|
16 |
+
HF_API_KEY = os.getenv("HF_API_KEY")
|
17 |
+
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
LANGSMITH_TRACING="true"
|
23 |
+
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
24 |
+
LANGSMITH_API_KEY=os.getenv("LANGSMITH_API_KEY")
|
25 |
+
LANGSMITH_PROJECT="pr-smug-rancher-51"
|
26 |
+
|
27 |
+
model_name="llama-3.1-70b-versatile"
|
28 |
+
llm = ChatGroq(
|
29 |
+
temperature=0,
|
30 |
+
model= "llama-3.3-70b-versatile", #"llama-3.1-70b-versatile", #"llama3-70b-8192",
|
31 |
+
api_key=GROQ_API_KEY,
|
32 |
+
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
@traceable
|
37 |
+
def get_answer(question):
|
38 |
+
|
39 |
+
prompt = ChatPromptTemplate.from_messages([
|
40 |
+
("system", "You are a medical expert called Dr.Med! Here are some info about cancer: {facts}"),
|
41 |
+
("user", "{question}")
|
42 |
+
])
|
43 |
+
|
44 |
+
chain = prompt | llm
|
45 |
+
|
46 |
+
parser = StrOutputParser()
|
47 |
+
|
48 |
+
chain = prompt | llm | parser
|
49 |
+
|
50 |
+
answer= chain.invoke({"question": question, "facts": fake_db_retrieval()})
|
51 |
+
return answer
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
|
58 |
+
db = "db1"
|
59 |
+
from langchain.vectorstores import Chroma
|
60 |
+
|
61 |
+
vector_store = Chroma(
|
62 |
+
collection_name="chromadb3",
|
63 |
+
persist_directory=db,
|
64 |
+
embedding_function=embedding_model,
|
65 |
+
)
|
66 |
+
|
67 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
|
68 |
+
from langchain.chains import RetrievalQA
|
69 |
+
|
70 |
+
conversational_memory = ConversationBufferWindowMemory(
|
71 |
+
memory_key='chat_history',
|
72 |
+
k=5, #Number of messages stored in memory
|
73 |
+
return_messages=True #Must return the messages in the response.
|
74 |
+
)
|
75 |
+
|
76 |
+
qa = RetrievalQA.from_chain_type(
|
77 |
+
llm=llm,
|
78 |
+
chain_type="stuff",
|
79 |
+
retriever=vector_store.as_retriever(k=5)
|
80 |
+
)
|
81 |
+
|
82 |
+
from langchain.agents import Tool
|
83 |
+
|
84 |
+
#Defining the list of tool objects to be used by LangChain.
|
85 |
+
tools = [
|
86 |
+
Tool(
|
87 |
+
name='Medical_KB',
|
88 |
+
func=qa.run,
|
89 |
+
description=(
|
90 |
+
'use this tool when answering medical knowledge queries to get '
|
91 |
+
'more information about the topic'
|
92 |
+
)
|
93 |
+
)
|
94 |
+
]
|
95 |
+
|
96 |
+
from langchain.agents import create_react_agent
|
97 |
+
from langchain import hub
|
98 |
+
|
99 |
+
prompt = hub.pull("hwchase17/react-chat")
|
100 |
+
agent = create_react_agent(
|
101 |
+
tools=tools,
|
102 |
+
llm=llm,
|
103 |
+
prompt=prompt,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
# Create an agent executor by passing in the agent and tools
|
109 |
+
from langchain.agents import AgentExecutor
|
110 |
+
agent_executor = AgentExecutor(agent=agent,
|
111 |
+
tools=tools,
|
112 |
+
verbose=True,
|
113 |
+
memory=conversational_memory,
|
114 |
+
max_iterations=30,
|
115 |
+
max_execution_time=600,
|
116 |
+
#early_stopping_method='generate',
|
117 |
+
handle_parsing_errors=True
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
# Function for continuing the conversation
|
122 |
+
import streamlit as st
|
123 |
+
|
124 |
+
# Function for continuing the conversation
|
125 |
+
def continue_conversation(input, history):
|
126 |
+
# Invoke the agent and get the response
|
127 |
+
response = agent_executor.invoke({"input": input})
|
128 |
+
output = response['output']
|
129 |
+
|
130 |
+
# Prepend the new input and output to the history (latest conversation comes first)
|
131 |
+
history.insert(0, {"role": "Patient", "message": input})
|
132 |
+
history.insert(0, {"role": "Doctor", "message": output})
|
133 |
+
|
134 |
+
# Return the current response and the full history (hidden state)
|
135 |
+
return output, history
|
136 |
+
|
137 |
+
# Streamlit UI
|
138 |
+
def main():
|
139 |
+
st.set_page_config(page_title="Medical Chatbot", page_icon="👨⚕️")
|
140 |
+
st.title("Medical Chatbot")
|
141 |
+
|
142 |
+
# Initialize the conversation history
|
143 |
+
if 'history' not in st.session_state:
|
144 |
+
st.session_state.history = []
|
145 |
+
|
146 |
+
# Sidebar for memory display
|
147 |
+
with st.sidebar:
|
148 |
+
st.header("Conversation History")
|
149 |
+
st.write("This section contains the conversation history.")
|
150 |
+
|
151 |
+
# Create a container for the chat
|
152 |
+
chat_container = st.container()
|
153 |
+
|
154 |
+
# Display the chat history with the latest conversation at the top
|
155 |
+
for chat in st.session_state.history:
|
156 |
+
if chat['role'] == 'Patient':
|
157 |
+
chat_container.markdown(f"**Patient:** {chat['message']}")
|
158 |
+
else:
|
159 |
+
chat_container.markdown(f"**Doctor:** {chat['message']}")
|
160 |
+
|
161 |
+
# User input text box at the bottom
|
162 |
+
user_input = st.text_input("Ask a question:", key="input", placeholder="Describe your symptoms or medical questions ?")
|
163 |
+
|
164 |
+
if user_input:
|
165 |
+
# Get the response and update the conversation history
|
166 |
+
output, updated_history = continue_conversation(user_input, st.session_state.history)
|
167 |
+
|
168 |
+
# Update the session state with the new history
|
169 |
+
st.session_state.history = updated_history
|
170 |
+
|
171 |
+
# Display memory of past conversation in an expandable section
|
172 |
+
with st.expander("Memory", expanded=True):
|
173 |
+
for chat in st.session_state.history:
|
174 |
+
st.write(f"**{chat['role']}:** {chat['message']}")
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
main()
|