import torch from transformers import T5Tokenizer, T5ForConditionalGeneration import os import pandas as pd # Get project root directory PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) def load_model(): print("📦 Loading pre-trained text-to-SQL model...") model_name = "cssupport/t5-small-awesome-text-to-sql" tokenizer = T5Tokenizer.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = T5ForConditionalGeneration.from_pretrained(model_name) model = model.to(device) model.eval() return model, tokenizer, device def generate_sql(question, schema, model, tokenizer, device): # Format input as expected by the model input_prompt = f"tables:\n{schema}\nquery for: {question}" # Tokenize the input prompt inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device) # Generate SQL with torch.no_grad(): outputs = model.generate(**inputs, max_length=512) # Decode the output generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_sql def get_schema_from_csv(csv_path): """Generate CREATE TABLE statements from CSV file""" df = pd.read_csv(csv_path) columns = [] for col in df.columns: # Infer column type dtype = df[col].dtype if dtype == 'int64': col_type = 'INT' elif dtype == 'float64': col_type = 'DECIMAL(10,2)' else: col_type = 'VARCHAR(255)' columns.append(f"{col} {col_type}") table_name = os.path.splitext(os.path.basename(csv_path))[0] create_table = f"CREATE TABLE {table_name} (\n " + ",\n ".join(columns) + "\n);" return create_table if __name__ == "__main__": # Load the pre-trained model model, tokenizer, device = load_model() # Save the model locally for future use output_dir = os.path.join(PROJECT_ROOT, "model_sqlgen_t5") print(f"💾 Saving model to {output_dir}") os.makedirs(output_dir, exist_ok=True) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) print(f"✅ Model successfully saved to {output_dir}") # Example usage with CSV csv_path = os.path.join(PROJECT_ROOT, "data", "retail_dataset.csv") if os.path.exists(csv_path): schema = get_schema_from_csv(csv_path) print("\nGenerated schema from CSV:") print(schema) question = "What is the total sales amount for each product category?" sql_query = generate_sql(question, schema, model, tokenizer, device) print("\nExample usage:") print(f"Question: {question}") print(f"Generated SQL: {sql_query}")