File size: 2,523 Bytes
c187aed
 
 
 
 
 
 
0b69762
 
c187aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Questions:
# ==========
# Show me the total number of entries in the first table
# Select top 10 customers from Canada with highest sum of C_ACCTBAL value, in descending order
# Show me the total of Customers per Nation, in ascending order
# Show me a query that lists totals for extended price, discounted extended price, discounted extended price plus tax, average quantity, average extended price, and average discount. These aggregates are grouped by RETURNFLAG and LINESTATUS, and listed in ascending order of RETURNFLAG and LINESTATUS. A count of the number of line items in each group is included

import os
import streamlit as st
import pandas
from snowflake.snowpark import Session
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from transformers import AutoTokenizer
from langchain_community.llms import Replicate
from langchain_core.prompts import PromptTemplate

st.set_page_config(page_title="Snowflake Arctic", page_icon="🤖")
@st.cache_resource(show_spinner="Connecting...")
def getSession():
    section = st.secrets[f"connections_snowflake"]
    pars = {
        "account": section["account"],
        "user": section["user"],
        "password": section["password"],
        "database": section["database"],
        "schema": section["schema"],
        "warehouse": section["warehouse"],
        "role": section["role"]
    }
    session = Session.builder.configs(pars).create()

    url = (f"snowflake://{pars['user']}:{pars['password']}@{pars['account']}"
        + f"/{pars['database']}/{pars['schema']}"
        + f"?warehouse={pars['warehouse']}&role={pars['role']}")
    db = SQLDatabase.from_uri(url)
    
    os.environ['REPLICATE_API_TOKEN'] = st.secrets["REPLICATE_API_TOKEN"]
    llm = Replicate(model="snowflake/snowflake-arctic-instruct", model_kwargs={"temperature": 0.75, "top_p": 1},)
    chain = create_sql_query_chain(llm, db)
    return session, db, chain

st.title("❄️ Snowflake Arctic with Replicate")
st.write("Returns and runs queries from questions in natural language.")

session, db, chain = getSession()

#user_query = st.chat_input("Type your message here...")

user_query = st.sidebar.text_area("Ask a question:", value="Show me the total number of entries in the first table")
sql = chain.invoke({"question": user_query}).rstrip(';')

tabQuery, tabData, tabLog = st.tabs("Query", "Data", "Log")
tabQuery.code(sql, language="sql")
tabData.dataframe(session.sql(sql))
tabLog.code(db.table_info, language="sql")