Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import textwrap | |
import warnings | |
from functools import wraps | |
from pathlib import Path | |
from typing import Any, Callable, Optional, Union | |
import datasets | |
import jinja2 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.data | |
from datasets import Dataset | |
from packaging import version | |
from torch.utils.data import DataLoader, IterableDataset | |
from transformers import ( | |
BaseImageProcessor, | |
DataCollator, | |
FeatureExtractionMixin, | |
GenerationConfig, | |
PreTrainedModel, | |
PreTrainedTokenizerBase, | |
ProcessorMixin, | |
Trainer, | |
TrainerCallback, | |
is_apex_available, | |
is_wandb_available, | |
) | |
from transformers.trainer_utils import EvalPrediction, seed_worker | |
from transformers.training_args import OptimizerNames | |
from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging | |
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template | |
from ..import_utils import is_vllm_available | |
from ..models import create_reference_model | |
from ..models.utils import unwrap_model_for_generation | |
from .judges import BasePairwiseJudge | |
from .online_dpo_config import OnlineDPOConfig | |
from .utils import ( | |
SIMPLE_CHAT_TEMPLATE, | |
DPODataCollatorWithPadding, | |
disable_dropout_in_model, | |
empty_cache, | |
generate_model_card, | |
get_comet_experiment_url, | |
get_reward, | |
prepare_deepspeed, | |
truncate_right, | |
) | |
if is_peft_available(): | |
from peft import PeftModel, get_peft_model | |
if is_apex_available(): | |
from apex import amp | |
if is_sagemaker_mp_enabled(): | |
from smdistributed.modelparallel import __version__ as SMP_VERSION | |
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") | |
else: | |
IS_SAGEMAKER_MP_POST_1_10 = False | |
if is_vllm_available(): | |
from vllm import LLM, SamplingParams | |
if is_wandb_available(): | |
import wandb | |
logger = logging.get_logger(__name__) | |
class OnlineDPOTrainer(Trainer): | |
r""" | |
Initialize OnlineDPOTrainer. | |
Args: | |
model (`transformers.PreTrainedModel` or `torch.nn.Module`): | |
The model to train, preferably an `AutoModelForCausalLM`. | |
ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`): | |
The reference model to use for training. If None is specified, the reference model will be created from | |
the model. | |
reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`): | |
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`. | |
judge (`BasePairwiseJudge`): | |
The judge to use for pairwise comparison of model completions. | |
args (`OnlineDPOConfig`): | |
The online DPO config arguments to use for training. | |
data_collator (`transformers.DataCollator`): | |
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used | |
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. | |
train_dataset (`datasets.Dataset`): | |
The dataset to use for training. | |
eval_dataset (`datasets.Dataset`): | |
The dataset to use for evaluation. | |
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): | |
Processing class used to process the data. If provided, will be used to automatically process the inputs | |
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or | |
reuse the fine-tuned model. | |
peft_config (`dict`): | |
The peft config to use for training. | |
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): | |
The function to use to compute the metrics. Must take a `EvalPrediction` and return | |
a dictionary string to metric values. | |
callbacks (`list[transformers.TrainerCallback]`): | |
The callbacks to use for training. | |
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): | |
The optimizer and scheduler to use for training. | |
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): | |
The function to use to preprocess the logits before computing the metrics. | |
""" | |
_tag_names = ["trl", "online-dpo"] | |
def __init__( | |
self, | |
model: Union[PreTrainedModel, nn.Module], | |
ref_model: Union[PreTrainedModel, nn.Module, None] = None, | |
reward_model: Union[PreTrainedModel, nn.Module, None] = None, | |
judge: Optional[BasePairwiseJudge] = None, | |
args: Optional[OnlineDPOConfig] = None, | |
data_collator: Optional[DataCollator] = None, | |
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, | |
eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None, | |
processing_class: Optional[ | |
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] | |
] = None, | |
reward_processing_class: Optional[PreTrainedTokenizerBase] = None, | |
peft_config: Optional[dict] = None, | |
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, | |
callbacks: Optional[list[TrainerCallback]] = None, | |
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), | |
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, | |
) -> None: | |
if ref_model is model: | |
raise ValueError( | |
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " | |
"same as `model`, either omit the `ref_model` argument or pass `None`." | |
) | |
self.ref_model = ref_model | |
if reward_model is not None and judge is not None: | |
warnings.warn( | |
"Both `reward_model` and `judge` are provided. Please choose provide only one of them. " | |
"Ignoring `judge` and using `reward_model`.", | |
UserWarning, | |
) | |
judge = None | |
elif reward_model is None and judge is None: | |
raise ValueError("Either `reward_model` or `judge` must be provided.") | |
self.reward_model = reward_model | |
self.reward_processing_class = reward_processing_class | |
self.judge = judge | |
self.is_encoder_decoder = model.config.is_encoder_decoder | |
if args.missing_eos_penalty is not None and judge is not None: | |
raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.") | |
if args is None: | |
raise ValueError("`args` must be provided.") | |
# Check that the processing_class is provided | |
if processing_class is None: | |
raise ValueError("`processing_class` must be provided.") | |
# Convert to PEFT model if peft_config is provided | |
if peft_config is not None: | |
# Check if PEFT is available | |
if not is_peft_available(): | |
raise ImportError( | |
"PEFT is not available and passed `peft_config`. Please install PEFT with " | |
"`pip install peft` to use it." | |
) | |
# If the model is already a PeftModel, we need to merge and unload it. | |
# Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft | |
if isinstance(model, PeftModel): | |
model = model.merge_and_unload() | |
# Get peft model with the given config | |
model = get_peft_model(model, peft_config) | |
# Disable dropout in the model and reference model | |
if args.disable_dropout: | |
disable_dropout_in_model(model) | |
if self.ref_model is not None: | |
disable_dropout_in_model(self.ref_model) | |
# Handle the ref_model | |
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to | |
# get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create | |
# the ref model from the model by copying it and disable the gradients and set it in evaluation mode. | |
if ref_model is None: # No ref model provided, the most common case | |
if peft_config is None: | |
self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode | |
else: | |
self.ref_model = None # we don't need a ref model here, we can just disable the adapter. | |
else: # rare case, the user provided a ref model | |
self.ref_model = ref_model | |
self.ref_model.eval() | |
# Disable the gradient and set the reward model in eval mode | |
if self.reward_model is not None: | |
self.reward_model.eval() | |
# Define the collator is not provided | |
if data_collator is None: | |
data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id) | |
self.max_length = args.max_length | |
self.stats = { | |
"objective/kl": [], | |
"objective/entropy": [], | |
"objective/non_score_reward": [], | |
"rewards/chosen": [], | |
"rewards/rejected": [], | |
"rewards/accuracies": [], | |
"rewards/margins": [], | |
"logps/chosen": [], | |
"logps/rejected": [], | |
"val/contain_eos_token": [], | |
"beta": [], | |
} | |
if self.reward_model is not None: | |
self.stats["objective/rlhf_reward"] = [] | |
self.stats["objective/scores_margin"] = [] | |
self.stats["objective/scores"] = [] | |
if args.use_vllm: | |
if not is_vllm_available(): | |
raise ImportError( | |
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with " | |
"`pip install vllm` to use it." | |
) | |
self.generation_config = SamplingParams( | |
n=2, # 2 generations per prompt | |
max_tokens=args.max_new_tokens, | |
temperature=args.temperature, | |
top_k=50, | |
top_p=1.0, | |
detokenize=False, # to avoid vllm to decode (we don't need it) | |
) | |
# vLLM dynamically adjusts the size of the key-value cache based on available GPU memory at instantiation. | |
# A larger cache size improves speed, so we would expect gpu_memory_utilization=1. | |
# However, at this stage, the optimizer's weights are not yet loaded onto the GPU; they will be loaded | |
# after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough | |
# space for them. Setting gpu_memory_utilization to 0.55 seems to work well in practice. | |
self.llm = LLM( | |
model=model.name_or_path, | |
gpu_memory_utilization=args.gpu_memory_utilization, | |
dtype=torch.float32, | |
# When release by vLLM, we would be able to distribute the model on multiple GPUs | |
# See https://github.com/vllm-project/vllm/pull/12071 | |
# tensor_parallel_size=torch.cuda.device_count(), | |
# distributed_executor_backend="external_launcher", | |
) | |
else: | |
self.generation_config = GenerationConfig( | |
max_new_tokens=args.max_new_tokens, | |
temperature=args.temperature, | |
top_k=50, | |
top_p=1.0, | |
do_sample=True, | |
use_cache=False if args.gradient_checkpointing else True, | |
) | |
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the | |
# input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include | |
# the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens | |
# of the input, floating-point operations will not be computed." To suppress this warning, we set the | |
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate | |
# that the warning has already been issued. | |
model.warnings_issued["estimate_tokens"] = True | |
super().__init__( | |
model=model, | |
args=args, | |
data_collator=data_collator, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
processing_class=processing_class, | |
compute_metrics=compute_metrics, | |
callbacks=callbacks, | |
optimizers=optimizers, | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | |
) | |
# Add tags for models that have been loaded with the correct transformers version | |
if hasattr(self.model, "add_model_tags"): | |
self.model.add_model_tags(self._tag_names) | |
self._beta = args.beta | |
# Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator | |
if self.is_deepspeed_enabled: | |
if self.reward_model is not None: | |
self.reward_model = prepare_deepspeed( | |
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 | |
) | |
if self.ref_model is not None: | |
self.ref_model = prepare_deepspeed( | |
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 | |
) | |
else: | |
if self.ref_model is not None: | |
self.ref_model = self.ref_model.to(self.accelerator.device) | |
if self.reward_model is not None: | |
self.reward_model = self.reward_model.to(self.accelerator.device) | |
def beta(self): | |
if isinstance(self._beta, list): | |
epoch = self.state.epoch | |
return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1] | |
else: | |
return self._beta | |
def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]: | |
"""Tokenize a single row from a DPO specific dataset.""" | |
if not is_encoder_decoder: | |
batch = tokenizer(feature["prompt"], add_special_tokens=False) | |
# Add BOS token to head of prompt. Avoid adding if it's already there | |
if tokenizer.bos_token_id is not None: | |
prompt_len_input_ids = len(batch["input_ids"]) | |
if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]: | |
batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"] | |
batch["attention_mask"] = [1] + batch["attention_mask"] | |
else: | |
batch = tokenizer(feature["prompt"], add_special_tokens=True) | |
batch = {f"prompt_{key}": value for key, value in batch.items()} | |
return batch | |
# Same as Trainer.get_train_dataloader but skip the "remove_unused_columns". | |
def get_train_dataloader(self) -> DataLoader: | |
if self.train_dataset is None: | |
raise ValueError("Trainer: training requires a train_dataset.") | |
train_dataset = self.train_dataset | |
data_collator = self.data_collator | |
dataloader_params = { | |
"batch_size": self._train_batch_size, | |
"collate_fn": data_collator, | |
"num_workers": self.args.dataloader_num_workers, | |
"pin_memory": self.args.dataloader_pin_memory, | |
"persistent_workers": self.args.dataloader_persistent_workers, | |
} | |
if not isinstance(train_dataset, torch.utils.data.IterableDataset): | |
dataloader_params["sampler"] = self._get_train_sampler() | |
dataloader_params["drop_last"] = self.args.dataloader_drop_last | |
dataloader_params["worker_init_fn"] = seed_worker | |
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor | |
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) | |
# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns". | |
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: | |
if eval_dataset is None and self.eval_dataset is None: | |
raise ValueError("Trainer: evaluation requires an eval_dataset.") | |
# If we have persistent workers, don't do a fork bomb especially as eval datasets | |
# don't change during training | |
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" | |
if ( | |
hasattr(self, "_eval_dataloaders") | |
and dataloader_key in self._eval_dataloaders | |
and self.args.dataloader_persistent_workers | |
): | |
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) | |
eval_dataset = ( | |
self.eval_dataset[eval_dataset] | |
if isinstance(eval_dataset, str) | |
else eval_dataset | |
if eval_dataset is not None | |
else self.eval_dataset | |
) | |
data_collator = self.data_collator | |
dataloader_params = { | |
"batch_size": self.args.eval_batch_size, | |
"collate_fn": data_collator, | |
"num_workers": self.args.dataloader_num_workers, | |
"pin_memory": self.args.dataloader_pin_memory, | |
"persistent_workers": self.args.dataloader_persistent_workers, | |
} | |
if not isinstance(eval_dataset, torch.utils.data.IterableDataset): | |
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) | |
dataloader_params["drop_last"] = self.args.dataloader_drop_last | |
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor | |
# accelerator.free_memory() will destroy the references, so | |
# we need to store the non-prepared version | |
eval_dataloader = DataLoader(eval_dataset, **dataloader_params) | |
if self.args.dataloader_persistent_workers: | |
if hasattr(self, "_eval_dataloaders"): | |
self._eval_dataloaders[dataloader_key] = eval_dataloader | |
else: | |
self._eval_dataloaders = {dataloader_key: eval_dataloader} | |
return self.accelerator.prepare(eval_dataloader) | |
def _generate_vllm(self, model, prompts): | |
eos_token_id = self.processing_class.eos_token_id | |
pad_token_id = self.processing_class.pad_token_id | |
# Load the latest weights | |
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model | |
llm_model.load_weights(model.state_dict().items()) | |
if is_conversational({"prompt": prompts[0]}): | |
outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False) | |
else: | |
outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False) | |
completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs] | |
prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs] | |
# Create mask and pad the prompt and completion | |
max_prompt_length = max(len(ids) for ids in prompt_ids) | |
prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids] | |
prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids] | |
max_tokens = self.generation_config.max_tokens | |
completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids] | |
completion_ids = [ | |
ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids | |
for ids in completion_ids | |
] | |
completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids] | |
# Convert to tensors | |
prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device) | |
prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device) | |
completion_ids = torch.tensor(completion_ids, device=self.accelerator.device) | |
completion_mask = torch.tensor(completion_mask, device=self.accelerator.device) | |
return prompt_ids, prompt_mask, completion_ids, completion_mask | |
def _generate(self, model, prompts): | |
eos_token_id = self.processing_class.eos_token_id | |
pad_token_id = self.processing_class.pad_token_id | |
# Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and | |
# policies with different tokenizers / chat templates. | |
inputs = [{"prompt": prompt} for prompt in prompts] | |
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] | |
inputs = [self.tokenize_row(x, self.is_encoder_decoder, self.processing_class) for x in inputs] | |
inputs = self.data_collator(inputs) | |
# Sample 2 completions per prompt of size `max_new_tokens` from the model | |
inputs = self._prepare_inputs(inputs) | |
prompt_ids = inputs["prompt_input_ids"].repeat(2, 1) | |
prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1) | |
with unwrap_model_for_generation( | |
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation | |
) as unwrapped_model: | |
output = unwrapped_model.generate( | |
input_ids=prompt_ids, | |
attention_mask=prompt_mask, | |
generation_config=self.generation_config, | |
) | |
completion_ids = output[:, prompt_ids.size(1) :] | |
completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id) | |
return prompt_ids, prompt_mask, completion_ids, completion_mask | |
def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask): | |
# Get the number of tokens to truncate from prompt | |
num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0) | |
# Truncate left to avoid oom | |
prompt_ids = prompt_ids[:, num_tokens_to_truncate:] | |
prompt_mask = prompt_mask[:, num_tokens_to_truncate:] | |
# Concat the prompt and completion | |
prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1) | |
prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1) | |
# Get the logprobs of the completions from the model | |
output = model(prompt_completion_ids, attention_mask=prompt_completion_mask) | |
# There is 1 offset, because the model predict the next token | |
logits = output.logits[:, prompt_ids.size(1) - 1 : -1] | |
# Take the completion tokens logprob | |
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1) | |
return logprobs | |
def training_step( | |
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None | |
) -> torch.Tensor: | |
model.train() | |
prompts = inputs["prompt"] | |
batch_size = len(prompts) | |
if self.args.use_vllm: | |
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts) | |
else: | |
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts) | |
contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1) | |
logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask) | |
with torch.no_grad(): | |
if self.ref_model is not None: | |
ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask) | |
else: # peft case: we just need to disable the adapter | |
with self.model.disable_adapter(): | |
ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask) | |
# Decode the completions, and format them if the input is conversational | |
device = logprobs.device | |
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) | |
if is_conversational({"prompt": prompts[0]}): | |
completions = [[{"role": "assistant", "content": completion}] for completion in completions] | |
# Get the reward from the reward model or judge | |
if self.judge is not None: | |
# Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not | |
# directly understandable by the judge and could alter its judgment. To avoid this and make the judge | |
# independent of the model's chat template, we use the raw conversation data, and apply our own chat | |
# template to it. | |
if is_conversational({"prompt": prompts[0]}): | |
environment = jinja2.Environment() | |
template = environment.from_string(SIMPLE_CHAT_TEMPLATE) | |
prompts = [template.render(messages=prompt) for prompt in prompts] | |
completions = [template.render(messages=completion) for completion in completions] | |
ranks_of_first_completion = self.judge.judge( | |
prompts, list(zip(completions[:batch_size], completions[batch_size:])) | |
) | |
# convert ranks to a True/False mask: | |
# when rank == 0, it means the first completion is the best | |
# when rank == 1, it means the second completion is the best | |
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device) | |
else: | |
# The reward model may not have the same chat template or tokenizer as the model, so we need to use the | |
# raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class. | |
prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1] | |
if is_conversational({"prompt": prompts[0]}): | |
examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)] | |
examples = [apply_chat_template(example, self.reward_processing_class) for example in examples] | |
prompts = [example["prompt"] for example in examples] | |
completions = [example["completion"] for example in examples] | |
# Tokenize the prompts | |
prompts_ids = self.reward_processing_class( | |
prompts, padding=True, return_tensors="pt", padding_side="left" | |
)["input_ids"].to(device) | |
context_length = prompts_ids.shape[1] | |
# Tokenize the completions | |
completions_ids = self.reward_processing_class( | |
completions, padding=True, return_tensors="pt", padding_side="right" | |
)["input_ids"].to(device) | |
# Concatenate the prompts and completions and get the reward | |
prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1) | |
with torch.inference_mode(): | |
_, scores, _ = get_reward( | |
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length | |
) | |
# Filter completion. Ensure that the sample contains stop_token_id | |
# Completions not passing that filter will receive a lower score. | |
if self.args.missing_eos_penalty is not None: | |
scores[~contain_eos_token] -= self.args.missing_eos_penalty | |
# Split the scores in 2 (the prompts of the first half are the same as the second half) | |
first_half, second_half = scores.split(batch_size) | |
# Get the indices of the chosen and rejected examples | |
mask = first_half >= second_half | |
batch_range = torch.arange(batch_size, device=device) | |
chosen_indices = batch_range + (~mask * batch_size) | |
rejected_indices = batch_range + (mask * batch_size) | |
# Build tensor so that the first half is the chosen examples and the second half the rejected examples | |
cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected | |
cr_logprobs = logprobs[cr_indices] | |
cr_ref_logprobs = ref_logprobs[cr_indices] | |
# mask out the padding tokens | |
padding_mask = ~completion_mask.bool() | |
cr_padding_mask = padding_mask[cr_indices] | |
cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1) | |
cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1) | |
# Split the chosen and rejected examples | |
chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size) | |
chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size) | |
pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum | |
ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum | |
logits = pi_logratios - ref_logratios | |
if self.args.loss_type == "sigmoid": | |
losses = -F.logsigmoid(self.beta * logits) | |
elif self.args.loss_type == "ipo": | |
losses = (logits - 1 / (2 * self.beta)) ** 2 | |
else: | |
raise NotImplementedError(f"invalid loss type {self.loss_type}") | |
loss = losses.mean() | |
# Log everything | |
if self.reward_model is not None: | |
scores_margin = scores[chosen_indices] - scores[rejected_indices] | |
self.stats["objective/scores_margin"].append( | |
self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item() | |
) | |
self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item()) | |
self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item()) | |
self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item()) | |
self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item()) | |
kl = logprobs - ref_logprobs | |
mean_kl = kl.sum(1).mean() | |
self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) | |
non_score_reward = (-self.beta * kl).sum(1) | |
mean_non_score_reward = non_score_reward.mean() | |
self.stats["objective/non_score_reward"].append( | |
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() | |
) | |
if self.reward_model is not None: | |
rlhf_reward = scores + non_score_reward | |
self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item()) | |
mean_entropy = -logprobs.sum(1).mean() | |
self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item()) | |
chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) | |
gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards) | |
self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item()) | |
rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) | |
gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards) | |
self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item()) | |
margin = gathered_chosen_rewards - gathered_rejected_rewards | |
self.stats["rewards/margins"].append(margin.mean().item()) | |
accuracy = margin > 0 | |
self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) | |
self.stats["beta"].append(self.beta) | |
if ( | |
self.args.torch_empty_cache_steps is not None | |
and self.state.global_step % self.args.torch_empty_cache_steps == 0 | |
): | |
empty_cache() | |
kwargs = {} | |
# For LOMO optimizers you need to explicitly use the learnign rate | |
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: | |
kwargs["learning_rate"] = self._get_learning_rate() | |
if self.args.n_gpu > 1: | |
loss = loss.mean() # mean() to average on multi-gpu parallel training | |
if self.use_apex: | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
self.accelerator.backward(loss, **kwargs) | |
return loss.detach() / self.args.gradient_accumulation_steps | |
# Same as Trainer._maybe_log_save_evaluate but log our metrics | |
def _maybe_log_save_evaluate( | |
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None | |
): | |
if self.control.should_log and self.state.global_step > self._globalstep_last_logged: | |
logs: dict[str, float] = {} | |
# all_gather + mean() to get average loss over all processes | |
tr_loss_scalar = self._nested_gather(tr_loss).mean().item() | |
# reset tr_loss to zero | |
tr_loss -= tr_loss | |
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) | |
if grad_norm is not None: | |
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm | |
if learning_rate is not None: | |
logs["learning_rate"] = learning_rate | |
else: | |
logs["learning_rate"] = self._get_learning_rate() | |
# Add our metrics | |
for key, val in self.stats.items(): | |
logs[key] = sum(val) / len(val) | |
self.stats = {key: [] for key in self.stats} # reset stats | |
self._total_loss_scalar += tr_loss_scalar | |
self._globalstep_last_logged = self.state.global_step | |
self.store_flos() | |
self.log(logs, start_time) | |
metrics = None | |
if self.control.should_evaluate: | |
metrics = self._evaluate(trial, ignore_keys_for_eval) | |
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) | |
if self.args.save_strategy == "best": | |
self.control.should_save = is_new_best_metric | |
if self.control.should_save: | |
self._save_checkpoint(model, trial) | |
self.control = self.callback_handler.on_save(self.args, self.state, self.control) | |
# Ensure the model card is saved along with the checkpoint | |
def _save_checkpoint(self, model, trial): | |
if self.args.hub_model_id is None: | |
model_name = Path(self.args.output_dir).name | |
else: | |
model_name = self.args.hub_model_id.split("/")[-1] | |
self.create_model_card(model_name=model_name) | |
super()._save_checkpoint(model, trial) | |
def create_model_card( | |
self, | |
model_name: Optional[str] = None, | |
dataset_name: Optional[str] = None, | |
tags: Union[str, list[str], None] = None, | |
): | |
""" | |
Creates a draft of a model card using the information available to the `Trainer`. | |
Args: | |
model_name (`str` or `None`, *optional*, defaults to `None`): | |
Name of the model. | |
dataset_name (`str` or `None`, *optional*, defaults to `None`): | |
Name of the dataset used for training. | |
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): | |
Tags to be associated with the model card. | |
""" | |
if not self.is_world_process_zero(): | |
return | |
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): | |
base_model = self.model.config._name_or_path | |
else: | |
base_model = None | |
tags = tags or set() | |
if isinstance(tags, str): | |
tags = {tags} | |
if hasattr(self.model.config, "unsloth_version"): | |
tags.add("unsloth") | |
tags.update(self._tag_names) | |
citation = textwrap.dedent("""\ | |
@article{guo2024direct, | |
title = {{Direct Language Model Alignment from Online AI Feedback}}, | |
author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel}, | |
year = 2024, | |
eprint = {arXiv:2402.04792} | |
}""") | |
model_card = generate_model_card( | |
base_model=base_model, | |
model_name=model_name, | |
hub_model_id=self.hub_model_id, | |
dataset_name=dataset_name, | |
tags=tags, | |
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, | |
comet_url=get_comet_experiment_url(), | |
trainer_name="Online DPO", | |
trainer_citation=citation, | |
paper_title="Direct Language Model Alignment from Online AI Feedback", | |
paper_id="2402.04792", | |
) | |
model_card.save(os.path.join(self.args.output_dir, "README.md")) | |