Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import re | |
import sqlparse | |
# Load model and tokenizer | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained( | |
"onkolahmet/Qwen2-0.5B-Instruct-SQL-generator", | |
torch_dtype="auto", | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator") | |
# # Few-shot examples to include in each prompt | |
# examples = [ | |
# { | |
# "question": "Get the names and emails of customers who placed an order in the last 30 days.", | |
# "sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);" | |
# }, | |
# { | |
# "question": "Find all employees with a salary greater than 50000.", | |
# "sql": "SELECT * FROM employees WHERE salary > 50000;" | |
# }, | |
# { | |
# "question": "List all product names and their categories where the price is below 50.", | |
# "sql": "SELECT name, category FROM products WHERE price < 50;" | |
# }, | |
# { | |
# "question": "How many users registered in the year 2022?", | |
# "sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;" | |
# } | |
# ] | |
def generate_sql(question, context=None): | |
# Construct prompt with few-shot examples and context if available | |
prompt = "Translate natural language questions to SQL queries.\n\n" | |
# Add table context if available | |
if context and context.strip(): | |
prompt += f"Table Context:\n{context}\n\n" | |
# # Add few-shot examples | |
# for ex in examples: | |
# prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n" | |
# Add the current question | |
prompt += f"Q: {question}\nSQL:" | |
# Tokenize and generate | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Generate SQL query | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=128, | |
do_sample=True, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Extract and decode only the new generation | |
sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) | |
return sql_query.strip() | |
def clean_sql_output(sql_text): | |
""" | |
Clean and deduplicate SQL queries: | |
1. Remove comments | |
2. Remove duplicate queries | |
3. Extract only the most relevant query | |
4. Format properly | |
""" | |
# Remove SQL comments (both single line and multi-line) | |
sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE) | |
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL) | |
# Remove markdown code block syntax if present | |
sql_text = re.sub(r'```sql|```', '', sql_text) | |
# Split into individual queries if multiple exist | |
if ';' in sql_text: | |
queries = [q.strip() for q in sql_text.split(';') if q.strip()] | |
else: | |
# If no semicolons, try to identify separate queries by SELECT statements | |
sql_text_cleaned = re.sub(r'\s+', ' ', sql_text) | |
select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE)) | |
if len(select_matches) > 1: | |
queries = [] | |
for i in range(len(select_matches)): | |
start = select_matches[i].start() | |
end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned) | |
queries.append(sql_text_cleaned[start:end].strip()) | |
else: | |
queries = [sql_text] | |
# Remove empty queries | |
queries = [q for q in queries if q.strip()] | |
if not queries: | |
return "" | |
# If we have multiple queries, need to deduplicate | |
if len(queries) > 1: | |
# Normalize queries for comparison (lowercase, remove extra spaces) | |
normalized_queries = [] | |
for q in queries: | |
# Use sqlparse to format and normalize | |
try: | |
formatted = sqlparse.format( | |
q + ('' if q.strip().endswith(';') else ';'), | |
keyword_case='lower', | |
identifier_case='lower', | |
strip_comments=True, | |
reindent=True | |
) | |
normalized_queries.append(formatted) | |
except: | |
# If sqlparse fails, just do basic normalization | |
normalized = re.sub(r'\s+', ' ', q.lower().strip()) | |
normalized_queries.append(normalized) | |
# Find unique queries | |
unique_queries = [] | |
unique_normalized = [] | |
for i, norm_q in enumerate(normalized_queries): | |
if norm_q not in unique_normalized: | |
unique_normalized.append(norm_q) | |
unique_queries.append(queries[i]) | |
# Choose the most likely correct query: | |
# 1. Prefer queries with SELECT | |
# 2. Prefer longer queries (often more detailed) | |
# 3. Prefer first query if all else equal | |
select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)] | |
if select_queries: | |
# Choose the longest SELECT query (likely most detailed) | |
best_query = max(select_queries, key=len) | |
elif unique_queries: | |
# If no SELECT queries, choose the longest query | |
best_query = max(unique_queries, key=len) | |
else: | |
# Fallback to the first query | |
best_query = queries[0] | |
else: | |
best_query = queries[0] | |
# Clean up the chosen query | |
best_query = best_query.strip() | |
if not best_query.endswith(';'): | |
best_query += ';' | |
# Final formatting to ensure consistent spacing | |
best_query = re.sub(r'\s+', ' ', best_query) | |
try: | |
# Use sqlparse to nicely format the SQL for display | |
formatted_sql = sqlparse.format( | |
best_query, | |
keyword_case='upper', | |
identifier_case='lower', | |
reindent=True, | |
indent_width=2 | |
) | |
return formatted_sql | |
except: | |
return best_query | |
def process_input(question, table_context): | |
"""Function to process user input through the model and return formatted results""" | |
if not question.strip(): | |
return "Please enter a question." | |
# Generate SQL from the question and context | |
raw_sql = generate_sql(question, table_context) | |
# Clean the SQL output | |
cleaned_sql = clean_sql_output(raw_sql) | |
if not cleaned_sql: | |
return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question." | |
return cleaned_sql | |
# Sample table context examples for the example selector | |
example_contexts = [ | |
# Example 1 | |
""" | |
CREATE TABLE customers ( | |
id INT PRIMARY KEY, | |
name VARCHAR(100), | |
email VARCHAR(100), | |
order_date DATE | |
); | |
""", | |
# Example 2 | |
""" | |
CREATE TABLE products ( | |
id INT PRIMARY KEY, | |
name VARCHAR(100), | |
category VARCHAR(50), | |
price DECIMAL(10,2), | |
stock_quantity INT | |
); | |
""", | |
# Example 3 | |
""" | |
CREATE TABLE employees ( | |
id INT PRIMARY KEY, | |
name VARCHAR(100), | |
department VARCHAR(50), | |
salary DECIMAL(10,2), | |
hire_date DATE | |
); | |
CREATE TABLE departments ( | |
id INT PRIMARY KEY, | |
name VARCHAR(50), | |
manager_id INT, | |
budget DECIMAL(15,2) | |
); | |
""" | |
] | |
# Sample question examples | |
example_questions = [ | |
"Get the names and emails of customers who placed an order in the last 30 days.", | |
"Find all products with less than 10 items in stock.", | |
"List all employees in the Sales department with a salary greater than 50000.", | |
"What is the total budget for departments with more than 5 employees?", | |
"Count how many products are in each category where the price is greater than 100." | |
] | |
# Create the Gradio interface | |
with gr.Blocks(title="Text to SQL Converter") as demo: | |
gr.Markdown("# Text to SQL Query Converter") | |
gr.Markdown("Enter your question and optional table context to generate an SQL query.") | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox( | |
label="Your Question", | |
placeholder="e.g., Find all products with price less than $50", | |
lines=2 | |
) | |
table_context = gr.Textbox( | |
label="Table Context (Optional)", | |
placeholder="Enter your database schema or table definitions here...", | |
lines=10 | |
) | |
submit_btn = gr.Button("Generate SQL Query") | |
with gr.Column(): | |
sql_output = gr.Code( | |
label="Generated SQL Query", | |
language="sql", | |
lines=12 | |
) | |
# Examples section | |
gr.Markdown("### Try some examples") | |
example_selector = gr.Examples( | |
examples=[ | |
["List all products in the 'Electronics' category with price less than $500", example_contexts[1]], | |
["Find the total number of employees in each department", example_contexts[2]], | |
["Get customers who placed orders in the last 7 days", example_contexts[0]], | |
["Count the number of products in each category", example_contexts[1]], | |
["Find the average salary by department", example_contexts[2]] | |
], | |
inputs=[question_input, table_context] | |
) | |
# Set up the submit button to trigger the process_input function | |
submit_btn.click( | |
fn=process_input, | |
inputs=[question_input, table_context], | |
outputs=sql_output | |
) | |
# Also trigger on pressing Enter in the question input | |
question_input.submit( | |
fn=process_input, | |
inputs=[question_input, table_context], | |
outputs=sql_output | |
) | |
# Add information about the model | |
gr.Markdown(""" | |
### About | |
This app uses a fine-tuned language model to convert natural language questions into SQL queries. | |
- **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator) | |
- **How to use**: | |
1. Enter your question in natural language | |
2. If you have specific table schemas, add them in the Table Context field | |
3. Click "Generate SQL Query" or press Enter | |
Note: The model works best when table context is provided, but can generate generic SQL queries without it. | |
""") | |
# Launch the app | |
demo.launch() |