File size: 6,204 Bytes
c718aa0 c723280 c718aa0 183474c c718aa0 9ebf574 c718aa0 9ebf574 c718aa0 9ebf574 c718aa0 75a2814 8393aff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import streamlit as st
import torch
import uuid
import nest_asyncio
import asyncio
import os
torch.classes.__path__ = []
# Setup for HTTP API Calls
if 'device_id' not in st.session_state:
st.session_state.device_id = str(uuid.uuid4())
if "feedback_key" not in st.session_state:
st.session_state.feedback_key = 0
#corpus_keys = 'first.last'
def launch_bot():
def reset():
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?", "avatar": 'π€'}]
st.session_state.ex_prompt = None
st.session_state.first_turn = True
def generate_response(question):
response = vq.submit_query(question, languages[st.session_state.language])
return response
def generate_streaming_response(question):
response = vq.submit_query_streaming(question, languages[st.session_state.language])
return response
def show_example_questions():
if len(st.session_state.example_messages) > 0 and st.session_state.first_turn:
selected_example = pills("Questions to Try:", st.session_state.example_messages, index=None)
if selected_example:
st.session_state.ex_prompt = selected_example
st.session_state.first_turn = False
return True
return False
if 'cfg' not in st.session_state:
corpus_keys = ["first", "last"] #str(os.environ['corpus_keys']).split(',')
cfg = OmegaConf.create({
'corpus_keys': corpus_keys,
'api_key': str(os.environ['api_key']),
'title': os.environ['title'],
'source_data_desc': os.environ['source_data_desc'],
'streaming': isTrue(os.environ.get('streaming', False)),
'prompt_name': os.environ.get('prompt_name', None),
'examples': os.environ.get('examples', None),
'language': 'English'
})
st.session_state.cfg = cfg
st.session_state.ex_prompt = None
st.session_state.first_turn = True
st.session_state.language = cfg.language
example_messages = [example.strip() for example in cfg.examples.split(",")]
st.session_state.example_messages = [em for em in example_messages if len(em)>0][:max_examples]
st.session_state.vq = VectaraQuery(cfg.api_key, cfg.corpus_keys, cfg.prompt_name)
cfg = st.session_state.cfg
vq = st.session_state.vq
st.set_page_config(page_title=cfg.title, layout="wide")
# left side content
with st.sidebar:
#image = Image.open('Vectara-logo.png')
#st.image(image, width=175)
st.markdown(f"## About\n\n"
f"This demo uses outside RAG to ask questions about {cfg.source_data_desc}\n")
cfg.language = st.selectbox('Language:', languages.keys())
if st.session_state.language != cfg.language:
st.session_state.language = cfg.language
reset()
st.rerun()
st.markdown("\n")
bc1, _ = st.columns([1, 1])
with bc1:
if st.button('Start Over'):
reset()
st.rerun()
st.markdown("---")
st.markdown(
"## Temporary test demo only\n"
)
st.markdown(f"<center> <h2> Header Demo Test: {cfg.title} </h2> </center>", unsafe_allow_html=True)
if "messages" not in st.session_state.keys():
reset()
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"], avatar=message["avatar"]):
st.write(message["content"])
example_container = st.empty()
with example_container:
if show_example_questions():
example_container.empty()
st.rerun()
# select prompt from example question or user provided input
if st.session_state.ex_prompt:
prompt = st.session_state.ex_prompt
else:
prompt = st.chat_input()
if prompt:
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": 'π§βπ»'})
with st.chat_message("user", avatar="π§βπ»"):
st.write(prompt)
st.session_state.ex_prompt = None
# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant", avatar="π€"):
if cfg.streaming:
stream = generate_streaming_response(prompt)
response = st.write_stream(stream)
else:
with st.spinner("Thinking..."):
response = generate_response(prompt)
st.write(response)
response = escape_dollars_outside_latex(response)
message = {"role": "assistant", "content": response, "avatar": 'π€'}
st.session_state.messages.append(message)
# Send query and response to Amplitude Analytics
send_amplitude_data(
user_query=st.session_state.messages[-2]["content"],
chat_response=st.session_state.messages[-1]["content"],
demo_name=cfg["title"],
language=st.session_state.language
)
st.rerun()
if (st.session_state.messages[-1]["role"] == "assistant") & (st.session_state.messages[-1]["content"] != "How may I help you?"):
streamlit_feedback(feedback_type="thumbs", on_submit = thumbs_feedback, key = st.session_state.feedback_key,
kwargs = {"user_query": st.session_state.messages[-2]["content"],
"chat_response": st.session_state.messages[-1]["content"],
"demo_name": cfg["title"],
"response_language": st.session_state.language})
for i in range(100):
st.write(f"This is scrollable content line {i}")
if __name__ == "__main__":
st.set_page_config(page_title="Sticky toolbar test", layout="wide")
nest_asyncio.apply()
#asyncio.run(launch_bot())
|