File size: 2,730 Bytes
495c53e
 
 
 
 
 
 
 
e8e1f28
 
 
 
 
 
 
495c53e
e8e1f28
495c53e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e1f28
495c53e
 
e8e1f28
495c53e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e1f28
 
495c53e
e8e1f28
495c53e
 
e8e1f28
495c53e
e8e1f28
 
495c53e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlparse
import gradio as gr

model_name = "defog/llama-3-sqlcoder-8b"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load model on CPU
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    device_map={"": "cpu"},
    torch_dtype=torch.float32
)

# SQL Prompt Template
prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

DDL statements:

CREATE TABLE expenses (
  id INTEGER PRIMARY KEY,
  date DATE NOT NULL,
  amount DECIMAL(10,2) NOT NULL,
  category VARCHAR(50) NOT NULL,
  description TEXT,
  payment_method VARCHAR(20),
  user_id INTEGER
);

CREATE TABLE categories (
  id INTEGER PRIMARY KEY,
  name VARCHAR(50) UNIQUE NOT NULL,
  description TEXT
);

CREATE TABLE users (
  id INTEGER PRIMARY KEY,
  username VARCHAR(50) UNIQUE NOT NULL,
  email VARCHAR(100) UNIQUE NOT NULL,
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

CREATE TABLE budgets (
  id INTEGER PRIMARY KEY,
  user_id INTEGER,
  category VARCHAR(50),
  amount DECIMAL(10,2) NOT NULL,
  period VARCHAR(20) DEFAULT 'monthly',
  start_date DATE,
  end_date DATE
);

-- expenses.user_id can be joined with users.id
-- expenses.category can be joined with categories.name
-- budgets.user_id can be joined with users.id
-- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""

# Main function
def generate_query(question):
    formatted_prompt = prompt.format(question=question)
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")

    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
        temperature=0.0,
        top_p=1,
    )

    output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    try:
        sql_code = output.split("```sql")[1].split("```")[0].strip()
        return sqlparse.format(sql_code, reindent=True)
    except Exception:
        return "❌ SQL could not be parsed. Raw Output:\n\n" + output

# Gradio UI
iface = gr.Interface(
    fn=generate_query,
    inputs=gr.Textbox(lines=3, placeholder="Ask your SQL question..."),
    outputs="text",
    title="πŸ¦™ LLaMA 3 SQLCoder (CPU)",
    description="Convert natural language into SQL queries based on the given schema. Running on CPU – may be slow.",
)

if __name__ == "__main__":
    iface.launch()