NLSQL / app.py
HusnaManakkot's picture
Update app.py
5665aa8 verified
raw
history blame
826 Bytes
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
def nl_to_sql(natural_language_query):
# Tokenize the input query
input_ids = tokenizer(natural_language_query, return_tensors="pt").input_ids
# Generate the SQL query
output_ids = model.generate(input_ids, max_length=512)[0]
# Decode the generated SQL query
sql_query = tokenizer.decode(output_ids, skip_special_tokens=True)
return sql_query
# Example usage
natural_language_query = "What is the average salary of employees?"
sql_query = nl_to_sql(natural_language_query)
print(f"SQL Query: {sql_query}")