Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en") | |
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en") | |
def generate_sql(natural_language_query): | |
# Tokenize the input query | |
input_ids = tokenizer(natural_language_query, return_tensors="pt").input_ids | |
# Generate the SQL query | |
output_ids = model.generate(input_ids, max_length=512)[0] | |
# Decode the generated SQL query | |
sql_query = tokenizer.decode(output_ids, skip_special_tokens=True) | |
return sql_query | |
# Example questions for the interface | |
example_questions = [ | |
"What is the average salary of employees?", | |
"List the names of employees who work in the IT department.", | |
"Count the number of employees who joined after 2015." | |
] | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=generate_sql, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your natural language query here..."), | |
outputs="text", | |
examples=example_questions, | |
title="NL to SQL with CodeT5", | |
description="This model converts natural language queries into SQL using the WikiSQL dataset. Try one of the example questions or enter your own!" | |
) | |
# Launch the interface | |
interface.launch() | |