Sid26Roy commited on
Commit
c35bbf4
·
verified ·
1 Parent(s): a56d439

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -61
app.py CHANGED
@@ -1,57 +1,93 @@
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import sqlparse
4
- import gradio as gr
 
 
 
 
 
5
 
6
  model_name = "defog/llama-3-sqlcoder-8b"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
- # Load model on CPU
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_name,
12
- trust_remote_code=True,
13
- device_map={"": "cpu"},
14
- torch_dtype=torch.float32
15
- )
16
 
17
- # SQL Prompt Template
18
- prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  Generate a SQL query to answer this question: `{question}`
21
 
22
  DDL statements:
23
 
24
  CREATE TABLE expenses (
25
- id INTEGER PRIMARY KEY,
26
- date DATE NOT NULL,
27
- amount DECIMAL(10,2) NOT NULL,
28
- category VARCHAR(50) NOT NULL,
29
- description TEXT,
30
- payment_method VARCHAR(20),
31
- user_id INTEGER
32
  );
33
 
34
  CREATE TABLE categories (
35
- id INTEGER PRIMARY KEY,
36
- name VARCHAR(50) UNIQUE NOT NULL,
37
- description TEXT
38
  );
39
 
40
  CREATE TABLE users (
41
- id INTEGER PRIMARY KEY,
42
- username VARCHAR(50) UNIQUE NOT NULL,
43
- email VARCHAR(100) UNIQUE NOT NULL,
44
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
45
  );
46
 
47
  CREATE TABLE budgets (
48
- id INTEGER PRIMARY KEY,
49
- user_id INTEGER,
50
- category VARCHAR(50),
51
- amount DECIMAL(10,2) NOT NULL,
52
- period VARCHAR(20) DEFAULT 'monthly',
53
- start_date DATE,
54
- end_date DATE
55
  );
56
 
57
  -- expenses.user_id can be joined with users.id
@@ -63,38 +99,78 @@ The following SQL query best answers the question `{question}`:
63
  ```sql
64
  """
65
 
66
- # Main function
67
  def generate_query(question):
68
- formatted_prompt = prompt.format(question=question)
69
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
70
-
71
- generated_ids = model.generate(
72
- **inputs,
73
- num_return_sequences=1,
74
- eos_token_id=tokenizer.eos_token_id,
75
- pad_token_id=tokenizer.eos_token_id,
76
- max_new_tokens=400,
77
- do_sample=False,
78
- num_beams=1,
79
- temperature=0.0,
80
- top_p=1,
81
- )
82
-
83
- output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
84
  try:
85
- sql_code = output.split("```sql")[1].split("```")[0].strip()
86
- return sqlparse.format(sql_code, reindent=True)
87
- except Exception:
88
- return "❌ SQL could not be parsed. Raw Output:\n\n" + output
89
-
90
- # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  iface = gr.Interface(
92
- fn=generate_query,
93
- inputs=gr.Textbox(lines=3, placeholder="Ask your SQL question..."),
94
- outputs="text",
95
- title="🦙 LLaMA 3 SQLCoder (CPU)",
96
- description="Convert natural language into SQL queries based on the given schema. Running on CPU – may be slow.",
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
98
 
99
  if __name__ == "__main__":
100
- iface.launch()
 
1
+ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import sqlparse
5
+ import psutil
6
+ import os
7
+
8
+ # Check available memory
9
+ def get_available_memory():
10
+ return psutil.virtual_memory().available
11
 
12
  model_name = "defog/llama-3-sqlcoder-8b"
 
13
 
14
+ # Initialize tokenizer
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
16
 
17
+ # CPU-compatible model loading
18
+ def load_model():
19
+ try:
20
+ available_memory = get_available_memory()
21
+ print(f"Available memory: {available_memory / 1e9:.1f} GB")
22
+
23
+ # For CPU deployment, we'll use float32 or float16 without quantization
24
+ if available_memory > 16e9: # 16GB+ RAM
25
+ print("Loading model in float16...")
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ trust_remote_code=True,
29
+ torch_dtype=torch.float16,
30
+ device_map="cpu",
31
+ use_cache=True,
32
+ low_cpu_mem_usage=True
33
+ )
34
+ else:
35
+ print("Loading model in float32 with low memory usage...")
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
38
+ trust_remote_code=True,
39
+ device_map="cpu",
40
+ use_cache=True,
41
+ low_cpu_mem_usage=True,
42
+ torch_dtype=torch.float32
43
+ )
44
+
45
+ return model
46
+ except Exception as e:
47
+ print(f"Error loading model: {e}")
48
+ return None
49
+
50
+ # Load model (this will take some time on first run)
51
+ print("Loading model... This may take a few minutes on CPU.")
52
+ model = load_model()
53
+
54
+ prompt_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
55
 
56
  Generate a SQL query to answer this question: `{question}`
57
 
58
  DDL statements:
59
 
60
  CREATE TABLE expenses (
61
+ id INTEGER PRIMARY KEY, -- Unique ID for each expense
62
+ date DATE NOT NULL, -- Date when the expense occurred
63
+ amount DECIMAL(10,2) NOT NULL, -- Amount spent
64
+ category VARCHAR(50) NOT NULL, -- Category of expense (food, transport, utilities, etc.)
65
+ description TEXT, -- Optional description of the expense
66
+ payment_method VARCHAR(20), -- How the payment was made (cash, credit_card, debit_card, bank_transfer)
67
+ user_id INTEGER -- ID of the user who made the expense
68
  );
69
 
70
  CREATE TABLE categories (
71
+ id INTEGER PRIMARY KEY, -- Unique ID for each category
72
+ name VARCHAR(50) UNIQUE NOT NULL, -- Category name (food, transport, utilities, entertainment, etc.)
73
+ description TEXT -- Optional description of the category
74
  );
75
 
76
  CREATE TABLE users (
77
+ id INTEGER PRIMARY KEY, -- Unique ID for each user
78
+ username VARCHAR(50) UNIQUE NOT NULL, -- Username
79
+ email VARCHAR(100) UNIQUE NOT NULL, -- Email address
80
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- When the user account was created
81
  );
82
 
83
  CREATE TABLE budgets (
84
+ id INTEGER PRIMARY KEY, -- Unique ID for each budget
85
+ user_id INTEGER, -- ID of the user who set the budget
86
+ category VARCHAR(50), -- Category for which budget is set
87
+ amount DECIMAL(10,2) NOT NULL, -- Budget amount
88
+ period VARCHAR(20) DEFAULT 'monthly', -- Budget period (daily, weekly, monthly, yearly)
89
+ start_date DATE, -- Budget start date
90
+ end_date DATE -- Budget end date
91
  );
92
 
93
  -- expenses.user_id can be joined with users.id
 
99
  ```sql
