Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -298,6 +298,137 @@ def main():
|
|
298 |
filename = generate_filename(file_content_area, choice)
|
299 |
create_file(filename, file_content_area, response)
|
300 |
st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
|
301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
if __name__ == "__main__":
|
303 |
main()
|
|
|
298 |
filename = generate_filename(file_content_area, choice)
|
299 |
create_file(filename, file_content_area, response)
|
300 |
st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
from langchain.chains import ConversationChain
|
307 |
+
from langchain.chains.conversation.memory import ConversationEntityMemory
|
308 |
+
from langchain.chains.conversation.prompt import ENTITY_MEMORY_CONVERSATION_TEMPLATE
|
309 |
+
from langchain.llms import OpenAI
|
310 |
+
|
311 |
+
if "generated" not in st.session_state:
|
312 |
+
st.session_state["generated"] = []
|
313 |
+
|
314 |
+
if "past" not in st.session_state:
|
315 |
+
st.session_state["past"] = []
|
316 |
+
|
317 |
+
if "input" not in st.session_state:
|
318 |
+
st.session_state["input"] = ""
|
319 |
+
|
320 |
+
if "stored_session" not in st.session_state:
|
321 |
+
st.session_state["stored_session"] = []
|
322 |
+
|
323 |
+
|
324 |
+
# Define function to get user input
|
325 |
+
def get_text():
|
326 |
+
"""
|
327 |
+
Get the user input text.
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
(str): The text entered by the user
|
331 |
+
"""
|
332 |
+
input_text = st.text_input("You: ", st.session_state["input"], key="input",
|
333 |
+
placeholder="Your AI assistant here! Ask me anything ...",
|
334 |
+
label_visibility='hidden')
|
335 |
+
return input_text
|
336 |
+
|
337 |
+
# Define function to start a new chat
|
338 |
+
def new_chat():
|
339 |
+
"""
|
340 |
+
Clears session state and starts a new chat.
|
341 |
+
"""
|
342 |
+
save = []
|
343 |
+
for i in range(len(st.session_state['generated'])-1, -1, -1):
|
344 |
+
save.append("User:" + st.session_state["past"][i])
|
345 |
+
save.append("Bot:" + st.session_state["generated"][i])
|
346 |
+
st.session_state["stored_session"].append(save)
|
347 |
+
st.session_state["generated"] = []
|
348 |
+
st.session_state["past"] = []
|
349 |
+
st.session_state["input"] = ""
|
350 |
+
st.session_state.entity_memory.entity_store = {}
|
351 |
+
st.session_state.entity_memory.buffer.clear()
|
352 |
+
|
353 |
+
# Set up sidebar with various options
|
354 |
+
with st.sidebar.expander("π οΈ ", expanded=False):
|
355 |
+
# Option to preview memory store
|
356 |
+
if st.checkbox("Preview memory store"):
|
357 |
+
with st.expander("Memory-Store", expanded=False):
|
358 |
+
st.session_state.entity_memory.store
|
359 |
+
# Option to preview memory buffer
|
360 |
+
if st.checkbox("Preview memory buffer"):
|
361 |
+
with st.expander("Bufffer-Store", expanded=False):
|
362 |
+
st.session_state.entity_memory.buffer
|
363 |
+
MODEL = st.selectbox(label='Model', options=['gpt-3.5-turbo','text-davinci-003','text-davinci-002','code-davinci-002'])
|
364 |
+
K = st.number_input(' (#)Summary of prompts to consider',min_value=3,max_value=1000)
|
365 |
+
|
366 |
+
# Set up the Streamlit app layout
|
367 |
+
#st.title("π€ Chat Bot with π§ ")
|
368 |
+
#st.subheader(" Powered by π¦ LangChain + OpenAI + Streamlit")
|
369 |
+
|
370 |
+
# Ask the user to enter their OpenAI API key
|
371 |
+
#API_O = st.sidebar.text_input("API-KEY", type="password")
|
372 |
+
API_O = os.getenv('OPENAI_KEY')
|
373 |
+
|
374 |
+
# Session state storage would be ideal
|
375 |
+
if API_O:
|
376 |
+
# Create an OpenAI instance
|
377 |
+
llm = OpenAI(temperature=0,
|
378 |
+
openai_api_key=API_O,
|
379 |
+
model_name=MODEL,
|
380 |
+
verbose=False)
|
381 |
+
|
382 |
+
# Create a ConversationEntityMemory object if not already created
|
383 |
+
if 'entity_memory' not in st.session_state:
|
384 |
+
st.session_state.entity_memory = ConversationEntityMemory(llm=llm, k=K )
|
385 |
+
|
386 |
+
# Create the ConversationChain object with the specified configuration
|
387 |
+
Conversation = ConversationChain(
|
388 |
+
llm=llm,
|
389 |
+
prompt=ENTITY_MEMORY_CONVERSATION_TEMPLATE,
|
390 |
+
memory=st.session_state.entity_memory
|
391 |
+
)
|
392 |
+
|
393 |
+
|
394 |
+
# Add a button to start a new chat
|
395 |
+
st.sidebar.button("Embedding Memory Chat", on_click = new_chat, type='primary')
|
396 |
+
|
397 |
+
# Get the user input
|
398 |
+
user_input = get_text()
|
399 |
+
|
400 |
+
# Generate the output using the ConversationChain object and the user input, and add the input/output to the session
|
401 |
+
if user_input:
|
402 |
+
output = Conversation.run(input=user_input)
|
403 |
+
st.session_state.past.append(user_input)
|
404 |
+
st.session_state.generated.append(output)
|
405 |
+
|
406 |
+
# Allow to download as well
|
407 |
+
download_str = []
|
408 |
+
# Display the conversation history using an expander, and allow the user to download it
|
409 |
+
with st.expander("Conversation", expanded=True):
|
410 |
+
for i in range(len(st.session_state['generated'])-1, -1, -1):
|
411 |
+
st.info(st.session_state["past"][i],icon="π§")
|
412 |
+
st.success(st.session_state["generated"][i], icon="π€")
|
413 |
+
download_str.append(st.session_state["past"][i])
|
414 |
+
download_str.append(st.session_state["generated"][i])
|
415 |
+
|
416 |
+
# Can throw error - requires fix
|
417 |
+
download_str = '\n'.join(download_str)
|
418 |
+
if download_str:
|
419 |
+
st.download_button('Download',download_str)
|
420 |
+
|
421 |
+
# Display stored conversation sessions in the sidebar
|
422 |
+
for i, sublist in enumerate(st.session_state.stored_session):
|
423 |
+
with st.sidebar.expander(label= f"Conversation-Session:{i}"):
|
424 |
+
st.write(sublist)
|
425 |
+
|
426 |
+
# Allow the user to clear all stored conversation sessions
|
427 |
+
if st.session_state.stored_session:
|
428 |
+
if st.sidebar.checkbox("Clear-all"):
|
429 |
+
del st.session_state.stored_session
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
if __name__ == "__main__":
|
434 |
main()
|