|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
import gradio as gr |
|
|
|
|
|
model_dir = "./" |
|
|
|
tokenizer = T5Tokenizer.from_pretrained(model_dir) |
|
model = T5ForConditionalGeneration.from_pretrained(model_dir) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
def generate_sql(schema, instructions, user_query): |
|
combined_input = f"{instructions.strip()}\n\n{schema.strip()}\n\nUser Query: \"{user_query.strip()}\"\n\nSQL Query:" |
|
inputs = tokenizer(combined_input, padding=True, truncation=True, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_length=512) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 🧠 Text-to-SQL Generator") |
|
gr.Markdown("Enter the **schema**, **prompt/instructions**, and a **user query** to get the SQL output.") |
|
|
|
schema = gr.Textbox(label="Database Schema", lines=10, placeholder="CREATE TABLE students (...) ...") |
|
instructions = gr.Textbox(label="SQL Instructions / Prompt", lines=15, placeholder="Explain how to generate SQL queries...") |
|
user_query = gr.Textbox(label="User Query", placeholder="e.g., Show me students who never attended class") |
|
|
|
output = gr.Textbox(label="Generated SQL Query") |
|
|
|
submit = gr.Button("Generate SQL") |
|
submit.click(fn=generate_sql, inputs=[schema, instructions, user_query], outputs=output) |
|
|
|
demo.launch() |
|
|