Sid26Roy commited on
Commit
e8e1f28
Β·
verified Β·
1 Parent(s): 495c53e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -32
app.py CHANGED
@@ -6,32 +6,15 @@ import gradio as gr
6
  model_name = "defog/llama-3-sqlcoder-8b"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
- # Check GPU memory if available, otherwise default to 4-bit mode
10
- def get_model():
11
- try:
12
- available_memory = torch.cuda.get_device_properties(0).total_memory
13
- except:
14
- available_memory = 0
15
-
16
- if available_memory > 20e9:
17
- return AutoModelForCausalLM.from_pretrained(
18
- model_name,
19
- trust_remote_code=True,
20
- torch_dtype=torch.float16,
21
- device_map="auto",
22
- use_cache=True,
23
- )
24
- else:
25
- return AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- trust_remote_code=True,
28
- load_in_4bit=True,
29
- device_map="auto",
30
- use_cache=True,
31
- )
32
-
33
- model = get_model()
34
 
 
35
  prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
36
 
37
  Generate a SQL query to answer this question: `{question}`
@@ -80,9 +63,10 @@ The following SQL query best answers the question `{question}`:
80
  ```sql
81
  """
82
 
 
83
  def generate_query(question):
84
  formatted_prompt = prompt.format(question=question)
85
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
86
 
87
  generated_ids = model.generate(
88
  **inputs,
@@ -100,16 +84,16 @@ def generate_query(question):
100
  try:
101
  sql_code = output.split("```sql")[1].split("```")[0].strip()
102
  return sqlparse.format(sql_code, reindent=True)
103
- except:
104
- return "SQL could not be parsed. Raw Output:\n\n" + output
105
 
106
- # Gradio Interface
107
  iface = gr.Interface(
108
  fn=generate_query,
109
- inputs=gr.Textbox(lines=3, placeholder="Enter your natural language question..."),
110
  outputs="text",
111
- title="LLaMA 3 SQLCoder πŸ¦™",
112
- description="Enter a natural language question and get a SQL query based on predefined tables.",
113
  )
114
 
115
  if __name__ == "__main__":
 
6
  model_name = "defog/llama-3-sqlcoder-8b"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
+ # Load model on CPU
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ trust_remote_code=True,
13
+ device_map={"": "cpu"},
14
+ torch_dtype=torch.float32
15
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # SQL Prompt Template
18
  prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
19
 
20
  Generate a SQL query to answer this question: `{question}`
 
63
  ```sql
64
  """
65
 
66
+ # Main function
67
  def generate_query(question):
68
  formatted_prompt = prompt.format(question=question)
69
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
70
 
71
  generated_ids = model.generate(
72
  **inputs,
 
84
  try:
85
  sql_code = output.split("```sql")[1].split("```")[0].strip()
86
  return sqlparse.format(sql_code, reindent=True)
87
+ except Exception:
88
+ return "❌ SQL could not be parsed. Raw Output:\n\n" + output
89
 
90
+ # Gradio UI
91
  iface = gr.Interface(
92
  fn=generate_query,
93
+ inputs=gr.Textbox(lines=3, placeholder="Ask your SQL question..."),
94
  outputs="text",
95
+ title="πŸ¦™ LLaMA 3 SQLCoder (CPU)",
96
+ description="Convert natural language into SQL queries based on the given schema. Running on CPU – may be slow.",
97
  )
98
 
99
  if __name__ == "__main__":