NLSQL / app.py
HusnaManakkot's picture
Update app.py
a588039 verified
raw
history blame
1.37 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
def generate_sql(natural_language_query):
# Tokenize the input query
input_ids = tokenizer(natural_language_query, return_tensors="pt").input_ids
# Generate the SQL query
output_ids = model.generate(input_ids, max_length=512)[0]
# Decode the generated SQL query
sql_query = tokenizer.decode(output_ids, skip_special_tokens=True)
return sql_query
# Example questions for the interface
example_questions = [
"What is the average salary of employees?",
"List the names of employees who work in the IT department.",
"Count the number of employees who joined after 2015."
]
# Create the Gradio interface
interface = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(lines=2, placeholder="Enter your natural language query here..."),
outputs="text",
examples=example_questions,
title="NL to SQL with CodeT5",
description="This model converts natural language queries into SQL using the WikiSQL dataset. Try one of the example questions or enter your own!"
)
# Launch the interface
interface.launch()