Spaces:
Sleeping
Sleeping
add conversational memory
Browse files- streamlit_app.py +22 -2
streamlit_app.py
CHANGED
|
@@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile
|
|
| 6 |
import dotenv
|
| 7 |
from grobid_quantities.quantities import QuantitiesAPI
|
| 8 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
|
|
|
| 9 |
|
| 10 |
dotenv.load_dotenv(override=True)
|
| 11 |
|
|
@@ -51,6 +52,9 @@ if 'ner_processing' not in st.session_state:
|
|
| 51 |
if 'uploaded' not in st.session_state:
|
| 52 |
st.session_state['uploaded'] = False
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
st.set_page_config(
|
| 55 |
page_title="Scientific Document Insights Q/A",
|
| 56 |
page_icon="π",
|
|
@@ -67,6 +71,11 @@ def new_file():
|
|
| 67 |
st.session_state['loaded_embeddings'] = None
|
| 68 |
st.session_state['doc_id'] = None
|
| 69 |
st.session_state['uploaded'] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
# @st.cache_resource
|
|
@@ -169,7 +178,7 @@ with st.sidebar:
|
|
| 169 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
|
| 170 |
|
| 171 |
st.markdown(
|
| 172 |
-
":warning: Mistral and Zephyr are
|
| 173 |
|
| 174 |
if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
|
| 175 |
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
|
@@ -206,6 +215,11 @@ with st.sidebar:
|
|
| 206 |
# else:
|
| 207 |
# is_api_key_provided = st.session_state['api_key']
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
st.title("π Scientific Document Insights Q/A")
|
| 210 |
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
|
| 211 |
|
|
@@ -298,7 +312,8 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
|
|
| 298 |
elif mode == "LLM":
|
| 299 |
with st.spinner("Generating response..."):
|
| 300 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
| 301 |
-
context_size=context_size
|
|
|
|
| 302 |
|
| 303 |
if not text_response:
|
| 304 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
|
@@ -317,5 +332,10 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
|
|
| 317 |
st.write(text_response)
|
| 318 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
| 321 |
play_old_messages()
|
|
|
|
| 6 |
import dotenv
|
| 7 |
from grobid_quantities.quantities import QuantitiesAPI
|
| 8 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
| 9 |
+
from langchain.memory import ConversationBufferWindowMemory
|
| 10 |
|
| 11 |
dotenv.load_dotenv(override=True)
|
| 12 |
|
|
|
|
| 52 |
if 'uploaded' not in st.session_state:
|
| 53 |
st.session_state['uploaded'] = False
|
| 54 |
|
| 55 |
+
if 'memory' not in st.session_state:
|
| 56 |
+
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
| 57 |
+
|
| 58 |
st.set_page_config(
|
| 59 |
page_title="Scientific Document Insights Q/A",
|
| 60 |
page_icon="π",
|
|
|
|
| 71 |
st.session_state['loaded_embeddings'] = None
|
| 72 |
st.session_state['doc_id'] = None
|
| 73 |
st.session_state['uploaded'] = True
|
| 74 |
+
st.session_state['memory'].clear()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def clear_memory():
|
| 78 |
+
st.session_state['memory'].clear()
|
| 79 |
|
| 80 |
|
| 81 |
# @st.cache_resource
|
|
|
|
| 178 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
|
| 179 |
|
| 180 |
st.markdown(
|
| 181 |
+
":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")
|
| 182 |
|
| 183 |
if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
|
| 184 |
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
|
|
|
| 215 |
# else:
|
| 216 |
# is_api_key_provided = st.session_state['api_key']
|
| 217 |
|
| 218 |
+
st.button(
|
| 219 |
+
'Reset chat memory.',
|
| 220 |
+
on_click=clear_memory(),
|
| 221 |
+
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
|
| 222 |
+
|
| 223 |
st.title("π Scientific Document Insights Q/A")
|
| 224 |
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
|
| 225 |
|
|
|
|
| 312 |
elif mode == "LLM":
|
| 313 |
with st.spinner("Generating response..."):
|
| 314 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
| 315 |
+
context_size=context_size,
|
| 316 |
+
memory=st.session_state.memory)
|
| 317 |
|
| 318 |
if not text_response:
|
| 319 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
|
|
|
| 332 |
st.write(text_response)
|
| 333 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
| 334 |
|
| 335 |
+
for id in range(0, len(st.session_state.messages), 2):
|
| 336 |
+
question = st.session_state.messages[id]['content']
|
| 337 |
+
answer = st.session_state.messages[id + 1]['content']
|
| 338 |
+
st.session_state.memory.save_context({"input": question}, {"output": answer})
|
| 339 |
+
|
| 340 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
| 341 |
play_old_messages()
|