File size: 3,026 Bytes
e815180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# app.py
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import sqlparse

# Set page config
st.set_page_config(
    page_title="AI SQL Query Generator",
    page_icon="πŸ€–",
    layout="centered"
)

# Load model and tokenizer
@st.cache_resource
def load_model():
    model_name = "tscholak/cxmefzzi"  # Pre-trained text-to-SQL model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    return tokenizer, model

# Format SQL output
def format_sql(sql):
    return sqlparse.format(sql, reindent=True, keyword_case='upper')

# Generate SQL from natural language
def generate_sql(input_text, tokenizer, model):
    prefix = "Translate English to SQL: "
    inputs = tokenizer(prefix + input_text, return_tensors="pt", max_length=512, truncation=True)
    outputs = model.generate(**inputs, max_length=256)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Streamlit UI
def main():
    st.title("πŸ€– AI-Powered SQL Query Generator")
    st.markdown("Convert natural language questions to SQL queries")

    # Load model
    tokenizer, model = load_model()

    # User input
    user_input = st.text_area(
        "Enter your question in natural language:",
        placeholder="e.g., Show all customers from California who made purchases after January 2023",
        height=150
    )

    # Generate button
    if st.button("Generate SQL"):
        if user_input.strip() == "":
            st.warning("Please enter a question")
        else:
            with st.spinner("Generating SQL query..."):
                try:
                    # Generate and format SQL
                    raw_sql = generate_sql(user_input, tokenizer, model)
                    formatted_sql = format_sql(raw_sql)
                    
                    # Display results
                    st.subheader("Generated SQL Query:")
                    st.code(formatted_sql, language="sql")
                    
                    st.success("Query generated successfully!")
                    
                    # Show raw output for debugging
                    with st.expander("Debug Info"):
                        st.write(f"Model: tscholak/cxmefzzi")
                        st.write(f"Raw Output: `{raw_sql}`")
                except Exception as e:
                    st.error(f"Error generating SQL: {str(e)}")

    # Footer
    st.markdown("---")
    st.markdown("### How to use:")
    st.markdown("1. Enter a question about data you want to query")
    st.markdown("2. Click 'Generate SQL'")
    st.markdown("3. Copy the generated SQL and use it in your database")
    
    st.markdown("### Example queries:")
    st.code("Show the total sales per product category in 2022", language="text")
    st.code("List employees hired before 2020 with salary above $50,000", language="text")
    st.code("Count orders by customer country and sort descending", language="text")

if __name__ == "__main__":
    main()