Spaces:
Runtime error
Runtime error
File size: 826 Bytes
5665aa8 5cacb61 5665aa8 5cacb61 5665aa8 f525ef3 5665aa8 f525ef3 5665aa8 abe7c03 120ccfd 5665aa8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
def nl_to_sql(natural_language_query):
# Tokenize the input query
input_ids = tokenizer(natural_language_query, return_tensors="pt").input_ids
# Generate the SQL query
output_ids = model.generate(input_ids, max_length=512)[0]
# Decode the generated SQL query
sql_query = tokenizer.decode(output_ids, skip_special_tokens=True)
return sql_query
# Example usage
natural_language_query = "What is the average salary of employees?"
sql_query = nl_to_sql(natural_language_query)
print(f"SQL Query: {sql_query}")
|