onkolahmet's picture
Update app.py
e7e63fd verified
raw
history blame
11.8 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
import sqlparse
import time
# App title and description
st.set_page_config(page_title="Text to SQL Converter", layout="wide")
st.title("Text to SQL Query Converter")
st.markdown("Enter your question and optional table context to generate an SQL query.")
# Model loading (with loading state indicator)
@st.cache_resource
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
st.info(f"Loading model on {device}... This may take a minute.")
# Load without device_map for compatibility
model = AutoModelForCausalLM.from_pretrained(
"onkolahmet/text_to_sql",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("onkolahmet/text_to_sql")
return model, tokenizer, device
# Few-shot examples to include in each prompt
examples = [
{
"question": "Get the names and emails of customers who placed an order in the last 30 days.",
"sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);"
},
{
"question": "Find all employees with a salary greater than 50000.",
"sql": "SELECT * FROM employees WHERE salary > 50000;"
},
{
"question": "List all product names and their categories where the price is below 50.",
"sql": "SELECT name, category FROM products WHERE price < 50;"
},
{
"question": "How many users registered in the year 2022?",
"sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;"
}
]
def generate_sql(model, tokenizer, device, question, context=None):
# Construct prompt with few-shot examples and context if available
prompt = "Translate natural language questions to SQL queries.\n\n"
# Add table context if available
if context and context.strip():
prompt += f"Table Context:\n{context}\n\n"
# Add few-shot examples
for ex in examples:
prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
# Add the current question
prompt += f"Q: {question}\nSQL:"
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate SQL query
outputs = model.generate(
inputs.input_ids,
max_new_tokens=128,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
# Extract and decode only the new generation
sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
return sql_query.strip()
def clean_sql_output(sql_text):
"""
Clean and deduplicate SQL queries:
1. Remove comments
2. Remove duplicate queries
3. Extract only the most relevant query
4. Format properly
"""
# Remove SQL comments (both single line and multi-line)
sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE)
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
# Remove markdown code block syntax if present
sql_text = re.sub(r'```sql|```', '', sql_text)
# Split into individual queries if multiple exist
if ';' in sql_text:
queries = [q.strip() for q in sql_text.split(';') if q.strip()]
else:
# If no semicolons, try to identify separate queries by SELECT statements
sql_text_cleaned = re.sub(r'\s+', ' ', sql_text)
select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE))
if len(select_matches) > 1:
queries = []
for i in range(len(select_matches)):
start = select_matches[i].start()
end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned)
queries.append(sql_text_cleaned[start:end].strip())
else:
queries = [sql_text]
# Remove empty queries
queries = [q for q in queries if q.strip()]
if not queries:
return ""
# If we have multiple queries, need to deduplicate
if len(queries) > 1:
# Normalize queries for comparison (lowercase, remove extra spaces)
normalized_queries = []
for q in queries:
# Use sqlparse to format and normalize
try:
formatted = sqlparse.format(
q + ('' if q.strip().endswith(';') else ';'),
keyword_case='lower',
identifier_case='lower',
strip_comments=True,
reindent=True
)
normalized_queries.append(formatted)
except:
# If sqlparse fails, just do basic normalization
normalized = re.sub(r'\s+', ' ', q.lower().strip())
normalized_queries.append(normalized)
# Find unique queries
unique_queries = []
unique_normalized = []
for i, norm_q in enumerate(normalized_queries):
if norm_q not in unique_normalized:
unique_normalized.append(norm_q)
unique_queries.append(queries[i])
# Choose the most likely correct query:
# 1. Prefer queries with SELECT
# 2. Prefer longer queries (often more detailed)
# 3. Prefer first query if all else equal
select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)]
if select_queries:
# Choose the longest SELECT query (likely most detailed)
best_query = max(select_queries, key=len)
elif unique_queries:
# If no SELECT queries, choose the longest query
best_query = max(unique_queries, key=len)
else:
# Fallback to the first query
best_query = queries[0]
else:
best_query = queries[0]
# Clean up the chosen query
best_query = best_query.strip()
if not best_query.endswith(';'):
best_query += ';'
# Final formatting to ensure consistent spacing
best_query = re.sub(r'\s+', ' ', best_query)
try:
# Use sqlparse to nicely format the SQL for display
formatted_sql = sqlparse.format(
best_query,
keyword_case='upper',
identifier_case='lower',
reindent=True,
indent_width=2
)
return formatted_sql
except:
return best_query
# Load model (this happens once when the app starts)
model, tokenizer, device = load_model()
st.success("Model loaded successfully! Ready to generate SQL queries.")
# Main app interface
col1, col2 = st.columns([1, 1])
with col1:
# User inputs
question = st.text_area("Your Question",
placeholder="e.g., Find all products with price less than $50",
height=100)
table_context = st.text_area("Table Context (Optional)",
placeholder="Enter your database schema or table definitions here...",
height=200)
# Example selection
with st.expander("Try an example", expanded=False):
example_option = st.selectbox(
"Select an example:",
[
"List all products in the 'Electronics' category with price less than $500",
"Find the total number of employees in each department",
"Get customers who placed orders in the last 7 days",
"Count the number of products in each category",
"Find the average salary by department"
]
)
# Sample table context examples mapped to questions
example_contexts = {
"List all products in the 'Electronics' category with price less than $500": """
CREATE TABLE products (
id INT PRIMARY KEY,
name VARCHAR(100),
category VARCHAR(50),
price DECIMAL(10,2),
stock_quantity INT
);
""",
"Find the total number of employees in each department": """
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10,2),
hire_date DATE
);
CREATE TABLE departments (
id INT PRIMARY KEY,
name VARCHAR(50),
manager_id INT,
budget DECIMAL(15,2)
);
""",
"Get customers who placed orders in the last 7 days": """
CREATE TABLE customers (
id INT PRIMARY KEY,
name VARCHAR(100),
email VARCHAR(100),
order_date DATE
);
""",
"Count the number of products in each category": """
CREATE TABLE products (
id INT PRIMARY KEY,
name VARCHAR(100),
category VARCHAR(50),
price DECIMAL(10,2),
stock_quantity INT
);
""",
"Find the average salary by department": """
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10,2),
hire_date DATE
);
"""
}
apply_example = st.button("Apply Example")
if apply_example:
question = example_option
table_context = example_contexts[example_option]
st.session_state.question = question
st.session_state.table_context = table_context
st.success("Example applied! Click 'Generate SQL Query' to see the result.")
# Button to generate SQL
generate_button = st.button("Generate SQL Query")
# Display results
with col2:
if generate_button and question:
with st.spinner("Generating SQL query..."):
# Record start time
start_time = time.time()
# Generate SQL
raw_sql = generate_sql(model, tokenizer, device, question, table_context)
cleaned_sql = clean_sql_output(raw_sql)
# Calculate elapsed time
elapsed_time = time.time() - start_time
# Display results
st.subheader("Generated SQL Query")
st.code(cleaned_sql, language="sql")
st.info(f"Query generated in {elapsed_time:.2f} seconds")
# Display explanation
st.subheader("Explanation")
st.write("This SQL query translates your natural language question into a database command.")
# Option to copy to clipboard (using JavaScript)
st.markdown(
f"""
<div style="margin-top: 20px;">
<button
onclick="navigator.clipboard.writeText(`{cleaned_sql}`);this.textContent='Copied!';setTimeout(()=>this.textContent='Copy to Clipboard',1500)"
style="background-color:#4CAF50;color:white;padding:8px 16px;border:none;border-radius:4px;cursor:pointer"
>
Copy to Clipboard
</button>
</div>
""",
unsafe_allow_html=True
)
else:
st.info("Enter a question and click 'Generate SQL Query' to see the result here.")
# App footer and info
st.markdown("---")
st.markdown("""
### About
This app uses a fine-tuned language model to convert natural language questions into SQL queries.
- **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator)
- **How to use**:
1. Enter your question in natural language
2. If you have specific table schemas, add them in the Table Context field
3. Click "Generate SQL Query"
Note: The model works best when table context is provided, but can generate generic SQL queries without it.
""")