Spaces:
Runtime error
Runtime error
File size: 11,846 Bytes
e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd 1acff14 e7e63fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
import sqlparse
import time
# App title and description
st.set_page_config(page_title="Text to SQL Converter", layout="wide")
st.title("Text to SQL Query Converter")
st.markdown("Enter your question and optional table context to generate an SQL query.")
# Model loading (with loading state indicator)
@st.cache_resource
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
st.info(f"Loading model on {device}... This may take a minute.")
# Load without device_map for compatibility
model = AutoModelForCausalLM.from_pretrained(
"onkolahmet/text_to_sql",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("onkolahmet/text_to_sql")
return model, tokenizer, device
# Few-shot examples to include in each prompt
examples = [
{
"question": "Get the names and emails of customers who placed an order in the last 30 days.",
"sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);"
},
{
"question": "Find all employees with a salary greater than 50000.",
"sql": "SELECT * FROM employees WHERE salary > 50000;"
},
{
"question": "List all product names and their categories where the price is below 50.",
"sql": "SELECT name, category FROM products WHERE price < 50;"
},
{
"question": "How many users registered in the year 2022?",
"sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;"
}
]
def generate_sql(model, tokenizer, device, question, context=None):
# Construct prompt with few-shot examples and context if available
prompt = "Translate natural language questions to SQL queries.\n\n"
# Add table context if available
if context and context.strip():
prompt += f"Table Context:\n{context}\n\n"
# Add few-shot examples
for ex in examples:
prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
# Add the current question
prompt += f"Q: {question}\nSQL:"
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate SQL query
outputs = model.generate(
inputs.input_ids,
max_new_tokens=128,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
# Extract and decode only the new generation
sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
return sql_query.strip()
def clean_sql_output(sql_text):
"""
Clean and deduplicate SQL queries:
1. Remove comments
2. Remove duplicate queries
3. Extract only the most relevant query
4. Format properly
"""
# Remove SQL comments (both single line and multi-line)
sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE)
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
# Remove markdown code block syntax if present
sql_text = re.sub(r'```sql|```', '', sql_text)
# Split into individual queries if multiple exist
if ';' in sql_text:
queries = [q.strip() for q in sql_text.split(';') if q.strip()]
else:
# If no semicolons, try to identify separate queries by SELECT statements
sql_text_cleaned = re.sub(r'\s+', ' ', sql_text)
select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE))
if len(select_matches) > 1:
queries = []
for i in range(len(select_matches)):
start = select_matches[i].start()
end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned)
queries.append(sql_text_cleaned[start:end].strip())
else:
queries = [sql_text]
# Remove empty queries
queries = [q for q in queries if q.strip()]
if not queries:
return ""
# If we have multiple queries, need to deduplicate
if len(queries) > 1:
# Normalize queries for comparison (lowercase, remove extra spaces)
normalized_queries = []
for q in queries:
# Use sqlparse to format and normalize
try:
formatted = sqlparse.format(
q + ('' if q.strip().endswith(';') else ';'),
keyword_case='lower',
identifier_case='lower',
strip_comments=True,
reindent=True
)
normalized_queries.append(formatted)
except:
# If sqlparse fails, just do basic normalization
normalized = re.sub(r'\s+', ' ', q.lower().strip())
normalized_queries.append(normalized)
# Find unique queries
unique_queries = []
unique_normalized = []
for i, norm_q in enumerate(normalized_queries):
if norm_q not in unique_normalized:
unique_normalized.append(norm_q)
unique_queries.append(queries[i])
# Choose the most likely correct query:
# 1. Prefer queries with SELECT
# 2. Prefer longer queries (often more detailed)
# 3. Prefer first query if all else equal
select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)]
if select_queries:
# Choose the longest SELECT query (likely most detailed)
best_query = max(select_queries, key=len)
elif unique_queries:
# If no SELECT queries, choose the longest query
best_query = max(unique_queries, key=len)
else:
# Fallback to the first query
best_query = queries[0]
else:
best_query = queries[0]
# Clean up the chosen query
best_query = best_query.strip()
if not best_query.endswith(';'):
best_query += ';'
# Final formatting to ensure consistent spacing
best_query = re.sub(r'\s+', ' ', best_query)
try:
# Use sqlparse to nicely format the SQL for display
formatted_sql = sqlparse.format(
best_query,
keyword_case='upper',
identifier_case='lower',
reindent=True,
indent_width=2
)
return formatted_sql
except:
return best_query
# Load model (this happens once when the app starts)
model, tokenizer, device = load_model()
st.success("Model loaded successfully! Ready to generate SQL queries.")
# Main app interface
col1, col2 = st.columns([1, 1])
with col1:
# User inputs
question = st.text_area("Your Question",
placeholder="e.g., Find all products with price less than $50",
height=100)
table_context = st.text_area("Table Context (Optional)",
placeholder="Enter your database schema or table definitions here...",
height=200)
# Example selection
with st.expander("Try an example", expanded=False):
example_option = st.selectbox(
"Select an example:",
[
"List all products in the 'Electronics' category with price less than $500",
"Find the total number of employees in each department",
"Get customers who placed orders in the last 7 days",
"Count the number of products in each category",
"Find the average salary by department"
]
)
# Sample table context examples mapped to questions
example_contexts = {
"List all products in the 'Electronics' category with price less than $500": """
CREATE TABLE products (
id INT PRIMARY KEY,
name VARCHAR(100),
category VARCHAR(50),
price DECIMAL(10,2),
stock_quantity INT
);
""",
"Find the total number of employees in each department": """
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10,2),
hire_date DATE
);
CREATE TABLE departments (
id INT PRIMARY KEY,
name VARCHAR(50),
manager_id INT,
budget DECIMAL(15,2)
);
""",
"Get customers who placed orders in the last 7 days": """
CREATE TABLE customers (
id INT PRIMARY KEY,
name VARCHAR(100),
email VARCHAR(100),
order_date DATE
);
""",
"Count the number of products in each category": """
CREATE TABLE products (
id INT PRIMARY KEY,
name VARCHAR(100),
category VARCHAR(50),
price DECIMAL(10,2),
stock_quantity INT
);
""",
"Find the average salary by department": """
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10,2),
hire_date DATE
);
"""
}
apply_example = st.button("Apply Example")
if apply_example:
question = example_option
table_context = example_contexts[example_option]
st.session_state.question = question
st.session_state.table_context = table_context
st.success("Example applied! Click 'Generate SQL Query' to see the result.")
# Button to generate SQL
generate_button = st.button("Generate SQL Query")
# Display results
with col2:
if generate_button and question:
with st.spinner("Generating SQL query..."):
# Record start time
start_time = time.time()
# Generate SQL
raw_sql = generate_sql(model, tokenizer, device, question, table_context)
cleaned_sql = clean_sql_output(raw_sql)
# Calculate elapsed time
elapsed_time = time.time() - start_time
# Display results
st.subheader("Generated SQL Query")
st.code(cleaned_sql, language="sql")
st.info(f"Query generated in {elapsed_time:.2f} seconds")
# Display explanation
st.subheader("Explanation")
st.write("This SQL query translates your natural language question into a database command.")
# Option to copy to clipboard (using JavaScript)
st.markdown(
f"""
<div style="margin-top: 20px;">
<button
onclick="navigator.clipboard.writeText(`{cleaned_sql}`);this.textContent='Copied!';setTimeout(()=>this.textContent='Copy to Clipboard',1500)"
style="background-color:#4CAF50;color:white;padding:8px 16px;border:none;border-radius:4px;cursor:pointer"
>
Copy to Clipboard
</button>
</div>
""",
unsafe_allow_html=True
)
else:
st.info("Enter a question and click 'Generate SQL Query' to see the result here.")
# App footer and info
st.markdown("---")
st.markdown("""
### About
This app uses a fine-tuned language model to convert natural language questions into SQL queries.
- **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator)
- **How to use**:
1. Enter your question in natural language
2. If you have specific table schemas, add them in the Table Context field
3. Click "Generate SQL Query"
Note: The model works best when table context is provided, but can generate generic SQL queries without it.
""") |