Navya-Sree's picture
Create app.py
e815180 verified
# app.py
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import sqlparse
# Set page config
st.set_page_config(
page_title="AI SQL Query Generator",
page_icon="πŸ€–",
layout="centered"
)
# Load model and tokenizer
@st.cache_resource
def load_model():
model_name = "tscholak/cxmefzzi" # Pre-trained text-to-SQL model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return tokenizer, model
# Format SQL output
def format_sql(sql):
return sqlparse.format(sql, reindent=True, keyword_case='upper')
# Generate SQL from natural language
def generate_sql(input_text, tokenizer, model):
prefix = "Translate English to SQL: "
inputs = tokenizer(prefix + input_text, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(**inputs, max_length=256)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Streamlit UI
def main():
st.title("πŸ€– AI-Powered SQL Query Generator")
st.markdown("Convert natural language questions to SQL queries")
# Load model
tokenizer, model = load_model()
# User input
user_input = st.text_area(
"Enter your question in natural language:",
placeholder="e.g., Show all customers from California who made purchases after January 2023",
height=150
)
# Generate button
if st.button("Generate SQL"):
if user_input.strip() == "":
st.warning("Please enter a question")
else:
with st.spinner("Generating SQL query..."):
try:
# Generate and format SQL
raw_sql = generate_sql(user_input, tokenizer, model)
formatted_sql = format_sql(raw_sql)
# Display results
st.subheader("Generated SQL Query:")
st.code(formatted_sql, language="sql")
st.success("Query generated successfully!")
# Show raw output for debugging
with st.expander("Debug Info"):
st.write(f"Model: tscholak/cxmefzzi")
st.write(f"Raw Output: `{raw_sql}`")
except Exception as e:
st.error(f"Error generating SQL: {str(e)}")
# Footer
st.markdown("---")
st.markdown("### How to use:")
st.markdown("1. Enter a question about data you want to query")
st.markdown("2. Click 'Generate SQL'")
st.markdown("3. Copy the generated SQL and use it in your database")
st.markdown("### Example queries:")
st.code("Show the total sales per product category in 2022", language="text")
st.code("List employees hired before 2020 with salary above $50,000", language="text")
st.code("Count orders by customer country and sort descending", language="text")
if __name__ == "__main__":
main()