Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import re | |
import sqlparse | |
import time | |
# App title and description | |
st.set_page_config(page_title="Text to SQL Converter", layout="wide") | |
st.title("Text to SQL Query Converter") | |
st.markdown("Enter your question and optional table context to generate an SQL query.") | |
# Model loading (with loading state indicator) | |
def load_model(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
st.info(f"Loading model on {device}... This may take a minute.") | |
# Load without device_map for compatibility | |
model = AutoModelForCausalLM.from_pretrained( | |
"onkolahmet/text_to_sql", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
model = model.to(device) | |
tokenizer = AutoTokenizer.from_pretrained("onkolahmet/text_to_sql") | |
return model, tokenizer, device | |
# Few-shot examples to include in each prompt | |
examples = [ | |
{ | |
"question": "Get the names and emails of customers who placed an order in the last 30 days.", | |
"sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);" | |
}, | |
{ | |
"question": "Find all employees with a salary greater than 50000.", | |
"sql": "SELECT * FROM employees WHERE salary > 50000;" | |
}, | |
{ | |
"question": "List all product names and their categories where the price is below 50.", | |
"sql": "SELECT name, category FROM products WHERE price < 50;" | |
}, | |
{ | |
"question": "How many users registered in the year 2022?", | |
"sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;" | |
} | |
] | |
def generate_sql(model, tokenizer, device, question, context=None): | |
# Construct prompt with few-shot examples and context if available | |
prompt = "Translate natural language questions to SQL queries.\n\n" | |
# Add table context if available | |
if context and context.strip(): | |
prompt += f"Table Context:\n{context}\n\n" | |
# Add few-shot examples | |
for ex in examples: | |
prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n" | |
# Add the current question | |
prompt += f"Q: {question}\nSQL:" | |
# Tokenize and generate | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Generate SQL query | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=128, | |
do_sample=True, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Extract and decode only the new generation | |
sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) | |
return sql_query.strip() | |
def clean_sql_output(sql_text): | |
""" | |
Clean and deduplicate SQL queries: | |
1. Remove comments | |
2. Remove duplicate queries | |
3. Extract only the most relevant query | |
4. Format properly | |
""" | |
# Remove SQL comments (both single line and multi-line) | |
sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE) | |
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL) | |
# Remove markdown code block syntax if present | |
sql_text = re.sub(r'```sql|```', '', sql_text) | |
# Split into individual queries if multiple exist | |
if ';' in sql_text: | |
queries = [q.strip() for q in sql_text.split(';') if q.strip()] | |
else: | |
# If no semicolons, try to identify separate queries by SELECT statements | |
sql_text_cleaned = re.sub(r'\s+', ' ', sql_text) | |
select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE)) | |
if len(select_matches) > 1: | |
queries = [] | |
for i in range(len(select_matches)): | |
start = select_matches[i].start() | |
end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned) | |
queries.append(sql_text_cleaned[start:end].strip()) | |
else: | |
queries = [sql_text] | |
# Remove empty queries | |
queries = [q for q in queries if q.strip()] | |
if not queries: | |
return "" | |
# If we have multiple queries, need to deduplicate | |
if len(queries) > 1: | |
# Normalize queries for comparison (lowercase, remove extra spaces) | |
normalized_queries = [] | |
for q in queries: | |
# Use sqlparse to format and normalize | |
try: | |
formatted = sqlparse.format( | |
q + ('' if q.strip().endswith(';') else ';'), | |
keyword_case='lower', | |
identifier_case='lower', | |
strip_comments=True, | |
reindent=True | |
) | |
normalized_queries.append(formatted) | |
except: | |
# If sqlparse fails, just do basic normalization | |
normalized = re.sub(r'\s+', ' ', q.lower().strip()) | |
normalized_queries.append(normalized) | |
# Find unique queries | |
unique_queries = [] | |
unique_normalized = [] | |
for i, norm_q in enumerate(normalized_queries): | |
if norm_q not in unique_normalized: | |
unique_normalized.append(norm_q) | |
unique_queries.append(queries[i]) | |
# Choose the most likely correct query: | |
# 1. Prefer queries with SELECT | |
# 2. Prefer longer queries (often more detailed) | |
# 3. Prefer first query if all else equal | |
select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)] | |
if select_queries: | |
# Choose the longest SELECT query (likely most detailed) | |
best_query = max(select_queries, key=len) | |
elif unique_queries: | |
# If no SELECT queries, choose the longest query | |
best_query = max(unique_queries, key=len) | |
else: | |
# Fallback to the first query | |
best_query = queries[0] | |
else: | |
best_query = queries[0] | |
# Clean up the chosen query | |
best_query = best_query.strip() | |
if not best_query.endswith(';'): | |
best_query += ';' | |
# Final formatting to ensure consistent spacing | |
best_query = re.sub(r'\s+', ' ', best_query) | |
try: | |
# Use sqlparse to nicely format the SQL for display | |
formatted_sql = sqlparse.format( | |
best_query, | |
keyword_case='upper', | |
identifier_case='lower', | |
reindent=True, | |
indent_width=2 | |
) | |
return formatted_sql | |
except: | |
return best_query | |
# Load model (this happens once when the app starts) | |
model, tokenizer, device = load_model() | |
st.success("Model loaded successfully! Ready to generate SQL queries.") | |
# Main app interface | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
# User inputs | |
question = st.text_area("Your Question", | |
placeholder="e.g., Find all products with price less than $50", | |
height=100) | |
table_context = st.text_area("Table Context (Optional)", | |
placeholder="Enter your database schema or table definitions here...", | |
height=200) | |
# Example selection | |
with st.expander("Try an example", expanded=False): | |
example_option = st.selectbox( | |
"Select an example:", | |
[ | |
"List all products in the 'Electronics' category with price less than $500", | |
"Find the total number of employees in each department", | |
"Get customers who placed orders in the last 7 days", | |
"Count the number of products in each category", | |
"Find the average salary by department" | |
] | |
) | |
# Sample table context examples mapped to questions | |
example_contexts = { | |
"List all products in the 'Electronics' category with price less than $500": """ | |
CREATE TABLE products ( | |
id INT PRIMARY KEY, | |
name VARCHAR(100), | |
category VARCHAR(50), | |
price DECIMAL(10,2), | |
stock_quantity INT | |
); | |
""", | |
"Find the total number of employees in each department": """ | |
CREATE TABLE employees ( | |
id INT PRIMARY KEY, | |
name VARCHAR(100), | |
department VARCHAR(50), | |
salary DECIMAL(10,2), | |
hire_date DATE | |
); | |
CREATE TABLE departments ( | |
id INT PRIMARY KEY, | |
name VARCHAR(50), | |
manager_id INT, | |
budget DECIMAL(15,2) | |
); | |
""", | |
"Get customers who placed orders in the last 7 days": """ | |
CREATE TABLE customers ( | |
id INT PRIMARY KEY, | |
name VARCHAR(100), | |
email VARCHAR(100), | |
order_date DATE | |
); | |
""", | |
"Count the number of products in each category": """ | |
CREATE TABLE products ( | |
id INT PRIMARY KEY, | |
name VARCHAR(100), | |
category VARCHAR(50), | |
price DECIMAL(10,2), | |
stock_quantity INT | |
); | |
""", | |
"Find the average salary by department": """ | |
CREATE TABLE employees ( | |
id INT PRIMARY KEY, | |
name VARCHAR(100), | |
department VARCHAR(50), | |
salary DECIMAL(10,2), | |
hire_date DATE | |
); | |
""" | |
} | |
apply_example = st.button("Apply Example") | |
if apply_example: | |
question = example_option | |
table_context = example_contexts[example_option] | |
st.session_state.question = question | |
st.session_state.table_context = table_context | |
st.success("Example applied! Click 'Generate SQL Query' to see the result.") | |
# Button to generate SQL | |
generate_button = st.button("Generate SQL Query") | |
# Display results | |
with col2: | |
if generate_button and question: | |
with st.spinner("Generating SQL query..."): | |
# Record start time | |
start_time = time.time() | |
# Generate SQL | |
raw_sql = generate_sql(model, tokenizer, device, question, table_context) | |
cleaned_sql = clean_sql_output(raw_sql) | |
# Calculate elapsed time | |
elapsed_time = time.time() - start_time | |
# Display results | |
st.subheader("Generated SQL Query") | |
st.code(cleaned_sql, language="sql") | |
st.info(f"Query generated in {elapsed_time:.2f} seconds") | |
# Display explanation | |
st.subheader("Explanation") | |
st.write("This SQL query translates your natural language question into a database command.") | |
# Option to copy to clipboard (using JavaScript) | |
st.markdown( | |
f""" | |
<div style="margin-top: 20px;"> | |
<button | |
onclick="navigator.clipboard.writeText(`{cleaned_sql}`);this.textContent='Copied!';setTimeout(()=>this.textContent='Copy to Clipboard',1500)" | |
style="background-color:#4CAF50;color:white;padding:8px 16px;border:none;border-radius:4px;cursor:pointer" | |
> | |
Copy to Clipboard | |
</button> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
else: | |
st.info("Enter a question and click 'Generate SQL Query' to see the result here.") | |
# App footer and info | |
st.markdown("---") | |
st.markdown(""" | |
### About | |
This app uses a fine-tuned language model to convert natural language questions into SQL queries. | |
- **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator) | |
- **How to use**: | |
1. Enter your question in natural language | |
2. If you have specific table schemas, add them in the Table Context field | |
3. Click "Generate SQL Query" | |
Note: The model works best when table context is provided, but can generate generic SQL queries without it. | |
""") |