onkolahmet commited on
Commit
e7e63fd
Β·
verified Β·
1 Parent(s): 6867082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -120
app.py CHANGED
@@ -1,19 +1,52 @@
1
- import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import re
5
  import sqlparse
 
6
 
7
- # Load model and tokenizer
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model = AutoModelForCausalLM.from_pretrained(
10
- "onkolahmet/Qwen2-0.5B-Instruct-SQL-generator",
11
- torch_dtype="auto",
12
- device_map="auto"
13
- )
14
- tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator")
15
 
16
- def generate_sql(question, context=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # Construct prompt with few-shot examples and context if available
18
  prompt = "Translate natural language questions to SQL queries.\n\n"
19
 
@@ -21,6 +54,9 @@ def generate_sql(question, context=None):
21
  if context and context.strip():
22
  prompt += f"Table Context:\n{context}\n\n"
23
 
 
 
 
24
 
25
  # Add the current question
26
  prompt += f"Q: {question}\nSQL:"
@@ -146,36 +182,39 @@ def clean_sql_output(sql_text):
146
  except:
147
  return best_query
148
 
149
- def process_input(question, table_context):
150
- """Function to process user input through the model and return formatted results"""
151
- if not question.strip():
152
- return "Please enter a question."
153
-
154
- # Generate SQL from the question and context
155
- raw_sql = generate_sql(question, table_context)
156
-
157
- # Clean the SQL output
158
- cleaned_sql = clean_sql_output(raw_sql)
159
-
160
- if not cleaned_sql:
161
- return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question."
162
-
163
- return cleaned_sql
164
 
165
- # Sample table context examples for the example selector
166
- example_contexts = [
167
- # Example 1
168
- """
169
- CREATE TABLE customers (
170
- id INT PRIMARY KEY,
171
- name VARCHAR(100),
172
- email VARCHAR(100),
173
- order_date DATE
174
- );
175
- """,
176
 
177
- # Example 2
178
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  CREATE TABLE products (
180
  id INT PRIMARY KEY,
181
  name VARCHAR(100),
@@ -183,10 +222,8 @@ CREATE TABLE products (
183
  price DECIMAL(10,2),
184
  stock_quantity INT
185
  );
186
- """,
187
-
188
- # Example 3
189
- """
190
  CREATE TABLE employees (
191
  id INT PRIMARY KEY,
192
  name VARCHAR(100),
@@ -201,87 +238,97 @@ CREATE TABLE departments (
201
  manager_id INT,
202
  budget DECIMAL(15,2)
203
  );
204
- """
205
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- # Sample question examples
208
- example_questions = [
209
- "Get the names and emails of customers who placed an order in the last 30 days.",
210
- "Find all products with less than 10 items in stock.",
211
- "List all employees in the Sales department with a salary greater than 50000.",
212
- "What is the total budget for departments with more than 5 employees?",
213
- "Count how many products are in each category where the price is greater than 100."
214
- ]
215
 
216
- # Create the Gradio interface
217
- with gr.Blocks(title="Text to SQL Converter") as demo:
218
- gr.Markdown("# Text to SQL Query Converter")
219
- gr.Markdown("Enter your question and optional table context to generate an SQL query.")
220
-
221
- with gr.Row():
222
- with gr.Column():
223
- question_input = gr.Textbox(
224
- label="Your Question",
225
- placeholder="e.g., Find all products with price less than $50",
226
- lines=2
227
- )
228
 
229
- table_context = gr.Textbox(
230
- label="Table Context (Optional)",
231
- placeholder="Enter your database schema or table definitions here...",
232
- lines=10
233
- )
234
 
235
- submit_btn = gr.Button("Generate SQL Query")
236
-
237
- with gr.Column():
238
- sql_output = gr.Code(
239
- label="Generated SQL Query",
240
- language="sql",
241
- lines=12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  )
243
-
244
- # Examples section
245
- gr.Markdown("### Try some examples")
246
-
247
- example_selector = gr.Examples(
248
- examples=[
249
- ["List all products in the 'Electronics' category with price less than $500", example_contexts[1]],
250
- ["Find the total number of employees in each department", example_contexts[2]],
251
- ["Get customers who placed orders in the last 7 days", example_contexts[0]],
252
- ["Count the number of products in each category", example_contexts[1]],
253
- ["Find the average salary by department", example_contexts[2]]
254
- ],
255
- inputs=[question_input, table_context]
256
- )
257
-
258
- # Set up the submit button to trigger the process_input function
259
- submit_btn.click(
260
- fn=process_input,
261
- inputs=[question_input, table_context],
262
- outputs=sql_output
263
- )
264
-
265
- # Also trigger on pressing Enter in the question input
266
- question_input.submit(
267
- fn=process_input,
268
- inputs=[question_input, table_context],
269
- outputs=sql_output
270
- )
271
-
272
- # Add information about the model
273
- gr.Markdown("""
274
- ### About
275
- This app uses a fine-tuned language model to convert natural language questions into SQL queries.
276
-
277
- - **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator)
278
- - **How to use**:
279
- 1. Enter your question in natural language
280
- 2. If you have specific table schemas, add them in the Table Context field
281
- 3. Click "Generate SQL Query" or press Enter
282
-
283
- Note: The model works best when table context is provided, but can generate generic SQL queries without it.
284
- """)
285
 
286
- # Launch the app
287
- demo.launch()
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import re
5
  import sqlparse
6
+ import time
7
 
8
+ # App title and description
9
+ st.set_page_config(page_title="Text to SQL Converter", layout="wide")
10
+ st.title("Text to SQL Query Converter")
11
+ st.markdown("Enter your question and optional table context to generate an SQL query.")
 
 
 
 
12
 
13
+ # Model loading (with loading state indicator)
14
+ @st.cache_resource
15
+ def load_model():
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ st.info(f"Loading model on {device}... This may take a minute.")
18
+
19
+ # Load without device_map for compatibility
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ "onkolahmet/text_to_sql",
22
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
23
+ )
24
+ model = model.to(device)
25
+ tokenizer = AutoTokenizer.from_pretrained("onkolahmet/text_to_sql")
26
+
27
+ return model, tokenizer, device
28
+
29
+ # Few-shot examples to include in each prompt
30
+ examples = [
31
+ {
32
+ "question": "Get the names and emails of customers who placed an order in the last 30 days.",
33
+ "sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);"
34
+ },
35
+ {
36
+ "question": "Find all employees with a salary greater than 50000.",
37
+ "sql": "SELECT * FROM employees WHERE salary > 50000;"
38
+ },
39
+ {
40
+ "question": "List all product names and their categories where the price is below 50.",
41
+ "sql": "SELECT name, category FROM products WHERE price < 50;"
42
+ },
43
+ {
44
+ "question": "How many users registered in the year 2022?",
45
+ "sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;"
46
+ }
47
+ ]
48
+
49
+ def generate_sql(model, tokenizer, device, question, context=None):
50
  # Construct prompt with few-shot examples and context if available
51
  prompt = "Translate natural language questions to SQL queries.\n\n"
52
 
 
54
  if context and context.strip():
55
  prompt += f"Table Context:\n{context}\n\n"
56
 
57
+ # Add few-shot examples
58
+ for ex in examples:
59
+ prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
60
 
61
  # Add the current question
62
  prompt += f"Q: {question}\nSQL:"
 
182
  except:
183
  return best_query
184
 
185
+ # Load model (this happens once when the app starts)
186
+ model, tokenizer, device = load_model()
187
+ st.success("Model loaded successfully! Ready to generate SQL queries.")
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ # Main app interface
190
+ col1, col2 = st.columns([1, 1])
191
+
192
+ with col1:
193
+ # User inputs
194
+ question = st.text_area("Your Question",
195
+ placeholder="e.g., Find all products with price less than $50",
196
+ height=100)
 
 
 
197
 
198
+ table_context = st.text_area("Table Context (Optional)",
199
+ placeholder="Enter your database schema or table definitions here...",
200
+ height=200)
201
+
202
+ # Example selection
203
+ with st.expander("Try an example", expanded=False):
204
+ example_option = st.selectbox(
205
+ "Select an example:",
206
+ [
207
+ "List all products in the 'Electronics' category with price less than $500",
208
+ "Find the total number of employees in each department",
209
+ "Get customers who placed orders in the last 7 days",
210
+ "Count the number of products in each category",
211
+ "Find the average salary by department"
212
+ ]
213
+ )
214
+
215
+ # Sample table context examples mapped to questions
216
+ example_contexts = {
217
+ "List all products in the 'Electronics' category with price less than $500": """
218
  CREATE TABLE products (
219
  id INT PRIMARY KEY,
220
  name VARCHAR(100),
 
222
  price DECIMAL(10,2),
223
  stock_quantity INT
224
  );
225
+ """,
226
+ "Find the total number of employees in each department": """
 
 
227
  CREATE TABLE employees (
228
  id INT PRIMARY KEY,
229
  name VARCHAR(100),
 
238
  manager_id INT,
239
  budget DECIMAL(15,2)
240
  );
241
+ """,
242
+ "Get customers who placed orders in the last 7 days": """
243
+ CREATE TABLE customers (
244
+ id INT PRIMARY KEY,
245
+ name VARCHAR(100),
246
+ email VARCHAR(100),
247
+ order_date DATE
248
+ );
249
+ """,
250
+ "Count the number of products in each category": """
251
+ CREATE TABLE products (
252
+ id INT PRIMARY KEY,
253
+ name VARCHAR(100),
254
+ category VARCHAR(50),
255
+ price DECIMAL(10,2),
256
+ stock_quantity INT
257
+ );
258
+ """,
259
+ "Find the average salary by department": """
260
+ CREATE TABLE employees (
261
+ id INT PRIMARY KEY,
262
+ name VARCHAR(100),
263
+ department VARCHAR(50),
264
+ salary DECIMAL(10,2),
265
+ hire_date DATE
266
+ );
267
+ """
268
+ }
269
+
270
+ apply_example = st.button("Apply Example")
271
+ if apply_example:
272
+ question = example_option
273
+ table_context = example_contexts[example_option]
274
+ st.session_state.question = question
275
+ st.session_state.table_context = table_context
276
+ st.success("Example applied! Click 'Generate SQL Query' to see the result.")
277
 
278
+ # Button to generate SQL
279
+ generate_button = st.button("Generate SQL Query")
 
 
 
 
 
 
280
 
281
+ # Display results
282
+ with col2:
283
+ if generate_button and question:
284
+ with st.spinner("Generating SQL query..."):
285
+ # Record start time
286
+ start_time = time.time()
 
 
 
 
 
 
287
 
288
+ # Generate SQL
289
+ raw_sql = generate_sql(model, tokenizer, device, question, table_context)
290
+ cleaned_sql = clean_sql_output(raw_sql)
 
 
291
 
292
+ # Calculate elapsed time
293
+ elapsed_time = time.time() - start_time
294
+
295
+ # Display results
296
+ st.subheader("Generated SQL Query")
297
+ st.code(cleaned_sql, language="sql")
298
+ st.info(f"Query generated in {elapsed_time:.2f} seconds")
299
+
300
+ # Display explanation
301
+ st.subheader("Explanation")
302
+ st.write("This SQL query translates your natural language question into a database command.")
303
+
304
+ # Option to copy to clipboard (using JavaScript)
305
+ st.markdown(
306
+ f"""
307
+ <div style="margin-top: 20px;">
308
+ <button
309
+ onclick="navigator.clipboard.writeText(`{cleaned_sql}`);this.textContent='Copied!';setTimeout(()=>this.textContent='Copy to Clipboard',1500)"
310
+ style="background-color:#4CAF50;color:white;padding:8px 16px;border:none;border-radius:4px;cursor:pointer"
311
+ >
312
+ Copy to Clipboard
313
+ </button>
314
+ </div>
315
+ """,
316
+ unsafe_allow_html=True
317
  )
318
+ else:
319
+ st.info("Enter a question and click 'Generate SQL Query' to see the result here.")
320
+
321
+ # App footer and info
322
+ st.markdown("---")
323
+ st.markdown("""
324
+ ### About
325
+ This app uses a fine-tuned language model to convert natural language questions into SQL queries.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
+ - **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator)
328
+ - **How to use**:
329
+ 1. Enter your question in natural language
330
+ 2. If you have specific table schemas, add them in the Table Context field
331
+ 3. Click "Generate SQL Query"
332
+
333
+ Note: The model works best when table context is provided, but can generate generic SQL queries without it.
334
+ """)