RetailGenie / code /train_sqlgen_t5_local.py
shubh7's picture
Adding application file
5f946b0
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import os
import pandas as pd
# Get project root directory
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
def load_model():
print("📦 Loading pre-trained text-to-SQL model...")
model_name = "cssupport/t5-small-awesome-text-to-sql"
tokenizer = T5Tokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device)
model.eval()
return model, tokenizer, device
def generate_sql(question, schema, model, tokenizer, device):
# Format input as expected by the model
input_prompt = f"tables:\n{schema}\nquery for: {question}"
# Tokenize the input prompt
inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
# Generate SQL
with torch.no_grad():
outputs = model.generate(**inputs, max_length=512)
# Decode the output
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_sql
def get_schema_from_csv(csv_path):
"""Generate CREATE TABLE statements from CSV file"""
df = pd.read_csv(csv_path)
columns = []
for col in df.columns:
# Infer column type
dtype = df[col].dtype
if dtype == 'int64':
col_type = 'INT'
elif dtype == 'float64':
col_type = 'DECIMAL(10,2)'
else:
col_type = 'VARCHAR(255)'
columns.append(f"{col} {col_type}")
table_name = os.path.splitext(os.path.basename(csv_path))[0]
create_table = f"CREATE TABLE {table_name} (\n " + ",\n ".join(columns) + "\n);"
return create_table
if __name__ == "__main__":
# Load the pre-trained model
model, tokenizer, device = load_model()
# Save the model locally for future use
output_dir = os.path.join(PROJECT_ROOT, "model_sqlgen_t5")
print(f"💾 Saving model to {output_dir}")
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"✅ Model successfully saved to {output_dir}")
# Example usage with CSV
csv_path = os.path.join(PROJECT_ROOT, "data", "retail_dataset.csv")
if os.path.exists(csv_path):
schema = get_schema_from_csv(csv_path)
print("\nGenerated schema from CSV:")
print(schema)
question = "What is the total sales amount for each product category?"
sql_query = generate_sql(question, schema, model, tokenizer, device)
print("\nExample usage:")
print(f"Question: {question}")
print(f"Generated SQL: {sql_query}")