|
import streamlit as st |
|
import sqlite3 |
|
import pandas as pd |
|
import google.generativeai as genai |
|
import os |
|
|
|
|
|
st.set_page_config(page_title="Zero SQL", layout="wide") |
|
st.title("Zero SQL - Natural Language to SQL Query") |
|
|
|
|
|
@st.cache_resource |
|
def initialize_database(): |
|
conn = sqlite3.connect('database.db') |
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS Produkte ( |
|
ProduktID INTEGER PRIMARY KEY AUTOINCREMENT, |
|
Produktname TEXT NOT NULL, |
|
Preis REAL NOT NULL |
|
) |
|
''') |
|
|
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS Bestellungen ( |
|
BestellungID INTEGER PRIMARY KEY AUTOINCREMENT, |
|
ProduktID INTEGER NOT NULL, |
|
Menge INTEGER NOT NULL, |
|
Bestelldatum TEXT NOT NULL, |
|
Person TEXT NOT NULL, |
|
FOREIGN KEY (ProduktID) REFERENCES Produkte(ProduktID) |
|
) |
|
''') |
|
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM Produkte") |
|
if cursor.fetchone()[0] == 0: |
|
products = [ |
|
('Laptop', 999.99), |
|
('Smartphone', 699.99), |
|
('Tablet', 399.99) |
|
] |
|
cursor.executemany("INSERT INTO Produkte (Produktname, Preis) VALUES (?, ?)", products) |
|
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM Bestellungen") |
|
if cursor.fetchone()[0] == 0: |
|
orders = [ |
|
(1, 2, '2024-10-20', 'Max Mustermann'), |
|
(2, 1, '2024-10-21', 'Erika Musterfrau'), |
|
(3, 3, '2024-10-22', 'Hans Meier') |
|
] |
|
cursor.executemany("INSERT INTO Bestellungen (ProduktID, Menge, Bestelldatum, Person) VALUES (?, ?, ?, ?)", orders) |
|
|
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
initialize_database() |
|
|
|
|
|
with st.sidebar: |
|
st.header("Configuration") |
|
api_key_iput = st.text_input("Gemini API Key", type="password") |
|
st.markdown("---") |
|
st.markdown("**Sample Questions:**") |
|
st.markdown("- Show total sales per product") |
|
st.markdown("- List orders from Max Mustermann") |
|
st.markdown("- Find most popular product by quantity") |
|
|
|
|
|
with st.form("query_form"): |
|
user_input = st.text_area( |
|
"Enter your data request in natural language:", |
|
placeholder="e.g. Show all orders over β¬500", |
|
height=100 |
|
) |
|
submitted = st.form_submit_button("π Generate Query") |
|
|
|
if submitted: |
|
api_key = api_key_iput or os.getenv("GEMINI_API_KEY") |
|
|
|
if not user_input: |
|
st.error("π Please enter your data request!") |
|
else: |
|
try: |
|
|
|
genai.configure(api_key=api_key) |
|
model = genai.GenerativeModel('gemini-2.0-flash') |
|
|
|
|
|
|
|
system_context = """Given these SQL tables: |
|
CREATE TABLE Produkte ( |
|
ProduktID INTEGER PRIMARY KEY, |
|
Produktname TEXT NOT NULL, |
|
Preis REAL NOT NULL |
|
); |
|
CREATE TABLE Bestellungen ( |
|
BestellungID INTEGER PRIMARY KEY, |
|
ProduktID INTEGER, |
|
Menge INTEGER, |
|
Bestelldatum TEXT, |
|
Person TEXT, |
|
FOREIGN KEY (ProduktID) REFERENCES Produkte(ProduktID) |
|
); |
|
Generate ONLY the raw SQL query for the following request. |
|
Output ONLY the pure SQL statement without any formatting, |
|
explanations, or markdown blocks.""" |
|
|
|
|
|
full_prompt = f"{system_context}\n\nUser Request: {user_input}" |
|
|
|
|
|
response = model.generate_content( |
|
full_prompt, |
|
generation_config={"temperature": 0.3} |
|
) |
|
|
|
|
|
sql_query = response.text.strip() |
|
sql_query = sql_query.replace("```sql", "").replace("```", "").strip() |
|
|
|
|
|
with sqlite3.connect('database.db') as conn: |
|
cursor = conn.cursor() |
|
cursor.execute(sql_query) |
|
results = cursor.fetchall() |
|
column_names = [desc[0] for desc in cursor.description] |
|
|
|
st.subheader("Generated SQL") |
|
st.code(sql_query, language="sql") |
|
|
|
st.subheader("Results") |
|
if results: |
|
df = pd.DataFrame(results, columns=column_names) |
|
st.dataframe( |
|
data=df, |
|
use_container_width=True, |
|
hide_index=True |
|
) |
|
else: |
|
st.info("No results found", icon="βΉοΈ") |
|
|
|
except sqlite3.Error as e: |
|
st.error(f"π¨ SQL Error: {str(e)}") |
|
except Exception as e: |
|
st.error(f"π₯ Unexpected Error: {str(e)}") |