File size: 2,797 Bytes
5f946b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import os
import pandas as pd

# Get project root directory
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

def load_model():
    print("📦 Loading pre-trained text-to-SQL model...")
    model_name = "cssupport/t5-small-awesome-text-to-sql"
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    model = model.to(device)
    model.eval()
    return model, tokenizer, device

def generate_sql(question, schema, model, tokenizer, device):
    # Format input as expected by the model
    input_prompt = f"tables:\n{schema}\nquery for: {question}"
    
    # Tokenize the input prompt
    inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
    
    # Generate SQL
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=512)
    
    # Decode the output
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_sql

def get_schema_from_csv(csv_path):
    """Generate CREATE TABLE statements from CSV file"""
    df = pd.read_csv(csv_path)
    columns = []
    for col in df.columns:
        # Infer column type
        dtype = df[col].dtype
        if dtype == 'int64':
            col_type = 'INT'
        elif dtype == 'float64':
            col_type = 'DECIMAL(10,2)'
        else:
            col_type = 'VARCHAR(255)'
        columns.append(f"{col} {col_type}")
    
    table_name = os.path.splitext(os.path.basename(csv_path))[0]
    create_table = f"CREATE TABLE {table_name} (\n    " + ",\n    ".join(columns) + "\n);"
    return create_table

if __name__ == "__main__":
    # Load the pre-trained model
    model, tokenizer, device = load_model()
    
    # Save the model locally for future use
    output_dir = os.path.join(PROJECT_ROOT, "model_sqlgen_t5")
    print(f"💾 Saving model to {output_dir}")
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"✅ Model successfully saved to {output_dir}")
    
    # Example usage with CSV
    csv_path = os.path.join(PROJECT_ROOT, "data", "retail_dataset.csv")
    if os.path.exists(csv_path):
        schema = get_schema_from_csv(csv_path)
        print("\nGenerated schema from CSV:")
        print(schema)
        
        question = "What is the total sales amount for each product category?"
        sql_query = generate_sql(question, schema, model, tokenizer, device)
        print("\nExample usage:")
        print(f"Question: {question}")
        print(f"Generated SQL: {sql_query}")