File size: 2,730 Bytes
495c53e e8e1f28 495c53e e8e1f28 495c53e e8e1f28 495c53e e8e1f28 495c53e e8e1f28 495c53e e8e1f28 495c53e e8e1f28 495c53e e8e1f28 495c53e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlparse
import gradio as gr
model_name = "defog/llama-3-sqlcoder-8b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model on CPU
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map={"": "cpu"},
torch_dtype=torch.float32
)
# SQL Prompt Template
prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Generate a SQL query to answer this question: `{question}`
DDL statements:
CREATE TABLE expenses (
id INTEGER PRIMARY KEY,
date DATE NOT NULL,
amount DECIMAL(10,2) NOT NULL,
category VARCHAR(50) NOT NULL,
description TEXT,
payment_method VARCHAR(20),
user_id INTEGER
);
CREATE TABLE categories (
id INTEGER PRIMARY KEY,
name VARCHAR(50) UNIQUE NOT NULL,
description TEXT
);
CREATE TABLE users (
id INTEGER PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE budgets (
id INTEGER PRIMARY KEY,
user_id INTEGER,
category VARCHAR(50),
amount DECIMAL(10,2) NOT NULL,
period VARCHAR(20) DEFAULT 'monthly',
start_date DATE,
end_date DATE
);
-- expenses.user_id can be joined with users.id
-- expenses.category can be joined with categories.name
-- budgets.user_id can be joined with users.id
-- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The following SQL query best answers the question `{question}`:
```sql
"""
# Main function
def generate_query(question):
formatted_prompt = prompt.format(question=question)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
generated_ids = model.generate(
**inputs,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=400,
do_sample=False,
num_beams=1,
temperature=0.0,
top_p=1,
)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
try:
sql_code = output.split("```sql")[1].split("```")[0].strip()
return sqlparse.format(sql_code, reindent=True)
except Exception:
return "β SQL could not be parsed. Raw Output:\n\n" + output
# Gradio UI
iface = gr.Interface(
fn=generate_query,
inputs=gr.Textbox(lines=3, placeholder="Ask your SQL question..."),
outputs="text",
title="π¦ LLaMA 3 SQLCoder (CPU)",
description="Convert natural language into SQL queries based on the given schema. Running on CPU β may be slow.",
)
if __name__ == "__main__":
iface.launch()
|