Sid26Roy commited on
Commit
495c53e
·
verified ·
1 Parent(s): 56e4575

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import sqlparse
4
+ import gradio as gr
5
+
6
+ model_name = "defog/llama-3-sqlcoder-8b"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
+ # Check GPU memory if available, otherwise default to 4-bit mode
10
+ def get_model():
11
+ try:
12
+ available_memory = torch.cuda.get_device_properties(0).total_memory
13
+ except:
14
+ available_memory = 0
15
+
16
+ if available_memory > 20e9:
17
+ return AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto",
22
+ use_cache=True,
23
+ )
24
+ else:
25
+ return AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ trust_remote_code=True,
28
+ load_in_4bit=True,
29
+ device_map="auto",
30
+ use_cache=True,
31
+ )
32
+
33
+ model = get_model()
34
+
35
+ prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
36
+
37
+ Generate a SQL query to answer this question: `{question}`
38
+
39
+ DDL statements:
40
+
41
+ CREATE TABLE expenses (
42
+ id INTEGER PRIMARY KEY,
43
+ date DATE NOT NULL,
44
+ amount DECIMAL(10,2) NOT NULL,
45
+ category VARCHAR(50) NOT NULL,
46
+ description TEXT,
47
+ payment_method VARCHAR(20),
48
+ user_id INTEGER
49
+ );
50
+
51
+ CREATE TABLE categories (
52
+ id INTEGER PRIMARY KEY,
53
+ name VARCHAR(50) UNIQUE NOT NULL,
54
+ description TEXT
55
+ );
56
+
57
+ CREATE TABLE users (
58
+ id INTEGER PRIMARY KEY,
59
+ username VARCHAR(50) UNIQUE NOT NULL,
60
+ email VARCHAR(100) UNIQUE NOT NULL,
61
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
62
+ );
63
+
64
+ CREATE TABLE budgets (
65
+ id INTEGER PRIMARY KEY,
66
+ user_id INTEGER,
67
+ category VARCHAR(50),
68
+ amount DECIMAL(10,2) NOT NULL,
69
+ period VARCHAR(20) DEFAULT 'monthly',
70
+ start_date DATE,
71
+ end_date DATE
72
+ );
73
+
74
+ -- expenses.user_id can be joined with users.id
75
+ -- expenses.category can be joined with categories.name
76
+ -- budgets.user_id can be joined with users.id
77
+ -- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|>
78
+
79
+ The following SQL query best answers the question `{question}`:
80
+ ```sql
81
+ """
82
+
83
+ def generate_query(question):
84
+ formatted_prompt = prompt.format(question=question)
85
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
86
+
87
+ generated_ids = model.generate(
88
+ **inputs,
89
+ num_return_sequences=1,
90
+ eos_token_id=tokenizer.eos_token_id,
91
+ pad_token_id=tokenizer.eos_token_id,
92
+ max_new_tokens=400,
93
+ do_sample=False,
94
+ num_beams=1,
95
+ temperature=0.0,
96
+ top_p=1,
97
+ )
98
+
99
+ output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
100
+ try:
101
+ sql_code = output.split("```sql")[1].split("```")[0].strip()
102
+ return sqlparse.format(sql_code, reindent=True)
103
+ except:
104
+ return "SQL could not be parsed. Raw Output:\n\n" + output
105
+
106
+ # Gradio Interface
107
+ iface = gr.Interface(
108
+ fn=generate_query,
109
+ inputs=gr.Textbox(lines=3, placeholder="Enter your natural language question..."),
110
+ outputs="text",
111
+ title="LLaMA 3 SQLCoder 🦙",
112
+ description="Enter a natural language question and get a SQL query based on predefined tables.",
113
+ )
114
+
115
+ if __name__ == "__main__":
116
+ iface.launch()