Compliance / finetune.py
HemanM's picture
Create finetune.py
b8565ff verified
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
# Load model and tokenizer
model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Load dialogue dataset
dataset = load_dataset("HuggingFaceH4/ultrachat", split="train[:1%]") # Use 1% for demo
# Preprocess dataset
def preprocess(examples):
prompts = [f"User: {ex['prompt']} Assistant: {ex['response']}" for ex in examples]
return tokenizer(prompts, truncation=True, padding="max_length", max_length=512)
tokenized_dataset = dataset.map(preprocess, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir="./evo_finetuned",
per_device_train_batch_size=4,
num_train_epochs=3,
save_steps=1000,
save_total_limit=2,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
# Fine-tune
trainer.train()
# Save model
model.save_pretrained("evo_finetuned")
tokenizer.save_pretrained("evo_finetuned")