Spaces:
Sleeping
Sleeping
File size: 5,045 Bytes
fe47927 5484f7d 84925d4 fe47927 84925d4 314ddd7 5e2bcc2 314ddd7 84925d4 5e2bcc2 84925d4 d12ffe7 fe47927 69c50aa fe47927 d12ffe7 69c50aa d12ffe7 60e6743 fe47927 d12ffe7 fe47927 d12ffe7 fe47927 d12ffe7 fe47927 d12ffe7 71a6503 fe47927 d12ffe7 71a6503 d12ffe7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import streamlit as st
import replicate
import os
import snowflake.connector
import snowflake.snowpark.functions
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
from transformers import AutoTokenizer
connection_parameters= {"account" : "ap20346.ap-south-1",
"user" : "Vassist",
"password" : "Vassist@123",
"role" : "ACCOUNTADMIN",
"warehouse" : "COMPUTE_WH",
"database" : "SNOWFLAKE_SAMPLE_DATA",
"schema" : "TPCH_SF1"
}
test_session = Session.builder.configs(connection_parameters).create()
# Set assistant icon to Snowflake logo
icons = {"assistant": "./Snowflake_Logomark_blue.svg", "user": "⛷️"}
# App title
st.set_page_config(page_title="Snowflake Arctic Enu")
# Replicate Credentials
with st.sidebar:
st.title('Snowflake Arctic Enu')
if 'REPLICATE_API_TOKEN' in st.secrets:
replicate_api = st.secrets['REPLICATE_API_TOKEN']
else:
replicate_api = st.text_input('Enter Replicate API token:', type='password')
if not (replicate_api.startswith('r8_') and len(replicate_api)==40):
st.warning('Please enter your Replicate API token.', icon='⚠️')
st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.")
os.environ['REPLICATE_API_TOKEN'] = replicate_api
st.subheader("Adjust model parameters")
temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.3, step=0.01)
top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
# Store LLM-generated responses
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "Hi. I'm Snowflake Arctic demo of Enu, a new, efficient, intelligent, and truly open language model created by Snowflake AI Research. Ask me anything."}]
# Display or clear chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"], avatar=icons[message["role"]]):
st.write(message["content"])
def clear_chat_history():
st.session_state.messages = [{"role": "assistant", "content": "Hi. I'm Arctic, a new, efficient, intelligent, and truly open language model created by Snowflake AI Research. Ask me anything."}]
st.sidebar.button('Clear chat history', on_click=clear_chat_history)
st.sidebar.caption('Built by [Snowflake](https://snowflake.com/) to demonstrate [Snowflake Arctic](https://www.snowflake.com/blog/arctic-open-and-efficient-foundation-language-models-snowflake). App hosted on [Streamlit Community Cloud](https://streamlit.io/cloud). Model hosted by [Replicate](https://replicate.com/snowflake/snowflake-arctic-instruct).')
st.sidebar.caption('Build your own app powered by Arctic and [enter to win](https://arctic-streamlit-hackathon.devpost.com/) $10k in prizes.')
@st.cache_resource(show_spinner=False)
def get_tokenizer():
"""Get a tokenizer to make sure we're not sending too much text
text to the Model. Eventually we will replace this with ArcticTokenizer
"""
return AutoTokenizer.from_pretrained("huggyllama/llama-7b")
def get_num_tokens(prompt):
"""Get the number of tokens in a given prompt"""
tokenizer = get_tokenizer()
tokens = tokenizer.tokenize(prompt)
return len(tokens)
# Function for generating Snowflake Arctic response
def generate_arctic_response():
prompt = []
for dict_message in st.session_state.messages:
if dict_message["role"] == "user":
prompt.append("<|im_start|>user\n" + dict_message["content"] + "<|im_end|>")
else:
prompt.append("<|im_start|>assistant\n" + dict_message["content"] + "<|im_end|>")
prompt.append("<|im_start|>assistant")
prompt.append("")
prompt_str = "\n".join(prompt)
if get_num_tokens(prompt_str) >= 3072:
st.error("Conversation length too long. Please keep it under 3072 tokens.")
st.button('Clear chat history', on_click=clear_chat_history, key="clear_chat_history")
st.stop()
for event in replicate.stream("snowflake/snowflake-arctic-instruct",
input={"prompt": prompt_str,
"prompt_template": r"{prompt}",
"temperature": temperature,
"top_p": top_p,
}):
yield str(event)
# User-provided prompt
if prompt := st.chat_input(disabled=not replicate_api):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar="⛷️"):
st.write(prompt)
# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant", avatar="./Snowflake_Logomark_blue.svg"):
response = generate_arctic_response()
full_response = st.write_stream(response)
message = {"role": "assistant", "content": full_response}
st.session_state.messages.append(message) |