awacke1 commited on
Commit
e6c9dd5
Β·
1 Parent(s): 4536ad8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -1
app.py CHANGED
@@ -298,6 +298,137 @@ def main():
298
  filename = generate_filename(file_content_area, choice)
299
  create_file(filename, file_content_area, response)
300
  st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
301
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  if __name__ == "__main__":
303
  main()
 
298
  filename = generate_filename(file_content_area, choice)
299
  create_file(filename, file_content_area, response)
300
  st.sidebar.markdown(get_table_download_link(filename), unsafe_allow_html=True)
301
+
302
+
303
+
304
+
305
+
306
+ from langchain.chains import ConversationChain
307
+ from langchain.chains.conversation.memory import ConversationEntityMemory
308
+ from langchain.chains.conversation.prompt import ENTITY_MEMORY_CONVERSATION_TEMPLATE
309
+ from langchain.llms import OpenAI
310
+
311
+ if "generated" not in st.session_state:
312
+ st.session_state["generated"] = []
313
+
314
+ if "past" not in st.session_state:
315
+ st.session_state["past"] = []
316
+
317
+ if "input" not in st.session_state:
318
+ st.session_state["input"] = ""
319
+
320
+ if "stored_session" not in st.session_state:
321
+ st.session_state["stored_session"] = []
322
+
323
+
324
+ # Define function to get user input
325
+ def get_text():
326
+ """
327
+ Get the user input text.
328
+
329
+ Returns:
330
+ (str): The text entered by the user
331
+ """
332
+ input_text = st.text_input("You: ", st.session_state["input"], key="input",
333
+ placeholder="Your AI assistant here! Ask me anything ...",
334
+ label_visibility='hidden')
335
+ return input_text
336
+
337
+ # Define function to start a new chat
338
+ def new_chat():
339
+ """
340
+ Clears session state and starts a new chat.
341
+ """
342
+ save = []
343
+ for i in range(len(st.session_state['generated'])-1, -1, -1):
344
+ save.append("User:" + st.session_state["past"][i])
345
+ save.append("Bot:" + st.session_state["generated"][i])
346
+ st.session_state["stored_session"].append(save)
347
+ st.session_state["generated"] = []
348
+ st.session_state["past"] = []
349
+ st.session_state["input"] = ""
350
+ st.session_state.entity_memory.entity_store = {}
351
+ st.session_state.entity_memory.buffer.clear()
352
+
353
+ # Set up sidebar with various options
354
+ with st.sidebar.expander("πŸ› οΈ ", expanded=False):
355
+ # Option to preview memory store
356
+ if st.checkbox("Preview memory store"):
357
+ with st.expander("Memory-Store", expanded=False):
358
+ st.session_state.entity_memory.store
359
+ # Option to preview memory buffer
360
+ if st.checkbox("Preview memory buffer"):
361
+ with st.expander("Bufffer-Store", expanded=False):
362
+ st.session_state.entity_memory.buffer
363
+ MODEL = st.selectbox(label='Model', options=['gpt-3.5-turbo','text-davinci-003','text-davinci-002','code-davinci-002'])
364
+ K = st.number_input(' (#)Summary of prompts to consider',min_value=3,max_value=1000)
365
+
366
+ # Set up the Streamlit app layout
367
+ #st.title("πŸ€– Chat Bot with 🧠")
368
+ #st.subheader(" Powered by 🦜 LangChain + OpenAI + Streamlit")
369
+
370
+ # Ask the user to enter their OpenAI API key
371
+ #API_O = st.sidebar.text_input("API-KEY", type="password")
372
+ API_O = os.getenv('OPENAI_KEY')
373
+
374
+ # Session state storage would be ideal
375
+ if API_O:
376
+ # Create an OpenAI instance
377
+ llm = OpenAI(temperature=0,
378
+ openai_api_key=API_O,
379
+ model_name=MODEL,
380
+ verbose=False)
381
+
382
+ # Create a ConversationEntityMemory object if not already created
383
+ if 'entity_memory' not in st.session_state:
384
+ st.session_state.entity_memory = ConversationEntityMemory(llm=llm, k=K )
385
+
386
+ # Create the ConversationChain object with the specified configuration
387
+ Conversation = ConversationChain(
388
+ llm=llm,
389
+ prompt=ENTITY_MEMORY_CONVERSATION_TEMPLATE,
390
+ memory=st.session_state.entity_memory
391
+ )
392
+
393
+
394
+ # Add a button to start a new chat
395
+ st.sidebar.button("Embedding Memory Chat", on_click = new_chat, type='primary')
396
+
397
+ # Get the user input
398
+ user_input = get_text()
399
+
400
+ # Generate the output using the ConversationChain object and the user input, and add the input/output to the session
401
+ if user_input:
402
+ output = Conversation.run(input=user_input)
403
+ st.session_state.past.append(user_input)
404
+ st.session_state.generated.append(output)
405
+
406
+ # Allow to download as well
407
+ download_str = []
408
+ # Display the conversation history using an expander, and allow the user to download it
409
+ with st.expander("Conversation", expanded=True):
410
+ for i in range(len(st.session_state['generated'])-1, -1, -1):
411
+ st.info(st.session_state["past"][i],icon="🧐")
412
+ st.success(st.session_state["generated"][i], icon="πŸ€–")
413
+ download_str.append(st.session_state["past"][i])
414
+ download_str.append(st.session_state["generated"][i])
415
+
416
+ # Can throw error - requires fix
417
+ download_str = '\n'.join(download_str)
418
+ if download_str:
419
+ st.download_button('Download',download_str)
420
+
421
+ # Display stored conversation sessions in the sidebar
422
+ for i, sublist in enumerate(st.session_state.stored_session):
423
+ with st.sidebar.expander(label= f"Conversation-Session:{i}"):
424
+ st.write(sublist)
425
+
426
+ # Allow the user to clear all stored conversation sessions
427
+ if st.session_state.stored_session:
428
+ if st.sidebar.checkbox("Clear-all"):
429
+ del st.session_state.stored_session
430
+
431
+
432
+
433
  if __name__ == "__main__":
434
  main()