Update src/streamlit_app.py
Browse files- src/streamlit_app.py +104 -0
src/streamlit_app.py
CHANGED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pathlib import Path
|
3 |
+
from langchain.agents import create_sql_agent
|
4 |
+
from langchain.sql_database import SQLDatabase
|
5 |
+
from langchain.agents.agent_types import AgentType
|
6 |
+
from langchain.callbacks import StreamlitCallbackHandler
|
7 |
+
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
8 |
+
from sqlalchemy import create_engine
|
9 |
+
import sqlite3
|
10 |
+
from langchain_groq import ChatGroq
|
11 |
+
|
12 |
+
st.set_page_config(page_title="LangChain: Chat with SQL DB", page_icon="🦜")
|
13 |
+
st.title("🦜 LangChain: Chat with SQL DB")
|
14 |
+
|
15 |
+
LOCALDB = "USE_LOCALDB"
|
16 |
+
MYSQL = "USE_MYSQL"
|
17 |
+
|
18 |
+
radio_opt = ["Use SQLLite 3 Database- STUDENT.db", "Connect to you MySQL Database"]
|
19 |
+
selected_opt = st.sidebar.radio(label="Choose the DB which you want to chat", options=radio_opt)
|
20 |
+
|
21 |
+
if radio_opt.index(selected_opt) == 1:
|
22 |
+
db_uri = MYSQL
|
23 |
+
mysql_host = st.sidebar.text_input("Provide MySQL Host")
|
24 |
+
mysql_user = st.sidebar.text_input("MYSQL User")
|
25 |
+
mysql_password = st.sidebar.text_input("MYSQL password (leave blank if no password)", type="password")
|
26 |
+
mysql_db = st.sidebar.text_input("MySQL database")
|
27 |
+
else:
|
28 |
+
db_uri = LOCALDB
|
29 |
+
|
30 |
+
# Get API key from user input
|
31 |
+
api_key = st.sidebar.text_input(label="GROQ API Key", type="password")
|
32 |
+
|
33 |
+
if not api_key:
|
34 |
+
st.info("Please add the Groq API key to continue")
|
35 |
+
st.stop()
|
36 |
+
|
37 |
+
if db_uri == MYSQL and not all([mysql_host, mysql_user, mysql_db]):
|
38 |
+
st.info("Please enter all required MySQL database information")
|
39 |
+
st.stop()
|
40 |
+
|
41 |
+
# LLM model - using user-provided API key
|
42 |
+
llm = ChatGroq(groq_api_key=api_key, model_name="llama-3.1-8b-instant", streaming=True)
|
43 |
+
|
44 |
+
@st.cache_resource(ttl="2h")
|
45 |
+
def configure_db(db_uri, mysql_host=None, mysql_user=None, mysql_password=None, mysql_db=None):
|
46 |
+
if db_uri == LOCALDB:
|
47 |
+
dbfilepath = (Path(__file__).parent/"STUDENT.db").absolute()
|
48 |
+
print(dbfilepath)
|
49 |
+
creator = lambda: sqlite3.connect(f"file:{dbfilepath}?mode=ro", uri=True)
|
50 |
+
return SQLDatabase(create_engine("sqlite:///", creator=creator))
|
51 |
+
elif db_uri == MYSQL:
|
52 |
+
if not (mysql_host and mysql_user and mysql_db):
|
53 |
+
st.error("Please provide MySQL host, user, and database name.")
|
54 |
+
st.stop()
|
55 |
+
|
56 |
+
# Handle optional password
|
57 |
+
if mysql_password:
|
58 |
+
connection_string = f"mysql+mysqlconnector://{mysql_user}:{mysql_password}@{mysql_host}/{mysql_db}"
|
59 |
+
else:
|
60 |
+
connection_string = f"mysql+mysqlconnector://{mysql_user}@{mysql_host}/{mysql_db}"
|
61 |
+
|
62 |
+
return SQLDatabase(create_engine(connection_string))
|
63 |
+
|
64 |
+
# Configure database
|
65 |
+
try:
|
66 |
+
if db_uri == MYSQL:
|
67 |
+
db = configure_db(db_uri, mysql_host, mysql_user, mysql_password, mysql_db)
|
68 |
+
else:
|
69 |
+
db = configure_db(db_uri)
|
70 |
+
except Exception as e:
|
71 |
+
st.error(f"Error connecting to database: {str(e)}")
|
72 |
+
st.stop()
|
73 |
+
|
74 |
+
# Toolkit
|
75 |
+
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
76 |
+
agent = create_sql_agent(
|
77 |
+
llm=llm,
|
78 |
+
toolkit=toolkit,
|
79 |
+
verbose=True,
|
80 |
+
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
81 |
+
)
|
82 |
+
|
83 |
+
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
|
84 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
85 |
+
|
86 |
+
for msg in st.session_state.messages:
|
87 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
88 |
+
|
89 |
+
user_query = st.chat_input(placeholder="Ask anything from the database")
|
90 |
+
|
91 |
+
if user_query:
|
92 |
+
st.session_state.messages.append({"role": "user", "content": user_query})
|
93 |
+
st.chat_message("user").write(user_query)
|
94 |
+
|
95 |
+
with st.chat_message("assistant"):
|
96 |
+
try:
|
97 |
+
streamlit_callback = StreamlitCallbackHandler(st.container())
|
98 |
+
response = agent.run(user_query, callbacks=[streamlit_callback])
|
99 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
100 |
+
st.write(response)
|
101 |
+
except Exception as e:
|
102 |
+
error_message = f"An error occurred: {str(e)}"
|
103 |
+
st.error(error_message)
|
104 |
+
st.session_state.messages.append({"role": "assistant", "content": error_message})
|