100
  """
101
 
 
102
  def generate_query(question):
103
+ if model is None:
104
+ return "Error: Model not loaded properly"
105
+
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  try:
107
+ updated_prompt = prompt_template.format(question=question)
108
+ inputs = tokenizer(updated_prompt, return_tensors="pt")
109
+
110
+ # Generate on CPU
111
+ with torch.no_grad():
112
+ generated_ids = model.generate(
113
+ **inputs,
114
+ num_return_sequences=1,
115
+ eos_token_id=tokenizer.eos_token_id,
116
+ pad_token_id=tokenizer.eos_token_id,
117
+ max_new_tokens=400,
118
+ do_sample=False,
119
+ num_beams=1,
120
+ temperature=0.0,
121
+ top_p=1,
122
+ )
123
+
124
+ outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
125
+
126
+ # Extract SQL from output
127
+ if "```sql" in outputs[0]:
128
+ sql_part = outputs[0].split("```sql")[1].split("```")[0].strip()
129
+ else:
130
+ # Fallback extraction
131
+ sql_part = outputs[0].split("The following SQL query best answers the question")[1].strip()
132
+ if sql_part.startswith("`"):
133
+ sql_part = sql_part[1:]
134
+ if "```" in sql_part:
135
+ sql_part = sql_part.split("```")[0].strip()
136
+
137
+ # Clean up the SQL
138
+ if sql_part.endswith(";"):
139
+ sql_part = sql_part[:-1]
140
+
141
+ # Format the SQL
142
+ formatted_sql = sqlparse.format(sql_part, reindent=True, keyword_case='upper')
143
+ return formatted_sql
144
+
145
+ except Exception as e:
146
+ return f"Error generating query: {str(e)}"
147
+
148
+ def gradio_interface(question):
149
+ if not question.strip():
150
+ return "Please enter a question."
151
+
152
+ return generate_query(question)
153
+
154
+ # Create Gradio interface
155
  iface = gr.Interface(
156
+ fn=gradio_interface,
157
+ inputs=gr.Textbox(
158
+ label="Question",
159
+ placeholder="Enter your question (e.g., 'Show me all expenses for food category')",
160
+ lines=3
161
+ ),
162
+ outputs=gr.Code(label="Generated SQL Query", language="sql"),
163
+ title="SQL Query Generator",
164
+ description="Generate SQL queries from natural language questions about expense tracking database.",
165
+ examples=[
166
+ ["Show me all expenses for food category"],
167
+ ["What's the total amount spent on transport this month?"],
168
+ ["Insert a new expense of 50 dollars for groceries on 2024-01-15"],
169
+ ["Find users who spent more than 1000 dollars total"],
170
+ ["Show me the budget vs actual spending for each category"]
171
+ ],
172
+ cache_examples=False
173
  )
174
 
175
  if __name__ == "__main__":
176
+ iface.launch()