Spaces:
Runtime error
Runtime error
import pandas as pd | |
import os | |
import argparse | |
import shutil | |
import tempfile | |
import json | |
from google.cloud import storage | |
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments | |
from datasets import Dataset | |
from sklearn.preprocessing import LabelEncoder | |
import torch | |
# CLI arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset_path", type=str, required=True) | |
parser.add_argument("--output_dir", type=str, required=True) | |
args = parser.parse_args() | |
# Load dataset | |
print("📦 Loading dataset from:", args.dataset_path) | |
df = pd.read_csv(args.dataset_path) | |
df = df[["question", "intent"]] | |
# Label encoding | |
le = LabelEncoder() | |
df["label"] = le.fit_transform(df["intent"]) | |
label_mapping = dict(zip(le.classes_, le.transform(le.classes_))) | |
dataset = Dataset.from_pandas(df) | |
# Tokenizer and model | |
model_name = "distilbert-base-uncased" | |
tokenizer = DistilBERTTokenizerFast.from_pretrained(model_name) | |
model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=len(label_mapping)) | |
def tokenize(example): | |
return tokenizer(example["question"], truncation=True, padding="max_length", max_length=128) | |
dataset = dataset.map(tokenize) | |
training_args = TrainingArguments( | |
output_dir="./results_intent_classifier", | |
per_device_train_batch_size=4, | |
num_train_epochs=10, | |
logging_dir="./logs_intent", | |
logging_steps=5, | |
save_strategy="epoch", | |
evaluation_strategy="no" | |
) | |
trainer = Trainer(model=model, args=training_args, train_dataset=dataset) | |
trainer.train() | |
# Save to temp dir | |
local_dir = tempfile.mkdtemp() | |
model.save_pretrained(local_dir) | |
tokenizer.save_pretrained(local_dir) | |
with open(os.path.join(local_dir, "label_mapping.json"), "w") as f: | |
json.dump(label_mapping, f) | |
# Upload to GCS | |
gcs_model_path = os.path.join(args.output_dir, "intent") | |
bucket_name = gcs_model_path.split("/")[2] | |
base_path = "/".join(gcs_model_path.split("/")[3:]) | |
client = storage.Client() | |
for fname in os.listdir(local_dir): | |
local_path = os.path.join(local_dir, fname) | |
gcs_blob_path = os.path.join(base_path, fname) | |
print(f"⬆️ Uploading {fname} to gs://{bucket_name}/{gcs_blob_path}") | |
bucket = client.bucket(bucket_name) | |
blob = bucket.blob(gcs_blob_path) | |
blob.upload_from_filename(local_path) | |
print(f"✅ Intent model successfully uploaded to gs://{bucket_name}/{base_path}") | |