Balaprime commited on
Commit
9884367
·
verified ·
1 Parent(s): 7cc52eb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import re
5
+ import sqlparse
6
+
7
+ # Load model and tokenizer
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ "onkolahmet/Qwen2-0.5B-Instruct-SQL-generator",
11
+ torch_dtype="auto",
12
+ device_map="auto"
13
+ )
14
+ tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator")
15
+
16
+ # # Few-shot examples to include in each prompt
17
+ # examples = [
18
+ # {
19
+ # "question": "Get the names and emails of customers who placed an order in the last 30 days.",
20
+ # "sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);"
21
+ # },
22
+ # {
23
+ # "question": "Find all employees with a salary greater than 50000.",
24
+ # "sql": "SELECT * FROM employees WHERE salary > 50000;"
25
+ # },
26
+ # {
27
+ # "question": "List all product names and their categories where the price is below 50.",
28
+ # "sql": "SELECT name, category FROM products WHERE price < 50;"
29
+ # },
30
+ # {
31
+ # "question": "How many users registered in the year 2022?",
32
+ # "sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;"
33
+ # }
34
+ # ]
35
+
36
+ def generate_sql(question, context=None):
37
+ # Construct prompt with few-shot examples and context if available
38
+ prompt = "Translate natural language questions to SQL queries.\n\n"
39
+
40
+ # Add table context if available
41
+ if context and context.strip():
42
+ prompt += f"Table Context:\n{context}\n\n"
43
+
44
+ # # Add few-shot examples
45
+ # for ex in examples:
46
+ # prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
47
+
48
+ # Add the current question
49
+ prompt += f"Q: {question}\nSQL:"
50
+
51
+ # Tokenize and generate
52
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
53
+
54
+ # Generate SQL query
55
+ outputs = model.generate(
56
+ inputs.input_ids,
57
+ max_new_tokens=128,
58
+ do_sample=True,
59
+ eos_token_id=tokenizer.eos_token_id
60
+ )
61
+
62
+ # Extract and decode only the new generation
63
+ sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
64
+ return sql_query.strip()
65
+
66
+ def clean_sql_output(sql_text):
67
+ """
68
+ Clean and deduplicate SQL queries:
69
+ 1. Remove comments
70
+ 2. Remove duplicate queries
71
+ 3. Extract only the most relevant query
72
+ 4. Format properly
73
+ """
74
+ # Remove SQL comments (both single line and multi-line)
75
+ sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE)
76
+ sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
77
+
78
+ # Remove markdown code block syntax if present
79
+ sql_text = re.sub(r'```sql|```', '', sql_text)
80
+
81
+ # Split into individual queries if multiple exist
82
+ if ';' in sql_text:
83
+ queries = [q.strip() for q in sql_text.split(';') if q.strip()]
84
+ else:
85
+ # If no semicolons, try to identify separate queries by SELECT statements
86
+ sql_text_cleaned = re.sub(r'\s+', ' ', sql_text)
87
+ select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE))
88
+
89
+ if len(select_matches) > 1:
90
+ queries = []
91
+ for i in range(len(select_matches)):
92
+ start = select_matches[i].start()
93
+ end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned)
94
+ queries.append(sql_text_cleaned[start:end].strip())
95
+ else:
96
+ queries = [sql_text]
97
+
98
+ # Remove empty queries
99
+ queries = [q for q in queries if q.strip()]
100
+
101
+ if not queries:
102
+ return ""
103
+
104
+ # If we have multiple queries, need to deduplicate
105
+ if len(queries) > 1:
106
+ # Normalize queries for comparison (lowercase, remove extra spaces)
107
+ normalized_queries = []
108
+ for q in queries:
109
+ # Use sqlparse to format and normalize
110
+ try:
111
+ formatted = sqlparse.format(
112
+ q + ('' if q.strip().endswith(';') else ';'),
113
+ keyword_case='lower',
114
+ identifier_case='lower',
115
+ strip_comments=True,
116
+ reindent=True
117
+ )
118
+ normalized_queries.append(formatted)
119
+ except:
120
+ # If sqlparse fails, just do basic normalization
121
+ normalized = re.sub(r'\s+', ' ', q.lower().strip())
122
+ normalized_queries.append(normalized)
123
+
124
+ # Find unique queries
125
+ unique_queries = []
126
+ unique_normalized = []
127
+
128
+ for i, norm_q in enumerate(normalized_queries):
129
+ if norm_q not in unique_normalized:
130
+ unique_normalized.append(norm_q)
131
+ unique_queries.append(queries[i])
132
+
133
+ # Choose the most likely correct query:
134
+ # 1. Prefer queries with SELECT
135
+ # 2. Prefer longer queries (often more detailed)
136
+ # 3. Prefer first query if all else equal
137
+ select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)]
138
+
139
+ if select_queries:
140
+ # Choose the longest SELECT query (likely most detailed)
141
+ best_query = max(select_queries, key=len)
142
+ elif unique_queries:
143
+ # If no SELECT queries, choose the longest query
144
+ best_query = max(unique_queries, key=len)
145
+ else:
146
+ # Fallback to the first query
147
+ best_query = queries[0]
148
+ else:
149
+ best_query = queries[0]
150
+
151
+ # Clean up the chosen query
152
+ best_query = best_query.strip()
153
+ if not best_query.endswith(';'):
154
+ best_query += ';'
155
+
156
+ # Final formatting to ensure consistent spacing
157
+ best_query = re.sub(r'\s+', ' ', best_query)
158
+
159
+ try:
160
+ # Use sqlparse to nicely format the SQL for display
161
+ formatted_sql = sqlparse.format(
162
+ best_query,
163
+ keyword_case='upper',
164
+ identifier_case='lower',
165
+ reindent=True,
166
+ indent_width=2
167
+ )
168
+ return formatted_sql
169
+ except:
170
+ return best_query
171
+
172
+ def process_input(question, table_context):
173
+ """Function to process user input through the model and return formatted results"""
174
+ if not question.strip():
175
+ return "Please enter a question."
176
+
177
+ # Generate SQL from the question and context
178
+ raw_sql = generate_sql(question, table_context)
179
+
180
+ # Clean the SQL output
181
+ cleaned_sql = clean_sql_output(raw_sql)
182
+
183
+ if not cleaned_sql:
184
+ return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question."
185
+
186
+ return cleaned_sql
187
+
188
+ # Sample table context examples for the example selector
189
+ example_contexts = [
190
+ # Example 1
191
+ """
192
+ CREATE TABLE customers (
193
+ id INT PRIMARY KEY,
194
+ name VARCHAR(100),
195
+ email VARCHAR(100),
196
+ order_date DATE
197
+ );
198
+ """,
199
+
200
+ # Example 2
201
+ """
202
+ CREATE TABLE products (
203
+ id INT PRIMARY KEY,
204
+ name VARCHAR(100),
205
+ category VARCHAR(50),
206
+ price DECIMAL(10,2),
207
+ stock_quantity INT
208
+ );
209
+ """,
210
+
211
+ # Example 3
212
+ """
213
+ CREATE TABLE employees (
214
+ id INT PRIMARY KEY,
215
+ name VARCHAR(100),
216
+ department VARCHAR(50),
217
+ salary DECIMAL(10,2),
218
+ hire_date DATE
219
+ );
220
+ CREATE TABLE departments (
221
+ id INT PRIMARY KEY,
222
+ name VARCHAR(50),
223
+ manager_id INT,
224
+ budget DECIMAL(15,2)
225
+ );
226
+ """
227
+ ]
228
+
229
+ # Sample question examples
230
+ example_questions = [
231
+ "Get the names and emails of customers who placed an order in the last 30 days.",
232
+ "Find all products with less than 10 items in stock.",
233
+ "List all employees in the Sales department with a salary greater than 50000.",
234
+ "What is the total budget for departments with more than 5 employees?",
235
+ "Count how many products are in each category where the price is greater than 100."
236
+ ]
237
+
238
+ # Create the Gradio interface
239
+ with gr.Blocks(title="Text to SQL Converter") as demo:
240
+ gr.Markdown("# Text to SQL Query Converter")
241
+ gr.Markdown("Enter your question and optional table context to generate an SQL query.")
242
+
243
+
244
+
245
+ # Launch the app
246
+ demo.launch()