Spaces:
Sleeping
Sleeping
# Questions: | |
# ========== | |
# Show me the total number of entries in the first table | |
# Select top 10 customers from Canada with highest sum of C_ACCTBAL value, in descending order | |
# Show me the total of Customers per Nation, in ascending order | |
# Show me a query that lists totals for extended price, discounted extended price, discounted extended price plus tax, average quantity, average extended price, and average discount. These aggregates are grouped by RETURNFLAG and LINESTATUS, and listed in ascending order of RETURNFLAG and LINESTATUS. A count of the number of line items in each group is included | |
import os | |
import streamlit as st | |
import pandas | |
from snowflake.snowpark import Session | |
from langchain_community.utilities import SQLDatabase | |
from langchain.chains import create_sql_query_chain | |
from transformers import AutoTokenizer | |
from langchain_community.llms import Replicate | |
from langchain_core.prompts import PromptTemplate | |
st.set_page_config(page_title="Snowflake Arctic", page_icon="π€") | |
def getSession(): | |
section = st.secrets[f"connections_snowflake"] | |
pars = { | |
"account": section["account"], | |
"user": section["user"], | |
"password": section["password"], | |
"database": section["database"], | |
"schema": section["schema"], | |
"warehouse": section["warehouse"], | |
"role": section["role"] | |
} | |
session = Session.builder.configs(pars).create() | |
url = (f"snowflake://{pars['user']}:{pars['password']}@{pars['account']}" | |
+ f"/{pars['database']}/{pars['schema']}" | |
+ f"?warehouse={pars['warehouse']}&role={pars['role']}") | |
db = SQLDatabase.from_uri(url) | |
os.environ['REPLICATE_API_TOKEN'] = st.secrets["REPLICATE_API_TOKEN"] | |
llm = Replicate(model="snowflake/snowflake-arctic-instruct", model_kwargs={"temperature": 0.75, "top_p": 1},) | |
chain = create_sql_query_chain(llm, db) | |
return session, db, chain | |
st.title("βοΈ Snowflake Arctic with Replicate") | |
st.write("Returns and runs queries from questions in natural language.") | |
session, db, chain = getSession() | |
#user_query = st.chat_input("Type your message here...") | |
user_query = st.sidebar.text_area("Ask a question:", value="Show me the total number of entries in the first table") | |
sql = chain.invoke({"question": user_query}).rstrip(';') | |
tabQuery, tabData, tabLog = st.tabs("Query", "Data", "Log") | |
tabQuery.code(sql, language="sql") | |
tabData.dataframe(session.sql(sql)) | |
tabLog.code(db.table_info, language="sql") |