Spaces:
Sleeping
Sleeping
File size: 14,313 Bytes
d43b410 10a95c9 f45b463 8bb66b9 f45b463 31a5031 a2a0721 d43b410 0e29746 31a5031 a2a0721 d43b410 0e29746 31a5031 a2a0721 f45b463 d43b410 f45b463 31a5031 d43b410 a2a0721 0219321 31a5031 fb9ed3a 31a5031 1543f99 4abddf8 a62b777 f45b463 d43b410 f45b463 d43b410 f45b463 25b1dfe a2a0721 d43b410 4abddf8 6e06674 d43b410 f45b463 31a5031 a1495e2 31a5031 f45b463 4264c5b f45b463 31a5031 f45b463 31a5031 ee63f07 31a5031 ee63f07 7798405 f45b463 31a5031 f45b463 31a5031 f45b463 31a5031 f45b463 31a5031 f45b463 8958054 f45b463 6467922 f45b463 3c6915f f45b463 8958054 f45b463 154dfaa 4264c5b f45b463 af3c6e2 f45b463 1acff41 7e2c9a0 f45b463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import os
import time
import threading
import streamlit as st
from itertools import tee
from chain import ChainBuilder
DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN")
# remove these secrets from the container
# VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME")
# VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME")
if DATABRICKS_HOST is None:
raise ValueError("DATABRICKS_HOST environment variable must be set")
if DATABRICKS_TOKEN is None:
raise ValueError("DATABRICKS_TOKEN environment variable must be set")
MODEL_AVATAR_URL= "./VU.jpeg"
# MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns. Click the 'Clear Chat' button or refresh the page to start a new conversation."
# MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
EXAMPLE_PROMPTS = [
"How is a data lake used at Vanderbilt University Medical Center?",
"In a table, what are some of the greatest hurdles to healthcare in the United States?",
"What does EDW stand for in the context of Vanderbilt University Medical Center?",
"Code a sql statement that can query a database named 'VUMC'.",
"Write a short story about a country concert in Nashville, Tennessee.",
"Tell me about maximum out-of-pocket costs in healthcare.",
]
TITLE = "Vanderbilt AI Assistant"
DESCRIPTION= """Welcome to the first generation Vanderbilt AI assistant! \n
This AI assistant is built atop the Databricks DBRX large language model
and is augmented with additional organization-specific knowledge. Particularly, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center
terms like **Data Lake**, **EDW**, **HCERA**, and **thousands more**. (Ask the assistant if you don't know what any of these terms mean!) **Disclaimer**: The model has **no access to PHI**. \n
Try querying the model with any of the example prompts below for a simple introduction to both Vanderbilt-specific and general knowledge queries. The purpose of this
model is to allow VUMC employees access to an intelligent assistant that improves and expedites VUMC work. \n
Feedback and ideas are very welcome! Please provide any feedback, ideas, or issues to the email: **[email protected]**.
We hope to gradually improve this AI assistant to create a large-scale, all-inclusive tool to compliment the work of all VUMC staff; your comments are invaluable in this process."""
GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
# @st.cache_resource
# def get_global_semaphore():
# return threading.BoundedSemaphore(QUEUE_SIZE)
# global_semaphore = get_global_semaphore()
st.set_page_config(layout="wide")
# # To prevent streaming to fast, chunk the output into TOKEN_CHUNK_SIZE chunks
TOKEN_CHUNK_SIZE = 1
# if TOKEN_CHUNK_SIZE_ENV is not None:
# TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)
st.title(TITLE)
# st.image("sunrise.jpg", caption="Sunrise by the mountains") # add a Vanderbilt related picture to the head of our Space!
st.markdown(DESCRIPTION)
st.markdown("\n")
# use this to format later
with open("./style.css") as css:
st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
if "messages" not in st.session_state:
st.session_state["messages"] = []
def clear_chat_history():
st.session_state["messages"] = []
st.button('Clear Chat', on_click=clear_chat_history)
# build our chain outside the working body so that its only instantiated once - simply pass it the chat history for chat completion
chain = ChainBuilder().build_chain()
def last_role_is_user():
return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"
def text_stream(stream):
for chunk in stream:
if chunk["content"] is not None:
yield chunk["content"]
def get_stream_warning_error(stream):
error = None
warning = None
for chunk in stream:
if chunk["error"] is not None:
error = chunk["error"]
if chunk["warning"] is not None:
warning = chunk["warning"]
return warning, error
# @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
def chain_call(history):
# *** original code for instantiating the DBRX model through the OpenAI client *** skip this and introduce our chain eventually
# extra_body = {}
# if SAFETY_FILTER:
# extra_body["enable_safety_filter"] = SAFETY_FILTER
# chat_completion = client.chat.completions.create(
# messages=[
# {"role": m["role"], "content": m["content"]}
# for m in history
# ],
# model="databricks-dbrx-instruct",
# stream=True,
# max_tokens=MAX_TOKENS,
# temperature=0.7,
# extra_body= extra_body
# )
# *** can we stream the chain's response by incorporating the above OpenAI streaming functionality?
# *** Look back at the predict_stream function and see if we can incorporate that!
# *** looks like we want to use either chain.stream() or chain.astream()
# test first with invoke
input_example = {'messages':
[{'content': 'What does EDW stand for?', 'role': 'user'},
{'content': 'Enterprise Data Warehouse.', 'role': 'assistant'},
{'content': 'Thank you. What is the data lake?', 'role': 'user'},
{'content': 'A data lake is a centralized repository of structured and unstructured data. It allows data to be stored in its native state, without the need for transformations, so that it can be consumed by other users later. It is not just a term for storage, but also covers functionalities required for a platform, including data analysis, machine learning, cataloging and data movement.', 'role': 'assistant'},
{'content': 'Can you tell me more about how they are used?', 'role': 'user'},
{'content': 'At Vanderbilt University Medical Center, a data lake is used as a centralized repository for storing and managing large amounts of data in its native format. This allows for the data to be easily accessed and analyzed by different teams and business units within the organization. The data lake also provides functionalities such as data analysis, machine learning, cataloging and data movement, making it a versatile tool for handling diverse data sets.\n\nAn Enterprise Data Warehouse (EDW) is used for executing analytic queries on structured data. It is optimized for this purpose, with data being stored in a way that allows for efficient querying and analysis. This makes it a useful tool for teams that need to perform complex analyses on large data sets.\n\nA data mart is a specific organizational structure or pattern used in the context of data warehouses. It is a layer that has specific subdivisions for each business unit or team, such as finance, marketing, and product. This allows users to consume data in a format that meets their specific needs.\n\nA data lakehouse is a term used to describe approaches that attempt to combine the data structure and management features of a data warehouse with the low cost of storage of a data lake. This includes a structured transactional layer, which allows for efficient querying and analysis of data. This approach aims to provide the benefits of both data lakes and data warehouses in a single platform.', 'role': 'assistant'},
{'content': 'Nice answer. Can you tell me what the HCERA is?', 'role': 'user'}]}
input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]}
# search_result = vector_store.similarity_search(query=st.session_state["messages"][-1]["content"], k=5)
# chat_completion = search_result # TODO update this after we implement our chain
# chat_completion = chain.invoke(input_example) # *** TODO here we will pass only the chat history, the chain handles the system prompt
chat_completion = chain.stream(input)
return chat_completion
def write_response():
stream = chat_completion(st.session_state["messages"])
content_stream, error_stream = tee(stream)
response = st.write_stream(text_stream(content_stream))
stream_warning, stream_error = get_stream_warning_error(error_stream)
if stream_warning is not None:
st.warning(stream_warning,icon="β οΈ")
if stream_error is not None:
st.error(stream_error,icon="π¨")
# if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
if isinstance(response, list):
response = None
return response, stream_warning, stream_error
def chat_completion(messages):
# history_dbrx_format = [
# {"role": "system", "content": SYSTEM_PROMPT} # no longer need this because the chain handles all of this for us
# ]
# history_dbrx_format = history_dbrx_format + messages
# if (len(history_dbrx_format)-1)//2 >= MAX_CHAT_TURNS:
# yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
# return
chat_completion = None
error = None
# *** original code for querying DBRX through the OpenAI cleint for chat completion
# wait to be in queue
# with global_semaphore:
# try:
# chat_completion = chat_api_call(history_dbrx_format)
# except Exception as e:
# error = e
# chat_completion = chain_call(history_dbrx_format)
chat_completion = chain_call(messages) # simply pass the old messages, need not worry about the system prompt
if error is not None:
yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
print(error)
return
max_token_warning = None
partial_message = ""
chunk_counter = 0
for chunk in chat_completion:
# if chunk.choices[0].delta.content is not None:
# TODO *** we need to refactor this logic to match what happens with the response from our chain - it should be strings or an iterator of strings
# if chunk.page_content is not None:
if chunk is not None:
chunk_counter += 1
# partial_message += chunk.choices[0].delta.content
# partial_message += f"* {chunk.page_content} [{chunk.metadata}]"
partial_message += chunk
if chunk_counter % TOKEN_CHUNK_SIZE == 0:
chunk_counter = 0
yield {"content": partial_message, "error": None, "warning": None}
partial_message = ""
# if chunk.choices[0].finish_reason == "length":
# max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS
yield {"content": partial_message, "error": None, "warning": max_token_warning}
# if assistant is the last message, we need to prompt the user
# if user is the last message, we need to retry the assistant.
def handle_user_input(user_input):
with history:
response, stream_warning, stream_error = [None, None, None]
if last_role_is_user():
# retry the assistant if the user tries to send a new message
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
else:
st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None})
with st.chat_message("user", avatar="π§βπ»"):
st.markdown(user_input)
stream = chat_completion(st.session_state["messages"])
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error})
def feedback():
with st.container():
st.title("Feedback Interface")
sentiment_mapping = [":material/thumb_down:", ":material/thumb_up:"]
rating = st.feedback()
feedback = st.text_input(f"Please detail your rationale for choosing {sentiment_mapping[rating]}: ", "")
# feedback = ""
# review = {}
# if rating is not None:
# # st.markdown(f"You selected: {sentiment_mapping[rating]}")
# # rating = st.radio("Rate your experience:", ["π", "Neutral", "π"])
review = {"rating": {rating}, "feedback": {feedback}}
st.markdown(review)
time.sleep(5)
# # Save the feedback data
# if st.button("Submit"):
# with open("feedback.json", "a") as f:
# f.write()
# st.write("Thank you for your feedback!")
main = st.container()
with main:
history = st.container(height=400)
with history:
for message in st.session_state["messages"]:
avatar = "π§βπ»"
if message["role"] == "assistant":
avatar = MODEL_AVATAR_URL
with st.chat_message(message["role"], avatar=avatar):
if message["content"] is not None:
st.markdown(message["content"])
# receive feedback on AI outputs if the user feels inclined to give it
# rating = st.radio("Rate your experience:", ["Very satisfied", "Somewhat satisfied", "Neutral", "Somewhat dissatisfied", "Very dissatisfied"])
# st.button("Provide Feedback", on_click=feedback)
if message["error"] is not None:
st.error(message["error"],icon="π¨")
if message["warning"] is not None:
st.warning(message["warning"],icon="β οΈ")
if prompt := st.chat_input("Type a message!", max_chars=5000):
handle_user_input(prompt)
st.markdown("\n") #add some space for iphone users
with st.container():
st.button('Give Feedback on Last Response', on_click=feedback)
with st.sidebar:
with st.container():
st.title("Examples")
for prompt in EXAMPLE_PROMPTS:
st.button(prompt, args=(prompt,), on_click=handle_user_input) |