|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
import tempfile |
|
import unittest |
|
|
|
from datasets import load_dataset |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments |
|
from transformers.testing_utils import require_peft, require_wandb |
|
from transformers.trainer_utils import get_last_checkpoint |
|
from transformers.utils import is_peft_available |
|
|
|
from tests.testing_utils import require_comet, require_mergekit |
|
from trl import BasePairwiseJudge, DPOConfig, DPOTrainer, LogCompletionsCallback, MergeModelCallback, WinRateCallback |
|
from trl.mergekit_utils import MergeConfig |
|
|
|
|
|
if is_peft_available(): |
|
from peft import LoraConfig |
|
|
|
|
|
class HalfPairwiseJudge(BasePairwiseJudge): |
|
"""Naive pairwise judge that always returns [1, 0] for two prompts""" |
|
|
|
def judge(self, prompts, completions, shuffle_order=True, return_scores=False): |
|
|
|
assert len(prompts) == 2 |
|
if return_scores: |
|
return [0.3, 0.9] |
|
return [1, 0] |
|
|
|
|
|
class TrainerWithRefModel(Trainer): |
|
|
|
|
|
def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class): |
|
super().__init__( |
|
model=model, |
|
args=args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
processing_class=processing_class, |
|
) |
|
self.ref_model = ref_model |
|
|
|
|
|
class WinRateCallbackTester(unittest.TestCase): |
|
def setUp(self): |
|
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
|
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") |
|
dataset["train"] = dataset["train"].select(range(8)) |
|
self.expected_winrates = [ |
|
{"eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, |
|
{"eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, |
|
{"eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, |
|
{"eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, |
|
{"eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, |
|
{"eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, |
|
{"eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, |
|
] |
|
|
|
def tokenize_function(examples): |
|
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) |
|
out["labels"] = out["input_ids"].copy() |
|
return out |
|
|
|
self.dataset = dataset.map(tokenize_function, batched=True) |
|
|
|
self.generation_config = GenerationConfig(max_length=32) |
|
self.judge = HalfPairwiseJudge() |
|
|
|
def test_basic(self): |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = TrainingArguments( |
|
output_dir=tmp_dir, |
|
eval_strategy="steps", |
|
eval_steps=2, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
report_to="none", |
|
) |
|
trainer = TrainerWithRefModel( |
|
model=self.model, |
|
ref_model=self.ref_model, |
|
args=training_args, |
|
train_dataset=self.dataset["train"], |
|
eval_dataset=self.dataset["test"], |
|
processing_class=self.tokenizer, |
|
) |
|
win_rate_callback = WinRateCallback( |
|
judge=self.judge, trainer=trainer, generation_config=self.generation_config |
|
) |
|
trainer.add_callback(win_rate_callback) |
|
trainer.train() |
|
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] |
|
self.assertListEqual(winrate_history, self.expected_winrates) |
|
|
|
def test_without_ref_model(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = TrainingArguments( |
|
output_dir=tmp_dir, |
|
eval_strategy="steps", |
|
eval_steps=2, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
report_to="none", |
|
) |
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=self.dataset["train"], |
|
eval_dataset=self.dataset["test"], |
|
processing_class=self.tokenizer, |
|
) |
|
win_rate_callback = WinRateCallback( |
|
judge=self.judge, trainer=trainer, generation_config=self.generation_config |
|
) |
|
trainer.add_callback(win_rate_callback) |
|
trainer.train() |
|
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] |
|
self.assertListEqual(winrate_history, self.expected_winrates) |
|
|
|
def test_soft_judge(self): |
|
"""Test that the soft judge functionality works correctly""" |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = TrainingArguments( |
|
output_dir=tmp_dir, |
|
eval_strategy="steps", |
|
eval_steps=2, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
report_to="none", |
|
) |
|
trainer = TrainerWithRefModel( |
|
model=self.model, |
|
ref_model=self.ref_model, |
|
args=training_args, |
|
train_dataset=self.dataset["train"], |
|
eval_dataset=self.dataset["test"], |
|
processing_class=self.tokenizer, |
|
) |
|
win_rate_callback = WinRateCallback( |
|
judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True |
|
) |
|
trainer.add_callback(win_rate_callback) |
|
trainer.train() |
|
|
|
|
|
expected_soft_winrates = [ |
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, |
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, |
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, |
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, |
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, |
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, |
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, |
|
] |
|
|
|
winrate_history = [ |
|
{k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]} |
|
for h in trainer.state.log_history |
|
if "eval_avg_win_prob" in h |
|
] |
|
self.assertListEqual(winrate_history, expected_soft_winrates) |
|
|
|
@require_peft |
|
def test_lora(self): |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
peft_config = LoraConfig( |
|
r=16, |
|
lora_alpha=32, |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
self.model.add_adapter(peft_config) |
|
training_args = TrainingArguments( |
|
output_dir=tmp_dir, |
|
eval_strategy="steps", |
|
eval_steps=2, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
report_to="none", |
|
) |
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=self.dataset["train"], |
|
eval_dataset=self.dataset["test"], |
|
processing_class=self.tokenizer, |
|
) |
|
win_rate_callback = WinRateCallback( |
|
judge=self.judge, trainer=trainer, generation_config=self.generation_config |
|
) |
|
trainer.add_callback(win_rate_callback) |
|
trainer.train() |
|
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] |
|
self.assertListEqual(winrate_history, self.expected_winrates) |
|
|
|
|
|
class LogCompletionsCallbackTester(unittest.TestCase): |
|
def setUp(self): |
|
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") |
|
dataset["train"] = dataset["train"].select(range(8)) |
|
|
|
def tokenize_function(examples): |
|
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) |
|
out["labels"] = out["input_ids"].copy() |
|
return out |
|
|
|
self.dataset = dataset.map(tokenize_function, batched=True) |
|
|
|
self.generation_config = GenerationConfig(max_length=32) |
|
|
|
@require_wandb |
|
def test_basic_wandb(self): |
|
import wandb |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = TrainingArguments( |
|
output_dir=tmp_dir, |
|
eval_strategy="steps", |
|
eval_steps=2, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
report_to="wandb", |
|
) |
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=self.dataset["train"], |
|
eval_dataset=self.dataset["test"], |
|
processing_class=self.tokenizer, |
|
) |
|
completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) |
|
trainer.add_callback(completions_callback) |
|
trainer.train() |
|
|
|
|
|
completions_path = wandb.run.summary.completions["path"] |
|
json_path = os.path.join(wandb.run.dir, completions_path) |
|
with open(json_path) as f: |
|
completions = json.load(f) |
|
|
|
|
|
self.assertIn("step", completions["columns"]) |
|
self.assertIn("prompt", completions["columns"]) |
|
self.assertIn("completion", completions["columns"]) |
|
|
|
|
|
self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0]) |
|
|
|
@require_comet |
|
def test_basic_comet(self): |
|
import comet_ml |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = TrainingArguments( |
|
output_dir=tmp_dir, |
|
eval_strategy="steps", |
|
eval_steps=2, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
report_to="comet_ml", |
|
) |
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=self.dataset["train"], |
|
eval_dataset=self.dataset["test"], |
|
processing_class=self.tokenizer, |
|
) |
|
completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) |
|
trainer.add_callback(completions_callback) |
|
trainer.train() |
|
|
|
|
|
experiment = comet_ml.get_running_experiment() |
|
assert experiment is not None |
|
experiment.end() |
|
|
|
|
|
steps = len(self.dataset["train"]) + len(self.dataset["test"]) |
|
tables_logged = int(steps / 2) + 1 |
|
|
|
api_experiment = comet_ml.APIExperiment(previous_experiment=experiment.id) |
|
tables = api_experiment.get_asset_list("dataframe") |
|
assert tables is not None |
|
assert len(tables) == tables_logged |
|
assert all(table["fileName"] == "completions.csv" for table in tables) |
|
|
|
|
|
@require_mergekit |
|
class MergeModelCallbackTester(unittest.TestCase): |
|
def setUp(self): |
|
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
|
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") |
|
|
|
def test_callback(self): |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = DPOConfig( |
|
output_dir=tmp_dir, |
|
num_train_epochs=1, |
|
report_to="none", |
|
save_strategy="steps", |
|
save_steps=1, |
|
) |
|
config = MergeConfig() |
|
merge_callback = MergeModelCallback(config) |
|
trainer = DPOTrainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=self.dataset, |
|
processing_class=self.tokenizer, |
|
callbacks=[merge_callback], |
|
) |
|
trainer.train() |
|
last_checkpoint = get_last_checkpoint(tmp_dir) |
|
merged_path = os.path.join(last_checkpoint, "merged") |
|
self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.") |
|
|
|
def test_every_checkpoint(self): |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = DPOConfig( |
|
output_dir=tmp_dir, |
|
num_train_epochs=1, |
|
report_to="none", |
|
save_strategy="steps", |
|
save_steps=1, |
|
) |
|
config = MergeConfig() |
|
merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True) |
|
trainer = DPOTrainer( |
|
model=self.model, |
|
args=training_args, |
|
train_dataset=self.dataset, |
|
processing_class=self.tokenizer, |
|
callbacks=[merge_callback], |
|
) |
|
trainer.train() |
|
|
|
checkpoints = sorted( |
|
[os.path.join(tmp_dir, cp) for cp in os.listdir(tmp_dir) if cp.startswith("checkpoint-")] |
|
) |
|
|
|
for checkpoint in checkpoints: |
|
merged_path = os.path.join(checkpoint, "merged") |
|
self.assertTrue( |
|
os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." |
|
) |
|
|