# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import tempfile import unittest from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments from transformers.testing_utils import require_peft, require_wandb from transformers.trainer_utils import get_last_checkpoint from transformers.utils import is_peft_available from tests.testing_utils import require_comet, require_mergekit from trl import BasePairwiseJudge, DPOConfig, DPOTrainer, LogCompletionsCallback, MergeModelCallback, WinRateCallback from trl.mergekit_utils import MergeConfig if is_peft_available(): from peft import LoraConfig class HalfPairwiseJudge(BasePairwiseJudge): """Naive pairwise judge that always returns [1, 0] for two prompts""" def judge(self, prompts, completions, shuffle_order=True, return_scores=False): # just check that the batch size is 2 assert len(prompts) == 2 if return_scores: return [0.3, 0.9] return [1, 0] class TrainerWithRefModel(Trainer): # This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional # ref_model attribute def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class): super().__init__( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, ) self.ref_model = ref_model class WinRateCallbackTester(unittest.TestCase): def setUp(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer.pad_token = self.tokenizer.eos_token dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") dataset["train"] = dataset["train"].select(range(8)) self.expected_winrates = [ {"eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, {"eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, {"eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, {"eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, {"eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, {"eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, {"eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, ] def tokenize_function(examples): out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) out["labels"] = out["input_ids"].copy() return out self.dataset = dataset.map(tokenize_function, batched=True) self.generation_config = GenerationConfig(max_length=32) self.judge = HalfPairwiseJudge() def test_basic(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( output_dir=tmp_dir, eval_strategy="steps", eval_steps=2, # evaluate every 2 steps per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch per_device_eval_batch_size=2, report_to="none", ) trainer = TrainerWithRefModel( model=self.model, ref_model=self.ref_model, args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], processing_class=self.tokenizer, ) win_rate_callback = WinRateCallback( judge=self.judge, trainer=trainer, generation_config=self.generation_config ) trainer.add_callback(win_rate_callback) trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] self.assertListEqual(winrate_history, self.expected_winrates) def test_without_ref_model(self): # Same as before, but without the ref_model attribute. It should use the model attribute instead with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( output_dir=tmp_dir, eval_strategy="steps", eval_steps=2, # evaluate every 2 steps per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch per_device_eval_batch_size=2, report_to="none", ) trainer = Trainer( model=self.model, args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], processing_class=self.tokenizer, ) win_rate_callback = WinRateCallback( judge=self.judge, trainer=trainer, generation_config=self.generation_config ) trainer.add_callback(win_rate_callback) trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] self.assertListEqual(winrate_history, self.expected_winrates) def test_soft_judge(self): """Test that the soft judge functionality works correctly""" with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( output_dir=tmp_dir, eval_strategy="steps", eval_steps=2, # evaluate every 2 steps per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch per_device_eval_batch_size=2, report_to="none", ) trainer = TrainerWithRefModel( model=self.model, ref_model=self.ref_model, args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], processing_class=self.tokenizer, ) win_rate_callback = WinRateCallback( judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True ) trainer.add_callback(win_rate_callback) trainer.train() # Expected values based on judge returning [0.3, 0.9] for each pair expected_soft_winrates = [ {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, ] winrate_history = [ {k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]} for h in trainer.state.log_history if "eval_avg_win_prob" in h ] self.assertListEqual(winrate_history, expected_soft_winrates) @require_peft def test_lora(self): with tempfile.TemporaryDirectory() as tmp_dir: peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) self.model.add_adapter(peft_config) training_args = TrainingArguments( output_dir=tmp_dir, eval_strategy="steps", eval_steps=2, # evaluate every 2 steps per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch per_device_eval_batch_size=2, report_to="none", ) trainer = Trainer( model=self.model, args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], processing_class=self.tokenizer, ) win_rate_callback = WinRateCallback( judge=self.judge, trainer=trainer, generation_config=self.generation_config ) trainer.add_callback(win_rate_callback) trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] self.assertListEqual(winrate_history, self.expected_winrates) class LogCompletionsCallbackTester(unittest.TestCase): def setUp(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer.pad_token = self.tokenizer.eos_token dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") dataset["train"] = dataset["train"].select(range(8)) def tokenize_function(examples): out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) out["labels"] = out["input_ids"].copy() return out self.dataset = dataset.map(tokenize_function, batched=True) self.generation_config = GenerationConfig(max_length=32) @require_wandb def test_basic_wandb(self): import wandb with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( output_dir=tmp_dir, eval_strategy="steps", eval_steps=2, # evaluate every 2 steps per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch per_device_eval_batch_size=2, report_to="wandb", ) trainer = Trainer( model=self.model, args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], processing_class=self.tokenizer, ) completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) trainer.add_callback(completions_callback) trainer.train() # Get the current run completions_path = wandb.run.summary.completions["path"] json_path = os.path.join(wandb.run.dir, completions_path) with open(json_path) as f: completions = json.load(f) # Check that the columns are correct self.assertIn("step", completions["columns"]) self.assertIn("prompt", completions["columns"]) self.assertIn("completion", completions["columns"]) # Check that the prompt is in the log self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0]) @require_comet def test_basic_comet(self): import comet_ml with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( output_dir=tmp_dir, eval_strategy="steps", eval_steps=2, # evaluate every 2 steps per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch per_device_eval_batch_size=2, report_to="comet_ml", ) trainer = Trainer( model=self.model, args=training_args, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], processing_class=self.tokenizer, ) completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) trainer.add_callback(completions_callback) trainer.train() # close experiment to make sure all pending data are flushed experiment = comet_ml.get_running_experiment() assert experiment is not None experiment.end() # get experiment assets and check that all required tables was logged steps = len(self.dataset["train"]) + len(self.dataset["test"]) tables_logged = int(steps / 2) + 1 # +1 to include zero step api_experiment = comet_ml.APIExperiment(previous_experiment=experiment.id) tables = api_experiment.get_asset_list("dataframe") assert tables is not None assert len(tables) == tables_logged assert all(table["fileName"] == "completions.csv" for table in tables) @require_mergekit class MergeModelCallbackTester(unittest.TestCase): def setUp(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") def test_callback(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = DPOConfig( output_dir=tmp_dir, num_train_epochs=1, report_to="none", save_strategy="steps", save_steps=1, ) config = MergeConfig() merge_callback = MergeModelCallback(config) trainer = DPOTrainer( model=self.model, args=training_args, train_dataset=self.dataset, processing_class=self.tokenizer, callbacks=[merge_callback], ) trainer.train() last_checkpoint = get_last_checkpoint(tmp_dir) merged_path = os.path.join(last_checkpoint, "merged") self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.") def test_every_checkpoint(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = DPOConfig( output_dir=tmp_dir, num_train_epochs=1, report_to="none", save_strategy="steps", save_steps=1, ) config = MergeConfig() merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True) trainer = DPOTrainer( model=self.model, args=training_args, train_dataset=self.dataset, processing_class=self.tokenizer, callbacks=[merge_callback], ) trainer.train() checkpoints = sorted( [os.path.join(tmp_dir, cp) for cp in os.listdir(tmp_dir) if cp.startswith("checkpoint-")] ) for checkpoint in checkpoints: merged_path = os.path.join(checkpoint, "merged") self.assertTrue( os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." )