Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en") | |
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en") | |
def nl_to_sql(natural_language_query): | |
# Tokenize the input query | |
input_ids = tokenizer(natural_language_query, return_tensors="pt").input_ids | |
# Generate the SQL query | |
output_ids = model.generate(input_ids, max_length=512)[0] | |
# Decode the generated SQL query | |
sql_query = tokenizer.decode(output_ids, skip_special_tokens=True) | |
return sql_query | |
# Example usage | |
natural_language_query = "What is the average salary of employees?" | |
sql_query = nl_to_sql(natural_language_query) | |
print(f"SQL Query: {sql_query}") | |