File size: 826 Bytes
5665aa8
5cacb61
5665aa8
 
 
5cacb61
5665aa8
 
 
f525ef3
5665aa8
 
f525ef3
5665aa8
 
abe7c03
120ccfd
5665aa8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")

def nl_to_sql(natural_language_query):
    # Tokenize the input query
    input_ids = tokenizer(natural_language_query, return_tensors="pt").input_ids

    # Generate the SQL query
    output_ids = model.generate(input_ids, max_length=512)[0]

    # Decode the generated SQL query
    sql_query = tokenizer.decode(output_ids, skip_special_tokens=True)
    return sql_query

# Example usage
natural_language_query = "What is the average salary of employees?"
sql_query = nl_to_sql(natural_language_query)
print(f"SQL Query: {sql_query}")