trl-sandbox / trl /trainer /callbacks.py
ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Union
import pandas as pd
import torch
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import gather_object, is_wandb_available
from transformers import (
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import has_length
from transformers.utils import is_rich_available
from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_mergekit_available
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
from .utils import log_table_to_comet_experiment
if is_rich_available():
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
if is_wandb_available():
import wandb
def _generate_completions(
prompts: list[str],
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
accelerator: Accelerator,
generation_config: Optional[GenerationConfig],
batch_size: int = 1,
) -> list[str]:
"""
Generates completions for a list of pre-formatted prompts from the given model.
Args:
prompts (list[str]): A list of input prompts for which completions are to be generated.
model (PreTrainedModel): The pre-trained model to be used for generation.
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for encoding and decoding.
accelerator (Accelerator): The accelerator to be used for model execution.
generation_config (GenerationConfig): Configuration for text generation.
batch_size (int, optional): The number of prompts to process in each batch. Default is 1.
Returns:
list[str]: A list of generated text completions corresponding to the input prompts.
"""
completions = []
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
for idx in range(0, len(prompts), batch_size):
batch = prompts[idx : idx + batch_size]
tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device)
generations = unwrapped_model.generate(
**tokenized_batch,
generation_config=generation_config,
)
for prompt, generation in zip(tokenized_batch.input_ids, generations):
# Remove prompt from generation
generation = generation[len(prompt) :]
completion = tokenizer.decode(generation, skip_special_tokens=True)
completions.append(completion)
return completions
class SyncRefModelCallback(TrainerCallback):
"""
Callback to synchronize the model with a reference model.
"""
def __init__(
self,
ref_model: Union[PreTrainedModel, torch.nn.Module],
accelerator: Optional[Accelerator],
):
self.accelerator = accelerator
self.ref_model = ref_model
@staticmethod
def _sync_target_model(model, target_model, alpha):
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)
@staticmethod
def sync_target_model(model, target_model, alpha):
deepspeed_plugin = AcceleratorState().deepspeed_plugin
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
import deepspeed
with deepspeed.zero.GatheredParameters(
list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
):
if deepspeed.comm.get_rank() == 0:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)
else:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)
def on_step_end(self, args, state, control, **kwargs):
model: PreTrainedModel = kwargs["model"]
if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0:
if self.accelerator:
model = self.accelerator.unwrap_model(model)
self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha)
class RichProgressCallback(TrainerCallback):
"""
A [`TrainerCallback`] that displays the progress of training or evaluation using Rich.
"""
def __init__(self):
if not is_rich_available():
raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.")
self.training_bar = None
self.prediction_bar = None
self.training_task_id = None
self.prediction_task_id = None
self.rich_group = None
self.rich_console = None
self.training_status = None
self.current_step = None
def on_train_begin(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.training_bar = Progress()
self.prediction_bar = Progress()
self.rich_console = Console()
self.training_status = self.rich_console.status("Nothing to log yet ...")
self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status)))
self.rich_group.start()
self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps)
self.current_step = 0
def on_step_end(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True)
self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_world_process_zero and has_length(eval_dataloader):
if self.prediction_task_id is None:
self.prediction_task_id = self.prediction_bar.add_task(
"[blue]Predicting on the evaluation dataset", total=len(eval_dataloader)
)
self.prediction_bar.update(self.prediction_task_id, advance=1, update=True)
def on_evaluate(self, args, state, control, **kwargs):
if state.is_world_process_zero:
if self.prediction_task_id is not None:
self.prediction_bar.remove_task(self.prediction_task_id)
self.prediction_task_id = None
def on_predict(self, args, state, control, **kwargs):
if state.is_world_process_zero:
if self.prediction_task_id is not None:
self.prediction_bar.remove_task(self.prediction_task_id)
self.prediction_task_id = None
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero and self.training_bar is not None:
_ = logs.pop("total_flos", None)
self.training_status.update(f"[bold green]Status = {str(logs)}")
def on_train_end(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.rich_group.stop()
self.training_bar = None
self.prediction_bar = None
self.training_task_id = None
self.prediction_task_id = None
self.rich_group = None
self.rich_console = None
self.training_status = None
self.current_step = None
def _win_rate_completions_df(
state: TrainerState, prompts: list[str], completions: list[str], winner_indices: list[str]
) -> pd.DataFrame:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, winner_indices))
# Split completions from reference model and policy
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"])
class WinRateCallback(TrainerCallback):
"""
A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference.
It generates completions using prompts from the evaluation dataset and compares the trained model's outputs against
a reference. The reference is either the initial version of the model (before training) or the reference model, if
available in the trainer. During each evaluation step, a judge determines how often the trained model's completions
win against the reference using a judge. The win rate is then logged in the trainer's logs under the key
`"eval_win_rate"`.
Usage:
```python
trainer = DPOTrainer(...)
judge = PairRMJudge()
win_rate_callback = WinRateCallback(judge=judge, trainer=trainer)
trainer.add_callback(win_rate_callback)
```
Args:
judge (`BasePairwiseJudge`):
The judge to use for comparing completions.
trainer (`Trainer`):
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions. If the `Trainer` has a reference model (via the
`ref_model` attribute), it will use this reference model for generating the reference completions;
otherwise, it defaults to using the initial model.
generation_config (`GenerationConfig`, *optional*):
The generation config to use for generating completions.
num_prompts (`int` or `None`, *optional*, defaults to `None`):
The number of prompts to generate completions for. If not provided, defaults to the number of examples
in the evaluation dataset.
shuffle_order (`bool`, *optional*, defaults to `True`):
Whether to shuffle the order of the completions before judging.
use_soft_judge (`bool`, *optional*, defaults to `False`):
Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the
second.
"""
def __init__(
self,
judge: BasePairwiseJudge,
trainer: Trainer,
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
shuffle_order: bool = True,
use_soft_judge: bool = False,
):
self.judge = judge
self.trainer = trainer
self.shuffle_order = shuffle_order
self.generation_config = generation_config
self.ref_completions = []
self.use_soft_judge = use_soft_judge
if self.trainer.eval_dataset is None:
raise ValueError("Trainer must have an evaluation dataset to use the WinRateCallback.")
else:
self.eval_dataset = self.trainer.eval_dataset
if num_prompts is not None:
self.eval_dataset = self.eval_dataset.select(range(num_prompts))
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# When the trainer is initialized, we generate completions for the reference model.
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
# Use the reference model if available, otherwise use the initial model
model = getattr(self.trainer, "ref_model", None)
# At this point, there are two cases where `ref_model` is None:
# 1. The method doesn't require a reference model.
# 2. The method uses a reference model, but `ref_model` is set to None.
# This occurs when using PEFT, where the reference model can be obtained by simply disabling the model's adapter.
# In theory, we should disable the adapter here, but since it's zero-initialized at the start of training,
# the model behaves identically with or without the adapter.
# Therefore, there's no need to explicitly disable it at this point.
if model is None:
model = self.trainer.model_wrapped
with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts:
self.ref_completions = _generate_completions(
prompts,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
)
# Compute initial win rate as a reference point
completions = list(zip(self.ref_completions, self.ref_completions))
if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
ref_win_probs = gather_object(ref_win_probs)
else:
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
prompts = gather_object(prompts)
completions = gather_object(completions)
winner_indices = gather_object(winner_indices)
# Logging
if self.trainer.accelerator.is_main_process:
win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices)
if self.use_soft_judge:
avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs)
self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate})
else:
self.trainer.log({"eval_win_rate": win_rate})
if "wandb" in args.report_to:
import wandb
if wandb.run is not None:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})
if "comet_ml" in args.report_to:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
log_table_to_comet_experiment(
name="win_rate_completions.csv",
table=df,
)
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# At every evaluation step, we generate completions for the model and compare them with the reference
# completions that have been generated at the beginning of training. We then compute the win rate and log it to
# the trainer.
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
model = self.trainer.model_wrapped
with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts:
completions = _generate_completions(
prompts,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
)
completions = list(zip(self.ref_completions, completions))
if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
ref_win_probs = gather_object(ref_win_probs)
else:
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
prompts = gather_object(prompts)
completions = gather_object(completions)
winner_indices = gather_object(winner_indices)
# Logging
if self.trainer.accelerator.is_main_process:
win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices)
if self.use_soft_judge:
avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs)
self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate})
else:
self.trainer.log({"eval_win_rate": win_rate})
if "wandb" in args.report_to:
import wandb
if wandb.run is not None:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})
if "comet_ml" in args.report_to:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
log_table_to_comet_experiment(
name="win_rate_completions.csv",
table=df,
)
class LogCompletionsCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases and/or Comet.
Usage:
```python
trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
trainer.add_callback(completions_callback)
```
Args:
trainer (`Trainer`):
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions.
generation_config (`GenerationConfig`, *optional*):
The generation config to use for generating completions.
num_prompts (`int` or `None`, *optional*):
The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset.
freq (`int` or `None`, *optional*):
The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`.
"""
def __init__(
self,
trainer: Trainer,
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
freq: Optional[int] = None,
):
self.trainer = trainer
self.generation_config = generation_config
self.freq = freq
self.table = []
self._last_logged_step = -1
if self.trainer.eval_dataset is None:
raise ValueError("Trainer must have an evaluation dataset to use the LogCompletionsCallback.")
else:
self.eval_dataset = self.trainer.eval_dataset
if num_prompts is not None:
self.eval_dataset = self.eval_dataset.select(range(num_prompts))
def on_step_end(self, args, state, control, **kwargs):
# Only log once per step (this method may be called multiple times)
if state.global_step == self._last_logged_step:
return
# Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps)
freq = self.freq or state.eval_steps
if state.global_step % freq != 0:
return
tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
model = self.trainer.model_wrapped
with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts:
prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts]
completions = _generate_completions(
prompts,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
)
completions = gather_object(completions)
prompts = gather_object(prompts)
# Build the data to log
if self.trainer.accelerator.is_main_process:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions))
self.table.extend(data)
table = pd.DataFrame(columns=["step", "prompt", "completion"], data=self.table)
if "wandb" in args.report_to:
wandb.log({"completions": table})
if "comet_ml" in args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=table,
)
# Save the last logged step, so we don't log the same completions multiple times
self._last_logged_step = state.global_step
class MergeModelCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based on a merge configuration.
Args:
merge_config ([`MergeConfig`], *optional*, defaults to `None`):
Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used.
merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to merge the model at every checkpoint.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the merged model to the Hub after merging.
Example:
```python
!pip install mergekit
from trl.mergekit_utils import MergeConfig
from trl import MergeModelCallback
config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
```
"""
def __init__(
self,
merge_config: Optional["MergeConfig"] = None,
merge_at_every_checkpoint: bool = False,
push_to_hub: bool = False,
):
if not is_mergekit_available():
raise ImportError(
"MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`."
)
self.merge_config = merge_config or MergeConfig()
self.merge_at_every_checkpoint = merge_at_every_checkpoint
self.push_to_hub = push_to_hub
def _merge_and_maybe_push(self, output_dir, global_step, model):
checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}")
self.merge_config.policy_model_path = checkpoint_path
if self.merge_config.target_model_path is None:
self.merge_config.target_model_path = model.config._name_or_path
merge_path = os.path.join(checkpoint_path, "merged")
merge_models(self.merge_config.create(), merge_path)
if self.push_to_hub:
repo_name = f"{output_dir}_checkpoint-{global_step}_merged"
upload_model_to_hf(merge_path, repo_name)
def on_save(self, args, state, control, model=None, **kwargs):
if self.merge_at_every_checkpoint:
self._merge_and_maybe_push(args.output_dir, state.global_step, model)
def on_train_end(self, args, state, control, model=None, **kwargs):
if not self.merge_at_every_checkpoint:
self._merge_and_maybe_push(args.output_dir, state.global_step, model)