ubaid975 commited on
Commit
ce43f2f
·
verified ·
1 Parent(s): 3718936

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +98 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,100 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from langchain_community.utilities import SQLDatabase
3
+ from langchain.chat_models import ChatOpenAI
4
+ from langchain.agents import create_sql_agent
5
+ from langchain_groq import ChatGroq
6
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
7
+
8
+ import tempfile
9
+ import sqlite3
10
+ import pandas as pd
11
+
12
+
13
+ def is_valid_sqlite(file_path):
14
+ try:
15
+ with sqlite3.connect(file_path) as conn:
16
+ conn.execute("SELECT name FROM sqlite_master LIMIT 1;")
17
+ return True
18
+ except sqlite3.DatabaseError:
19
+ return False
20
+
21
+
22
+ def text_to_sql(query: str, db_path: str, llm_provider: str, api_key: str, model_name: str):
23
+ try:
24
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
25
+
26
+ if llm_provider == 'OPENAI':
27
+ llm = ChatOpenAI(api_key=api_key, model=model_name)
28
+ elif llm_provider == 'OPEN_ROUTER':
29
+ llm = ChatOpenAI(api_key=api_key, base_url='https://openrouter.ai/api/v1', model=model_name)
30
+ elif llm_provider == 'GROQ':
31
+ llm = ChatGroq(api_key=api_key, model=model_name)
32
+ else:
33
+ return "Unsupported LLM provider selected."
34
+
35
+ toolkit = SQLDatabaseToolkit(llm=llm, db=db)
36
+ db_chain = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
37
+ return db_chain.run(query)
38
+
39
+ except Exception as e:
40
+ return f"Error: {str(e)}"
41
+
42
+
43
+ def show_tables_as_df(db_path):
44
+ conn = sqlite3.connect(db_path)
45
+ cursor = conn.cursor()
46
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
47
+ tables = cursor.fetchall()
48
+
49
+ if tables:
50
+ for table_name in tables:
51
+ table = table_name[0]
52
+ st.subheader(f"Table: {table}")
53
+ df = pd.read_sql_query(f"SELECT * FROM {table} LIMIT 10", conn)
54
+ st.dataframe(df)
55
+ else:
56
+ st.write("No tables found in database.")
57
+
58
+ conn.close()
59
+
60
+
61
+ # Streamlit UI
62
+ st.title('🗃️ Chat with SQLite Database')
63
+
64
+ st.write("This app lets you interact with a SQLite database using natural language queries powered by LLMs.")
65
+
66
+ uploaded_file = st.file_uploader("Upload SQLite Database (.db file)", type=["db"])
67
+
68
+ llm_provider = st.radio("Choose LLM Provider", options=['OPEN_ROUTER', 'GROQ', 'OPENAI'])
69
+ model_name = st.text_input("Enter the Model Name", value='nousresearch/deephermes-3-mistral-24b-preview:free')
70
+ api_key = st.text_input("Enter Your API Key", type="password")
71
+ query = st.text_area("Enter Your Query")
72
+
73
+ try:
74
+ if uploaded_file is not None:
75
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmpfile:
76
+ tmpfile.write(uploaded_file.read())
77
+ tmp_db_path = tmpfile.name
78
+
79
+ if not is_valid_sqlite(tmp_db_path):
80
+ st.error("The uploaded file is not a valid SQLite database.")
81
+ else:
82
+ st.success("Valid SQLite database uploaded!")
83
+
84
+ # Show tables as pandas DataFrames
85
+ st.info("Displaying first 10 rows from each table:")
86
+ show_tables_as_df(tmp_db_path)
87
+
88
+ if st.button("RUN Query"):
89
+ if not api_key or not model_name:
90
+ st.error("Please provide API key and model name.")
91
+ elif not query.strip():
92
+ st.error("Please enter a query.")
93
+ else:
94
+ st.info(f"Running query on `{uploaded_file.name}`...")
95
+ result = text_to_sql(query, tmp_db_path, llm_provider, api_key, model_name)
96
+ st.success("Query Result:")
97
+ st.write(result)
98
 
99
+ except Exception as e:
100
+ st.error(f"Error: {str(e)}")