|
|
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import sqlparse |
|
|
|
|
|
st.set_page_config( |
|
page_title="AI SQL Query Generator", |
|
page_icon="π€", |
|
layout="centered" |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model_name = "tscholak/cxmefzzi" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
return tokenizer, model |
|
|
|
|
|
def format_sql(sql): |
|
return sqlparse.format(sql, reindent=True, keyword_case='upper') |
|
|
|
|
|
def generate_sql(input_text, tokenizer, model): |
|
prefix = "Translate English to SQL: " |
|
inputs = tokenizer(prefix + input_text, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = model.generate(**inputs, max_length=256) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
def main(): |
|
st.title("π€ AI-Powered SQL Query Generator") |
|
st.markdown("Convert natural language questions to SQL queries") |
|
|
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
user_input = st.text_area( |
|
"Enter your question in natural language:", |
|
placeholder="e.g., Show all customers from California who made purchases after January 2023", |
|
height=150 |
|
) |
|
|
|
|
|
if st.button("Generate SQL"): |
|
if user_input.strip() == "": |
|
st.warning("Please enter a question") |
|
else: |
|
with st.spinner("Generating SQL query..."): |
|
try: |
|
|
|
raw_sql = generate_sql(user_input, tokenizer, model) |
|
formatted_sql = format_sql(raw_sql) |
|
|
|
|
|
st.subheader("Generated SQL Query:") |
|
st.code(formatted_sql, language="sql") |
|
|
|
st.success("Query generated successfully!") |
|
|
|
|
|
with st.expander("Debug Info"): |
|
st.write(f"Model: tscholak/cxmefzzi") |
|
st.write(f"Raw Output: `{raw_sql}`") |
|
except Exception as e: |
|
st.error(f"Error generating SQL: {str(e)}") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown("### How to use:") |
|
st.markdown("1. Enter a question about data you want to query") |
|
st.markdown("2. Click 'Generate SQL'") |
|
st.markdown("3. Copy the generated SQL and use it in your database") |
|
|
|
st.markdown("### Example queries:") |
|
st.code("Show the total sales per product category in 2022", language="text") |
|
st.code("List employees hired before 2020 with salary above $50,000", language="text") |
|
st.code("Count orders by customer country and sort descending", language="text") |
|
|
|
if __name__ == "__main__": |
|
main() |