import torch from transformers import AutoTokenizer, AutoModelForCausalLM import sqlparse import gradio as gr model_name = "defog/llama-3-sqlcoder-8b" tokenizer = AutoTokenizer.from_pretrained(model_name) # Load model on CPU model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, device_map={"": "cpu"}, torch_dtype=torch.float32 ) # SQL Prompt Template prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|> Generate a SQL query to answer this question: `{question}` DDL statements: CREATE TABLE expenses ( id INTEGER PRIMARY KEY, date DATE NOT NULL, amount DECIMAL(10,2) NOT NULL, category VARCHAR(50) NOT NULL, description TEXT, payment_method VARCHAR(20), user_id INTEGER ); CREATE TABLE categories ( id INTEGER PRIMARY KEY, name VARCHAR(50) UNIQUE NOT NULL, description TEXT ); CREATE TABLE users ( id INTEGER PRIMARY KEY, username VARCHAR(50) UNIQUE NOT NULL, email VARCHAR(100) UNIQUE NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE budgets ( id INTEGER PRIMARY KEY, user_id INTEGER, category VARCHAR(50), amount DECIMAL(10,2) NOT NULL, period VARCHAR(20) DEFAULT 'monthly', start_date DATE, end_date DATE ); -- expenses.user_id can be joined with users.id -- expenses.category can be joined with categories.name -- budgets.user_id can be joined with users.id -- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|> The following SQL query best answers the question `{question}`: ```sql """ # Main function def generate_query(question): formatted_prompt = prompt.format(question=question) inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu") generated_ids = model.generate( **inputs, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, max_new_tokens=400, do_sample=False, num_beams=1, temperature=0.0, top_p=1, ) output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] try: sql_code = output.split("```sql")[1].split("```")[0].strip() return sqlparse.format(sql_code, reindent=True) except Exception: return "❌ SQL could not be parsed. Raw Output:\n\n" + output # Gradio UI iface = gr.Interface( fn=generate_query, inputs=gr.Textbox(lines=3, placeholder="Ask your SQL question..."), outputs="text", title="🦙 LLaMA 3 SQLCoder (CPU)", description="Convert natural language into SQL queries based on the given schema. Running on CPU – may be slow.", ) if __name__ == "__main__": iface.launch()