|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
model_name = "EleutherAI/gpt-neo-2.7B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
schema = { |
|
"products": { |
|
"columns": ["product_id", "name", "price", "category_id"], |
|
"relations": "category_id -> categories.id", |
|
}, |
|
"categories": { |
|
"columns": ["id", "category_name"], |
|
"relations": None, |
|
}, |
|
"orders": { |
|
"columns": ["order_id", "customer_name", "product_id", "order_date"], |
|
"relations": "product_id -> products.product_id", |
|
}, |
|
} |
|
|
|
def generate_context(schema): |
|
""" |
|
Generate context dynamically from the schema. |
|
""" |
|
context = "### Database Schema ###\n\n" |
|
for table, details in schema.items(): |
|
context += f"Table: {table}\nColumns: {', '.join(details['columns'])}\n" |
|
if details.get("relations"): |
|
context += f"Relations: {details['relations']}\n" |
|
context += "\n" |
|
|
|
context += "### Instructions ###\n" |
|
context += ( |
|
"Understand the database schema thoroughly to identify the relevant tables, their columns, and the relationships between them. If a question involves data from multiple tables, use appropriate joins to connect them. The questions might not always be related to SQL query generation — they can be about understanding what the database is storing in each field or column. Provide descriptions for the fields in the tables, including their meanings and any relevant details about the data they store. Be aware that the questions could also request information about how tables and columns are interrelated. Think about how to extract and explain data from the schema based on the user's query.\n" |
|
) |
|
return context |
|
|
|
|
|
context = generate_context(schema) |
|
|
|
def answer_question(context, question): |
|
""" |
|
Generate an SQL query or database-related response using the model. |
|
""" |
|
prompt = f"{context}\n\nUser Question: {question}\nSQL Query or Answer:" |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) |
|
outputs = model.generate(inputs.input_ids, max_length=256, num_beams=5, early_stopping=True) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
print("Database Assistant is ready. Ask your questions!") |
|
|
|
|
|
questions = [ |
|
"describe the product table for me, what kind of data it is storing and all" |
|
] |
|
|
|
for user_question in questions: |
|
print(f"Question: {user_question}") |
|
response = answer_question(context, user_question) |
|
print("\nGenerated Response:\n", response, "\n") |
|
|
|
|