Navya-Sree commited on
Commit
e815180
·
verified ·
1 Parent(s): f3b31dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import sqlparse
5
+
6
+ # Set page config
7
+ st.set_page_config(
8
+ page_title="AI SQL Query Generator",
9
+ page_icon="🤖",
10
+ layout="centered"
11
+ )
12
+
13
+ # Load model and tokenizer
14
+ @st.cache_resource
15
+ def load_model():
16
+ model_name = "tscholak/cxmefzzi" # Pre-trained text-to-SQL model
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
+ return tokenizer, model
20
+
21
+ # Format SQL output
22
+ def format_sql(sql):
23
+ return sqlparse.format(sql, reindent=True, keyword_case='upper')
24
+
25
+ # Generate SQL from natural language
26
+ def generate_sql(input_text, tokenizer, model):
27
+ prefix = "Translate English to SQL: "
28
+ inputs = tokenizer(prefix + input_text, return_tensors="pt", max_length=512, truncation=True)
29
+ outputs = model.generate(**inputs, max_length=256)
30
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+
32
+ # Streamlit UI
33
+ def main():
34
+ st.title("🤖 AI-Powered SQL Query Generator")
35
+ st.markdown("Convert natural language questions to SQL queries")
36
+
37
+ # Load model
38
+ tokenizer, model = load_model()
39
+
40
+ # User input
41
+ user_input = st.text_area(
42
+ "Enter your question in natural language:",
43
+ placeholder="e.g., Show all customers from California who made purchases after January 2023",
44
+ height=150
45
+ )
46
+
47
+ # Generate button
48
+ if st.button("Generate SQL"):
49
+ if user_input.strip() == "":
50
+ st.warning("Please enter a question")
51
+ else:
52
+ with st.spinner("Generating SQL query..."):
53
+ try:
54
+ # Generate and format SQL
55
+ raw_sql = generate_sql(user_input, tokenizer, model)
56
+ formatted_sql = format_sql(raw_sql)
57
+
58
+ # Display results
59
+ st.subheader("Generated SQL Query:")
60
+ st.code(formatted_sql, language="sql")
61
+
62
+ st.success("Query generated successfully!")
63
+
64
+ # Show raw output for debugging
65
+ with st.expander("Debug Info"):
66
+ st.write(f"Model: tscholak/cxmefzzi")
67
+ st.write(f"Raw Output: `{raw_sql}`")
68
+ except Exception as e:
69
+ st.error(f"Error generating SQL: {str(e)}")
70
+
71
+ # Footer
72
+ st.markdown("---")
73
+ st.markdown("### How to use:")
74
+ st.markdown("1. Enter a question about data you want to query")
75
+ st.markdown("2. Click 'Generate SQL'")
76
+ st.markdown("3. Copy the generated SQL and use it in your database")
77
+
78
+ st.markdown("### Example queries:")
79
+ st.code("Show the total sales per product category in 2022", language="text")
80
+ st.code("List employees hired before 2020 with salary above $50,000", language="text")
81
+ st.code("Count orders by customer country and sort descending", language="text")
82
+
83
+ if __name__ == "__main__":
84
+ main()