SnowflakeArctic / app.py
hfenlume's picture
Update app.py
c187aed verified
# Questions:
# ==========
# Show me the total number of entries in the first table
# Select top 10 customers from Canada with highest sum of C_ACCTBAL value, in descending order
# Show me the total of Customers per Nation, in ascending order
# Show me a query that lists totals for extended price, discounted extended price, discounted extended price plus tax, average quantity, average extended price, and average discount. These aggregates are grouped by RETURNFLAG and LINESTATUS, and listed in ascending order of RETURNFLAG and LINESTATUS. A count of the number of line items in each group is included
import os
import streamlit as st
import pandas
from snowflake.snowpark import Session
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from transformers import AutoTokenizer
from langchain_community.llms import Replicate
from langchain_core.prompts import PromptTemplate
st.set_page_config(page_title="Snowflake Arctic", page_icon="πŸ€–")
@st.cache_resource(show_spinner="Connecting...")
def getSession():
section = st.secrets[f"connections_snowflake"]
pars = {
"account": section["account"],
"user": section["user"],
"password": section["password"],
"database": section["database"],
"schema": section["schema"],
"warehouse": section["warehouse"],
"role": section["role"]
}
session = Session.builder.configs(pars).create()
url = (f"snowflake://{pars['user']}:{pars['password']}@{pars['account']}"
+ f"/{pars['database']}/{pars['schema']}"
+ f"?warehouse={pars['warehouse']}&role={pars['role']}")
db = SQLDatabase.from_uri(url)
os.environ['REPLICATE_API_TOKEN'] = st.secrets["REPLICATE_API_TOKEN"]
llm = Replicate(model="snowflake/snowflake-arctic-instruct", model_kwargs={"temperature": 0.75, "top_p": 1},)
chain = create_sql_query_chain(llm, db)
return session, db, chain
st.title("❄️ Snowflake Arctic with Replicate")
st.write("Returns and runs queries from questions in natural language.")
session, db, chain = getSession()
#user_query = st.chat_input("Type your message here...")
user_query = st.sidebar.text_area("Ask a question:", value="Show me the total number of entries in the first table")
sql = chain.invoke({"question": user_query}).rstrip(';')
tabQuery, tabData, tabLog = st.tabs("Query", "Data", "Log")
tabQuery.code(sql, language="sql")
tabData.dataframe(session.sql(sql))
tabLog.code(db.table_info, language="sql")