Spaces:
Runtime error
Runtime error
import pandas as pd | |
import os | |
import argparse | |
import shutil | |
import tempfile | |
from google.cloud import storage | |
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments | |
from datasets import Dataset | |
import torch | |
# CLI arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset_path", type=str, required=True) | |
parser.add_argument("--output_dir", type=str, required=True) | |
args = parser.parse_args() | |
print("📦 Loading dataset from:", args.dataset_path) | |
df = pd.read_csv(args.dataset_path) | |
df = df[["question", "sql"]].rename(columns={"question": "input_text", "sql": "target_text"}) | |
df["input_text"] = "translate question to SQL: " + df["input_text"] | |
dataset = Dataset.from_pandas(df) | |
# Load tokenizer and model | |
model_name = "t5-small" | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
model = T5ForConditionalGeneration.from_pretrained(model_name) | |
def preprocess(example): | |
input_enc = tokenizer(example["input_text"], truncation=True, padding="max_length", max_length=128) | |
target_enc = tokenizer(example["target_text"], truncation=True, padding="max_length", max_length=128) | |
input_enc["labels"] = target_enc["input_ids"] | |
return input_enc | |
tokenized_dataset = dataset.map(preprocess) | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir="./results_t5_sqlgen", | |
per_device_train_batch_size=4, | |
num_train_epochs=10, | |
logging_dir="./logs", | |
logging_steps=5, | |
save_strategy="epoch", | |
evaluation_strategy="no" | |
) | |
# Train model | |
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset) | |
trainer.train() | |
# Save model to temporary local directory | |
local_dir = tempfile.mkdtemp() | |
model.save_pretrained(local_dir) | |
tokenizer.save_pretrained(local_dir) | |
# Upload all files to GCS | |
gcs_model_path = os.path.join(args.output_dir, "sqlgen") | |
bucket_name = gcs_model_path.split("/")[2] | |
base_path = "/".join(gcs_model_path.split("/")[3:]) | |
client = storage.Client() | |
for fname in os.listdir(local_dir): | |
local_path = os.path.join(local_dir, fname) | |
gcs_blob_path = os.path.join(base_path, fname) | |
print(f"⬆️ Uploading {fname} to gs://{bucket_name}/{gcs_blob_path}") | |
bucket = client.bucket(bucket_name) | |
blob = bucket.blob(gcs_blob_path) | |
blob.upload_from_filename(local_path) | |
print(f"✅ Model successfully uploaded to gs://{bucket_name}/{base_path}") | |