Spaces:
Runtime error
Runtime error
import torch | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
import os | |
import pandas as pd | |
# Get project root directory | |
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
def load_model(): | |
print("📦 Loading pre-trained text-to-SQL model...") | |
model_name = "cssupport/t5-small-awesome-text-to-sql" | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = T5ForConditionalGeneration.from_pretrained(model_name) | |
model = model.to(device) | |
model.eval() | |
return model, tokenizer, device | |
def generate_sql(question, schema, model, tokenizer, device): | |
# Format input as expected by the model | |
input_prompt = f"tables:\n{schema}\nquery for: {question}" | |
# Tokenize the input prompt | |
inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device) | |
# Generate SQL | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_length=512) | |
# Decode the output | |
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_sql | |
def get_schema_from_csv(csv_path): | |
"""Generate CREATE TABLE statements from CSV file""" | |
df = pd.read_csv(csv_path) | |
columns = [] | |
for col in df.columns: | |
# Infer column type | |
dtype = df[col].dtype | |
if dtype == 'int64': | |
col_type = 'INT' | |
elif dtype == 'float64': | |
col_type = 'DECIMAL(10,2)' | |
else: | |
col_type = 'VARCHAR(255)' | |
columns.append(f"{col} {col_type}") | |
table_name = os.path.splitext(os.path.basename(csv_path))[0] | |
create_table = f"CREATE TABLE {table_name} (\n " + ",\n ".join(columns) + "\n);" | |
return create_table | |
if __name__ == "__main__": | |
# Load the pre-trained model | |
model, tokenizer, device = load_model() | |
# Save the model locally for future use | |
output_dir = os.path.join(PROJECT_ROOT, "model_sqlgen_t5") | |
print(f"💾 Saving model to {output_dir}") | |
os.makedirs(output_dir, exist_ok=True) | |
model.save_pretrained(output_dir) | |
tokenizer.save_pretrained(output_dir) | |
print(f"✅ Model successfully saved to {output_dir}") | |
# Example usage with CSV | |
csv_path = os.path.join(PROJECT_ROOT, "data", "retail_dataset.csv") | |
if os.path.exists(csv_path): | |
schema = get_schema_from_csv(csv_path) | |
print("\nGenerated schema from CSV:") | |
print(schema) | |
question = "What is the total sales amount for each product category?" | |
sql_query = generate_sql(question, schema, model, tokenizer, device) | |
print("\nExample usage:") | |
print(f"Question: {question}") | |
print(f"Generated SQL: {sql_query}") | |