RetailGenie / code /cloud_train_sqlgen_t5_script.py
shubh7's picture
Adding application file
5f946b0
import pandas as pd
import os
import argparse
import shutil
import tempfile
from google.cloud import storage
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset
import torch
# CLI arguments
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
args = parser.parse_args()
print("📦 Loading dataset from:", args.dataset_path)
df = pd.read_csv(args.dataset_path)
df = df[["question", "sql"]].rename(columns={"question": "input_text", "sql": "target_text"})
df["input_text"] = "translate question to SQL: " + df["input_text"]
dataset = Dataset.from_pandas(df)
# Load tokenizer and model
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
def preprocess(example):
input_enc = tokenizer(example["input_text"], truncation=True, padding="max_length", max_length=128)
target_enc = tokenizer(example["target_text"], truncation=True, padding="max_length", max_length=128)
input_enc["labels"] = target_enc["input_ids"]
return input_enc
tokenized_dataset = dataset.map(preprocess)
# Training arguments
training_args = TrainingArguments(
output_dir="./results_t5_sqlgen",
per_device_train_batch_size=4,
num_train_epochs=10,
logging_dir="./logs",
logging_steps=5,
save_strategy="epoch",
evaluation_strategy="no"
)
# Train model
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset)
trainer.train()
# Save model to temporary local directory
local_dir = tempfile.mkdtemp()
model.save_pretrained(local_dir)
tokenizer.save_pretrained(local_dir)
# Upload all files to GCS
gcs_model_path = os.path.join(args.output_dir, "sqlgen")
bucket_name = gcs_model_path.split("/")[2]
base_path = "/".join(gcs_model_path.split("/")[3:])
client = storage.Client()
for fname in os.listdir(local_dir):
local_path = os.path.join(local_dir, fname)
gcs_blob_path = os.path.join(base_path, fname)
print(f"⬆️ Uploading {fname} to gs://{bucket_name}/{gcs_blob_path}")
bucket = client.bucket(bucket_name)
blob = bucket.blob(gcs_blob_path)
blob.upload_from_filename(local_path)
print(f"✅ Model successfully uploaded to gs://{bucket_name}/{base_path}")