Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
import os | |
import random | |
import textwrap | |
import warnings | |
from collections import defaultdict | |
from contextlib import contextmanager, nullcontext | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any, Callable, Literal, Optional, Union | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from accelerate import PartialState | |
from accelerate.utils import tqdm | |
from datasets import Dataset, IterableDataset | |
from torch import autocast | |
from torch.utils.data import DataLoader | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BaseImageProcessor, | |
DataCollator, | |
FeatureExtractionMixin, | |
PreTrainedModel, | |
PreTrainedTokenizerBase, | |
ProcessorMixin, | |
Trainer, | |
is_comet_available, | |
is_wandb_available, | |
) | |
from transformers.data.data_collator import DataCollatorMixin | |
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES | |
from transformers.trainer_callback import TrainerCallback | |
from transformers.trainer_utils import EvalLoopOutput | |
from transformers.utils import is_liger_kernel_available, is_peft_available | |
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt | |
from ..models import create_reference_model, prepare_deepspeed | |
from ..models.utils import prepare_fsdp | |
from .callbacks import SyncRefModelCallback | |
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType | |
from .utils import ( | |
RunningMoments, | |
cap_exp, | |
disable_dropout_in_model, | |
empty_cache, | |
flush_left, | |
flush_right, | |
generate_model_card, | |
get_comet_experiment_url, | |
log_table_to_comet_experiment, | |
pad, | |
pad_to_length, | |
peft_module_casting_to_bf16, | |
selective_log_softmax, | |
) | |
if is_peft_available(): | |
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training | |
if is_liger_kernel_available(): | |
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss | |
if is_wandb_available(): | |
import wandb | |
def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> torch.Tensor: | |
"""Shift input ids one token to the right, and pad with pad_token_id""" | |
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | |
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | |
shifted_input_ids[:, 0] = decoder_start_token_id | |
class DataCollatorForPreference(DataCollatorMixin): | |
""" | |
Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they | |
are not all of the same length. | |
Args: | |
pad_token_id (`int`): | |
Token ID to use for padding. | |
return_tensors (`str`, *optional*, defaults to `"pt"`): | |
Type of Tensor to return. Only `"pt"` is currently supported. | |
Examples: | |
```python | |
>>> from trl import DataCollatorForPreference | |
>>> collator = DataCollatorForPreference(pad_token_id=0) | |
>>> examples = [ | |
... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, | |
... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]} | |
... ] | |
>>> collator(examples) | |
{'prompt_input_ids': tensor([[1, 2, 3], | |
[0, 7, 8]]), | |
'prompt_attention_mask': tensor([[1, 1, 1], | |
[0, 1, 1]]), | |
'chosen_input_ids': tensor([[ 4, 5], | |
[ 9, 10]]), | |
'chosen_attention_mask': tensor([[1, 1], | |
[1, 1]]), | |
'rejected_input_ids': tensor([[ 6, 0, 0], | |
[11, 12, 13]]), | |
'rejected_attention_mask': tensor([[1, 0, 0], | |
[1, 1, 1]]) | |
} | |
``` | |
""" | |
pad_token_id: int | |
return_tensors: str = "pt" | |
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: | |
# Convert to tensor | |
prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples] | |
prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids] | |
chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] | |
chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids] | |
rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] | |
rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids] | |
if "pixel_values" in examples[0]: | |
pixel_values = [torch.tensor(example["pixel_values"]) for example in examples] | |
if "pixel_attention_mask" in examples[0]: | |
pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples] | |
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: | |
ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples]) | |
ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples]) | |
# Pad | |
output = {} | |
output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left") | |
output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left") | |
output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=self.pad_token_id) | |
output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0) | |
output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=self.pad_token_id) | |
output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0) | |
if "pixel_values" in examples[0]: | |
output["pixel_values"] = pad(pixel_values, padding_value=0.0) | |
if "pixel_attention_mask" in examples[0]: | |
output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0) | |
if "image_sizes" in examples[0]: | |
output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples]) | |
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: | |
output["ref_chosen_logps"] = ref_chosen_logps | |
output["ref_rejected_logps"] = ref_rejected_logps | |
return output | |
class DPOTrainer(Trainer): | |
""" | |
Trainer for Direct Preference Optimization (DPO) method. | |
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. | |
Args: | |
model (`Union[str, PreTrainedModel]`): | |
Model to be trained. Can be either: | |
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or | |
a path to a *directory* containing model weights saved using | |
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is | |
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments | |
in `args.model_init_kwargs`. | |
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. | |
ref_model (`PreTrainedModelWrapper`): | |
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no | |
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. | |
args ([`DPOConfig`], *optional*, defaults to `None`): | |
Configuration for this trainer. If `None`, a default configuration is used. | |
data_collator (`DataCollator`, *optional*): | |
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. | |
Will default to [`DataCollatorForPreference`]. | |
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): | |
Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can | |
be either: | |
- [Standard](dataset_formats#standard): Each sample contains plain text. | |
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role | |
and content). | |
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): | |
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. | |
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): | |
Processing class used to process the data. If `None`, the processing class is loaded from the model's name | |
with [`~transformers.AutoTokenizer.from_pretrained`]. | |
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): | |
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return | |
a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to | |
`True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered | |
after the last eval batch to signal that the function needs to calculate and return the global summary | |
statistics rather than accumulating the batch-level statistics. | |
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): | |
List of callbacks to customize the training loop. Will add those to the list of default callbacks | |
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). | |
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] | |
method. | |
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): | |
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your | |
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. | |
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`): | |
A tuple containing the optimizer class and keyword arguments to use. | |
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. | |
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`): | |
A function that preprocess the logits right before caching them at each evaluation step. Must take two | |
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made | |
by this function will be reflected in the predictions received by `compute_metrics`. | |
Note that the labels (second parameter) will be `None` if the dataset does not have them. | |
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): | |
PEFT configuration used to wrap the model. If `None`, the model is not wrapped. | |
""" | |
_tag_names = ["trl", "dpo"] | |
def __init__( | |
self, | |
model: Union[str, nn.Module, PreTrainedModel], | |
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, | |
args: Optional[DPOConfig] = None, | |
data_collator: Optional[DataCollator] = None, # type: ignore | |
train_dataset: Optional[Union[Dataset, IterableDataset]] = None, | |
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, | |
processing_class: Optional[ | |
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] | |
] = None, | |
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, | |
callbacks: Optional[list[TrainerCallback]] = None, | |
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), | |
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, | |
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, | |
peft_config: Optional["PeftConfig"] = None, | |
): | |
# Args | |
model_id = model if isinstance(model, str) else model.config._name_or_path | |
if args is None: | |
model_name = model_id.split("/")[-1] | |
args = DPOConfig(f"{model_name}-DPO") | |
# Handle the tokenizer | |
if processing_class is None: | |
processing_class = AutoTokenizer.from_pretrained(model_id) | |
if args.padding_value is not None: | |
self.padding_value = args.padding_value | |
else: | |
if hasattr(processing_class, "pad_token_id") and processing_class.pad_token_id is not None: | |
self.padding_value = processing_class.pad_token_id | |
elif hasattr(processing_class, "tokenizer") and processing_class.tokenizer.pad_token_id is not None: | |
self.padding_value = processing_class.tokenizer.pad_token_id | |
else: | |
raise ValueError( | |
"`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in the " | |
"`processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set " | |
"`tokenizer.pad_token` (e.g., `tokenizer.pad_token = tokenizer.eos_token`) before instantiating " | |
"the trainer." | |
) | |
# Model | |
if not isinstance(model, str) and ref_model is model: | |
raise ValueError( | |
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " | |
"same as `model`, you must mass a copy of it, or `None` if you use peft." | |
) | |
if args.model_init_kwargs is not None and not isinstance(model, str): | |
warnings.warn( | |
"You passed model_init_kwargs to the `DPOConfig`, but your model is already instantiated. " | |
"The `model_init_kwargs` will be ignored." | |
) | |
if isinstance(model, str): | |
model = self._create_model_from_path(model, args) | |
if args.ref_model_init_kwargs is not None and not isinstance(ref_model, str): | |
warnings.warn( | |
"You passed ref_model_init_kwargs to the `DPOConfig`, but your ref_model is already instantiated. " | |
"The `ref_model_init_kwargs` will be ignored." | |
) | |
if isinstance(ref_model, str): | |
ref_model = self._create_model_from_path(ref_model, args, is_ref=True) | |
# PEFT configuration and model wrapping | |
model = self._prepare_peft_model(model, ref_model, peft_config, args) | |
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): | |
raise ValueError( | |
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed." | |
" Please install `wandb` or `comet-ml` to resolve." | |
) | |
self.is_encoder_decoder = model.config.is_encoder_decoder | |
self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys() | |
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) | |
self.model_adapter_name = args.model_adapter_name | |
self.ref_adapter_name = args.ref_adapter_name | |
self.reference_free = args.reference_free | |
if ref_model: | |
self.ref_model = ref_model | |
elif self.is_peft_model or args.precompute_ref_log_probs: | |
# The `model` with adapters turned off will be used as the reference model | |
self.ref_model = None | |
else: | |
self.ref_model = create_reference_model(model) | |
# Disable dropout in the model and reference model | |
if args.disable_dropout: | |
disable_dropout_in_model(model) | |
if self.ref_model is not None: | |
disable_dropout_in_model(self.ref_model) | |
# Liger kernel | |
if args.use_liger_loss: | |
if not is_liger_kernel_available(): | |
raise ImportError( | |
"You set `use_liger_loss=True` but the liger kernel is not available. " | |
"Please install liger-kernel first: `pip install liger-kernel`" | |
) | |
if args.loss_type != "sigmoid": | |
raise ValueError( | |
"You set `use_liger_loss=True` but the loss type is not `sigmoid`. " | |
"Please set `loss_type='sigmoid'` to use the liger kernel." | |
) | |
self.dpo_loss_fn = LigerFusedLinearDPOLoss( | |
ignore_index=args.label_pad_token_id, | |
beta=args.beta, | |
use_ref_model=not args.reference_free, | |
average_log_prob=False, | |
) | |
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the | |
# input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the | |
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and | |
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens | |
# of the input, floating-point operations will not be computed." To suppress this warning, we set the | |
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate | |
# that the warning has already been issued. | |
model.warnings_issued["estimate_tokens"] = True | |
# Data collator | |
if data_collator is None: | |
data_collator = DataCollatorForPreference(pad_token_id=self.padding_value) | |
self.generate_during_eval = args.generate_during_eval | |
self.label_pad_token_id = args.label_pad_token_id | |
self.max_prompt_length = args.max_prompt_length | |
self.max_completion_length = args.max_completion_length | |
self.max_length = args.max_length | |
self.truncation_mode = args.truncation_mode | |
self.precompute_ref_log_probs = args.precompute_ref_log_probs | |
self.use_logits_to_keep = args.use_logits_to_keep | |
if args.padding_free: | |
if model.config._attn_implementation != "flash_attention_2": | |
warnings.warn( | |
"Padding-free training is enabled, but the attention implementation is not set to " | |
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " | |
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " | |
"other implementations may lead to unexpected behavior. To ensure compatibility, set " | |
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " | |
"attention mechanism can handle flattened sequences." | |
) | |
if args.per_device_train_batch_size == 1: | |
warnings.warn( | |
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " | |
"of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " | |
"to at least 2." | |
) | |
self.padding_free = args.padding_free | |
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader | |
# keep track of first called to avoid computation of future calls | |
self._precomputed_train_ref_log_probs = False | |
self._precomputed_eval_ref_log_probs = False | |
if ( | |
args.loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"] | |
and args.label_smoothing > 0 | |
): | |
warnings.warn( | |
f"You are using the {args.loss_type} loss type that does not support label smoothing. The " | |
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.", | |
UserWarning, | |
) | |
if args.loss_type == "kto_pair": | |
raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.") | |
self.beta = args.beta | |
self.label_smoothing = args.label_smoothing | |
self.loss_type = args.loss_type | |
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) | |
self.use_weighting = args.use_weighting | |
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) | |
if self.aux_loss_enabled and self.aux_loss_coef == 0.0: | |
warnings.warn( | |
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " | |
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " | |
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " | |
"loss.", | |
UserWarning, | |
) | |
self._stored_metrics = defaultdict(lambda: defaultdict(list)) | |
self.f_divergence_type = args.f_divergence_type | |
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} | |
self.dataset_num_proc = args.dataset_num_proc | |
# Dataset preparation | |
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") | |
if eval_dataset is not None: | |
if isinstance(eval_dataset, dict): | |
eval_dataset = { | |
key: self._prepare_dataset(dataset, processing_class, args, key) | |
for key, dataset in eval_dataset.items() | |
} | |
else: | |
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") | |
super().__init__( | |
model=model, | |
args=args, | |
data_collator=data_collator, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
processing_class=processing_class, | |
compute_metrics=compute_metrics, | |
callbacks=callbacks, | |
optimizers=optimizers, | |
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | |
) | |
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the | |
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set | |
# self.model_accepts_loss_kwargs to False to enable scaling. | |
self.model_accepts_loss_kwargs = False | |
# Add tags for models that have been loaded with the correct transformers version | |
if hasattr(self.model, "add_model_tags"): | |
self.model.add_model_tags(self._tag_names) | |
if not hasattr(self, "accelerator"): | |
raise AttributeError( | |
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." | |
) | |
# Deepspeed Zero-3 does not support precompute_ref_log_probs | |
if self.is_deepspeed_enabled: | |
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: | |
raise ValueError( | |
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." | |
) | |
if self.ref_model is None: | |
if not (self.is_peft_model or self.precompute_ref_log_probs): | |
raise ValueError( | |
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" | |
) | |
if args.sync_ref_model: | |
raise ValueError( | |
"You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." | |
) | |
else: | |
if self.is_deepspeed_enabled: | |
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) | |
elif self.is_fsdp_enabled: | |
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) | |
else: | |
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) | |
if args.sync_ref_model: | |
if self.precompute_ref_log_probs: | |
raise ValueError( | |
"You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." | |
) | |
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) | |
if self.loss_type == "bco_pair": | |
self.running = RunningMoments(self.accelerator) | |
def _create_model_from_path(self, model_path: str, args: DPOConfig, is_ref: bool = False) -> PreTrainedModel: | |
"""Creates a model from a path or model identifier.""" | |
if not is_ref: | |
model_init_kwargs = args.model_init_kwargs or {} | |
else: | |
model_init_kwargs = args.ref_model_init_kwargs or {} | |
# Handle torch dtype | |
torch_dtype = model_init_kwargs.get("torch_dtype") | |
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: | |
pass # torch_dtype is already a torch.dtype or "auto" or None | |
elif isinstance(torch_dtype, str): # it's a str, but not "auto" | |
torch_dtype = getattr(torch, torch_dtype) | |
model_init_kwargs["torch_dtype"] = torch_dtype | |
else: | |
raise ValueError( | |
"Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing " | |
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." | |
) | |
# Disable caching if gradient checkpointing is enabled (not supported) | |
# if args.gradient_checkpointing: | |
# model_init_kwargs["use_cache"] = False | |
# Create model | |
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) | |
return model | |
def _prepare_peft_model( | |
self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig | |
) -> PreTrainedModel: | |
"""Prepares a model for PEFT training.""" | |
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` | |
# has been called in order to properly call autocast if needed. | |
self._peft_has_been_casted_to_bf16 = False | |
if not is_peft_available() and peft_config is not None: | |
raise ValueError( | |
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" | |
) | |
elif is_peft_available() and peft_config is not None: | |
# if model is a peft model and we have a peft_config, we merge and unload it first | |
if isinstance(model, PeftModel): | |
model = model.merge_and_unload() | |
if ref_model is not None and not args.force_use_ref_model: | |
raise ValueError( | |
"You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference" | |
" model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init." | |
" if you want to use a different ref_model." | |
) | |
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): | |
_support_gc_kwargs = hasattr( | |
args, "gradient_checkpointing_kwargs" | |
) and "gradient_checkpointing_kwargs" in list( | |
inspect.signature(prepare_model_for_kbit_training).parameters | |
) | |
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} | |
if _support_gc_kwargs: | |
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs | |
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) | |
else: | |
model = self._prepare_gradient_checkpointing(model, args) | |
# get peft model with the given config | |
model = get_peft_model(model, peft_config) | |
if args.bf16 and getattr(model, "is_loaded_in_4bit", False): | |
peft_module_casting_to_bf16(model) | |
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager | |
self._peft_has_been_casted_to_bf16 = True | |
else: | |
model = self._prepare_gradient_checkpointing(model, args) | |
return model | |
def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig): | |
"""Prepare the gradienting checkpointing for the model.""" | |
# For models that use gradient_checkpointing, we need to attach a hook that enables input | |
# to explicitly have `requires_grad=True`, otherwise training will either silently | |
# fail or completely fail. | |
if args.gradient_checkpointing: | |
# For backward compatibility with older versions of transformers | |
if hasattr(model, "enable_input_require_grads"): | |
model.enable_input_require_grads() | |
else: | |
def make_inputs_require_grad(module, input, output): | |
output.requires_grad_(True) | |
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) | |
return model | |
def _prepare_dataset( | |
self, | |
dataset: Union[Dataset, IterableDataset], | |
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], | |
args: DPOConfig, | |
dataset_name: str, | |
) -> Union[Dataset, IterableDataset]: | |
# Build the kwargs for the `map` function | |
map_kwargs = {} | |
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size | |
map_kwargs["num_proc"] = args.dataset_num_proc | |
map_kwargs["writer_batch_size"] = 10 | |
with PartialState().main_process_first(): | |
# Extract prompt if needed | |
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | |
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" | |
dataset = dataset.map(maybe_extract_prompt, **map_kwargs) | |
# Apply the chat template if needed | |
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | |
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" | |
dataset = dataset.map( | |
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs | |
) | |
# Tokenize the dataset | |
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | |
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" | |
dataset = dataset.map( | |
self.tokenize_row if not self.is_vision_model else self.process_row, | |
remove_columns=["chosen", "rejected"], | |
fn_kwargs={ | |
"processing_class": processing_class, | |
"max_prompt_length": args.max_prompt_length, | |
"max_completion_length": args.max_completion_length, | |
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) | |
"add_special_tokens": False, | |
}, | |
**map_kwargs, | |
) | |
return dataset | |
def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): | |
""" | |
Tokenize a row of the dataset. | |
Args: | |
features (`dict[str, str]`): | |
Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. | |
processing_class (`PreTrainedTokenizerBase`): | |
Processing class used to process the data. | |
max_prompt_length (`int` or `None`): | |
Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. | |
max_completion_length (`int` or `None`): | |
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. | |
add_special_tokens (`bool`): | |
Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, | |
the prompt sequence will have a bos token prepended and an eos token appended. In any case, the | |
completion sequences will have an eos token appended. | |
Returns: | |
`dict[str, list[int]]`: | |
Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and | |
`"rejected_input_ids". | |
Example: | |
```python | |
>>> from transformers import GPT2Tokenizer | |
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
>>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} | |
>>> DPOTrainer.tokenize_row( | |
... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False | |
... ) | |
{'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} | |
``` | |
""" | |
tokenizer = processing_class # the processing class is a tokenizer | |
prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] | |
chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] | |
rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] | |
# Add special tokens (typically for encoder-decoder models) | |
if add_special_tokens: | |
if tokenizer.bos_token_id is not None: | |
prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids | |
if tokenizer.eos_token_id is not None: | |
prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] | |
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] | |
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] | |
# Truncate prompt and completion sequences | |
if max_prompt_length is not None: | |
prompt_input_ids = prompt_input_ids[-max_prompt_length:] | |
if max_completion_length is not None: | |
chosen_input_ids = chosen_input_ids[:max_completion_length] | |
rejected_input_ids = rejected_input_ids[:max_completion_length] | |
return { | |
"prompt_input_ids": prompt_input_ids, | |
"chosen_input_ids": chosen_input_ids, | |
"rejected_input_ids": rejected_input_ids, | |
} | |
def process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): | |
""" | |
Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information. | |
""" | |
processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor | |
processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) | |
prompt_input_ids = processed_features["input_ids"][0] | |
pixel_values = processed_features["pixel_values"][0] | |
chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] | |
rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] | |
# Add special tokens (typically for encoder-decoder models) | |
if add_special_tokens: | |
if tokenizer.bos_token_id is not None: | |
prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids | |
if tokenizer.eos_token_id is not None: | |
prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] | |
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] | |
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] | |
# Truncate prompt and completion sequences | |
if max_prompt_length is not None: | |
prompt_input_ids = prompt_input_ids[-max_prompt_length:] | |
if max_completion_length is not None: | |
chosen_input_ids = chosen_input_ids[:max_completion_length] | |
rejected_input_ids = rejected_input_ids[:max_completion_length] | |
output = { | |
"prompt_input_ids": prompt_input_ids, | |
"pixel_values": pixel_values, | |
"chosen_input_ids": chosen_input_ids, | |
"rejected_input_ids": rejected_input_ids, | |
} | |
if "pixel_attention_mask" in processed_features: | |
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] | |
if "image_sizes" in processed_features: | |
output["image_sizes"] = processed_features["image_sizes"][0] | |
return output | |
def _set_signature_columns_if_needed(self): | |
# If `self.args.remove_unused_columns` is True, non-signature columns are removed. | |
# By default, this method sets `self._signature_columns` to the model's expected inputs. | |
# In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. | |
# Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override. | |
if self._signature_columns is None: | |
self._signature_columns = [ | |
"prompt_input_ids", | |
"chosen_input_ids", | |
"rejected_input_ids", | |
"image_sizes", | |
"ref_chosen_logps", | |
"ref_rejected_logps", | |
] | |
def get_train_dataloader(self) -> DataLoader: | |
""" | |
Returns the training [`~torch.utils.data.DataLoader`]. | |
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. | |
""" | |
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: | |
batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size | |
dataloader_params = { | |
"batch_size": batch_size, | |
"collate_fn": self.data_collator, | |
"num_workers": self.args.dataloader_num_workers, | |
"pin_memory": self.args.dataloader_pin_memory, | |
"shuffle": False, | |
} | |
# prepare dataloader | |
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) | |
ref_chosen_logps = [] | |
ref_rejected_logps = [] | |
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): | |
ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) | |
ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( | |
(ref_chosen_logp, ref_rejected_logp) | |
) | |
ref_chosen_logps.append(ref_chosen_logp.cpu()) | |
ref_rejected_logps.append(ref_rejected_logp.cpu()) | |
# Unnecessary cache clearing to avoid OOM | |
empty_cache() | |
self.accelerator.free_memory() | |
all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() | |
all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() | |
self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) | |
self.train_dataset = self.train_dataset.add_column( | |
name="ref_rejected_logps", column=all_ref_rejected_logps | |
) | |
self._precomputed_train_ref_log_probs = True | |
return super().get_train_dataloader() | |
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: | |
""" | |
Returns the evaluation [`~torch.utils.data.DataLoader`]. | |
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. | |
Args: | |
eval_dataset (`torch.utils.data.Dataset`, *optional*): | |
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted | |
by the `model.forward()` method are automatically removed. It must implement `__len__`. | |
""" | |
if eval_dataset is None and self.eval_dataset is None: | |
raise ValueError("Trainer: evaluation requires an eval_dataset.") | |
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset | |
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: | |
batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size | |
dataloader_params = { | |
"batch_size": batch_size, | |
"collate_fn": self.data_collator, | |
"num_workers": self.args.dataloader_num_workers, | |
"pin_memory": self.args.dataloader_pin_memory, | |
"shuffle": False, | |
} | |
# prepare dataloader | |
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) | |
ref_chosen_logps = [] | |
ref_rejected_logps = [] | |
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): | |
ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) | |
ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( | |
(ref_chosen_logp, ref_rejected_logp) | |
) | |
ref_chosen_logps.append(ref_chosen_logp.cpu()) | |
ref_rejected_logps.append(ref_rejected_logp.cpu()) | |
all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() | |
all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() | |
eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) | |
eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps) | |
# Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs | |
if self.eval_dataset is not None: | |
self.eval_dataset = eval_dataset | |
self._precomputed_eval_ref_log_probs = True | |
return super().get_eval_dataloader(eval_dataset=eval_dataset) | |
def null_ref_context(self): | |
"""Context manager for handling null reference model (that is, peft adapter manipulation).""" | |
with ( | |
self.accelerator.unwrap_model(self.model).disable_adapter() | |
if self.is_peft_model and not self.ref_adapter_name | |
else nullcontext() | |
): | |
if self.ref_adapter_name: | |
self.model.set_adapter(self.ref_adapter_name) | |
yield | |
if self.ref_adapter_name: | |
self.model.set_adapter(self.model_adapter_name or "default") | |
def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: | |
"""Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" | |
compte_ref_context_manager = ( | |
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() | |
) | |
with torch.no_grad(), compte_ref_context_manager: | |
if self.ref_model is None: | |
with self.null_ref_context(): | |
ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) | |
else: | |
ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) | |
return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] | |
def concatenated_inputs( | |
batch: dict[str, Union[list, torch.LongTensor]], padding_value: int | |
) -> dict[str, torch.LongTensor]: | |
""" | |
Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt | |
and completion sequences. | |
Args: | |
batch (`dict[str, Union[list, torch.LongTensor]]`): | |
A batch of input data. The batch must contain the following keys: | |
- `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input IDs. | |
- `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen completion input IDs. | |
- `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected completion input IDs. | |
- `"prompt_pixel_values"` (optional): Tensor for pixel values, if available. | |
- `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available. | |
padding_value (`int`): | |
The padding value to use for the concatenated completion sequences (`chosen_input_ids` and | |
`rejected_input_ids`). | |
Returns: | |
`dict[str, torch.LongTensor]`: A dictionary containing: | |
- `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`. | |
- `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * batch_size, max_completion_length)`. | |
- `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, prompt_length)`. | |
- `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * batch_size, max_completion_length)`. | |
- `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present. | |
- `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if `"prompt_pixel_attention_mask"` are present. | |
Notes: | |
The completion input IDs and attention masks are padded to the maximum completion length of the chosen | |
or rejected sequences. | |
""" | |
output = {} | |
# For the prompt, the input_ids are the same for both the chosen and rejected responses | |
output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0) | |
output["prompt_attention_mask"] = torch.cat( | |
[batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0 | |
) | |
if "pixel_values" in batch: | |
output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0) | |
if "pixel_attention_mask" in batch: | |
output["pixel_attention_mask"] = torch.cat( | |
[batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 | |
) | |
if "image_sizes" in batch: | |
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) | |
# Concatenate the chosen and rejected completions | |
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) | |
output["completion_input_ids"] = torch.cat( | |
( | |
pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value), | |
pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value), | |
), | |
) | |
output["completion_attention_mask"] = torch.cat( | |
( | |
pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0), | |
pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0), | |
), | |
) | |
return output | |
def dpo_loss( | |
self, | |
chosen_logps: torch.FloatTensor, | |
rejected_logps: torch.FloatTensor, | |
ref_chosen_logps: torch.FloatTensor, | |
ref_rejected_logps: torch.FloatTensor, | |
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: | |
""" | |
Compute the DPO loss for a batch of policy and reference model log probabilities. | |
Args: | |
chosen_logps (`torch.FloatTensor`): | |
Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`. | |
rejected_logps (`torch.FloatTensor`): | |
Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`. | |
ref_chosen_logps (`torch.FloatTensor`): | |
Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`. | |
ref_rejected_logps (`torch.FloatTensor`): | |
Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`. | |
Returns: | |
A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. | |
The losses tensor contains the DPO loss for each example in the batch. | |
The `chosen_rewards` and `rejected_rewards` tensors contain the rewards for the chosen and rejected | |
responses, respectively. | |
""" | |
device = self.accelerator.device | |
# Get the log ratios for the chosen and rejected responses | |
chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) | |
rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) | |
if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value: | |
# The alpha-divergence formula: (1 - u^-alpha) / alpha | |
# The divergence difference between the chosen and rejected sample is: | |
# (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha | |
# = (u[l]^-alpha - u[w]^-alpha) / alpha | |
# where u[w] and u[l] are the policy/reference probability ratios | |
# for the chosen and rejected samples, respectively. | |
alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT | |
if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: | |
alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) | |
logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef | |
else: | |
logratios = chosen_logps - rejected_logps | |
if self.reference_free: | |
ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) | |
else: | |
ref_logratios = ref_chosen_logps - ref_rejected_logps | |
logratios = logratios.to(self.accelerator.device) | |
ref_logratios = ref_logratios.to(self.accelerator.device) | |
logits = logratios - ref_logratios | |
if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value: | |
# The js-divergence formula: log(2 * u / (1 + u)) | |
# The divergence difference between the chosen and rejected sample is: | |
# log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l])) | |
# = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l])) | |
# where u[w] and u[l] are the policy/reference probability ratios | |
# for the chosen and rejected samples, respectively. | |
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) | |
# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. | |
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the | |
# labels and calculates a conservative DPO loss. | |
if self.loss_type == "sigmoid": | |
losses = ( | |
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) | |
- F.logsigmoid(-self.beta * logits) * self.label_smoothing | |
) | |
elif self.loss_type == "robust": | |
losses = ( | |
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) | |
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing | |
) / (1 - 2 * self.label_smoothing) | |
elif self.loss_type == "exo_pair": | |
# eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856 | |
import math | |
if self.label_smoothing == 0: | |
self.label_smoothing = 1e-3 | |
losses = (self.beta * logits).sigmoid() * ( | |
F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing) | |
) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) | |
elif self.loss_type == "hinge": | |
losses = torch.relu(1 - self.beta * logits) | |
elif self.loss_type == "ipo": | |
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. | |
losses = (logits - 1 / (2 * self.beta)) ** 2 | |
elif self.loss_type == "bco_pair": | |
chosen_logratios = chosen_logps - ref_chosen_logps | |
rejected_logratios = rejected_logps - ref_rejected_logps | |
chosen_rewards = self.beta * chosen_logratios | |
rejected_rewards = self.beta * rejected_logratios | |
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() | |
self.running.update(rewards) | |
delta = self.running.mean | |
losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( | |
-(self.beta * rejected_logratios - delta) | |
) | |
elif self.loss_type == "sppo_hard": | |
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, | |
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. | |
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is | |
# set to 1 for the winner and 0 for the loser. | |
a = chosen_logps - ref_chosen_logps | |
b = rejected_logps - ref_rejected_logps | |
losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 | |
elif self.loss_type == "nca_pair": | |
chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta | |
rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta | |
losses = ( | |
-F.logsigmoid(chosen_rewards) | |
- 0.5 * F.logsigmoid(-chosen_rewards) | |
- 0.5 * F.logsigmoid(-rejected_rewards) | |
) | |
elif self.loss_type == "aot_pair": | |
chosen_logratios = chosen_logps - ref_chosen_logps | |
rejected_logratios = rejected_logps - ref_rejected_logps | |
chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) | |
rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) | |
delta = chosen_logratios_sorted - rejected_logratios_sorted | |
losses = ( | |
-F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) | |
- F.logsigmoid(-self.beta * delta) * self.label_smoothing | |
) | |
elif self.loss_type == "aot": | |
logratios = chosen_logps - rejected_logps | |
ref_logratios = ref_chosen_logps - ref_rejected_logps | |
logratios_sorted, _ = torch.sort(logratios, dim=0) | |
ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) | |
delta = logratios_sorted - ref_logratios_sorted | |
losses = ( | |
-F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) | |
- F.logsigmoid(-self.beta * delta) * self.label_smoothing | |
) | |
elif self.loss_type == "apo_zero": | |
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) | |
# Use this loss when you believe the chosen outputs are better than your model's default output | |
losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood | |
losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood | |
losses = losses_chosen + losses_rejected | |
elif self.loss_type == "apo_down": | |
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) | |
# Use this loss when you believe the chosen outputs are worse than your model's default output. | |
# Decrease chosen likelihood and decrease rejected likelihood more | |
losses_chosen = F.sigmoid(self.beta * chosen_logratios) | |
losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) | |
losses = losses_chosen + losses_rejected | |
elif self.loss_type == "discopop": | |
# Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) | |
# This loss was discovered with LLM discovery | |
logratios = chosen_logps - rejected_logps | |
ref_logratios = ref_chosen_logps - ref_rejected_logps | |
logits = logratios - ref_logratios | |
logits = logits * self.beta | |
# Modulate the mixing coefficient based on the log ratio magnitudes | |
log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) | |
logistic_component = -F.logsigmoid(logits) | |
exp_component = torch.exp(-logits) | |
# Blend between logistic and exponential component based on log ratio modulation | |
losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation | |
else: | |
raise ValueError( | |
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " | |
"'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', 'apo_down']" | |
) | |
chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() | |
rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() | |
return losses, chosen_rewards, rejected_rewards | |
def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): | |
unwrapped_model = self.accelerator.unwrap_model(model) | |
concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) | |
model_kwargs = {} | |
if self.aux_loss_enabled: | |
model_kwargs["output_router_logits"] = True | |
# Add the pixel values and attention masks for vision models | |
if "pixel_values" in concatenated_batch: | |
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] | |
if "pixel_attention_mask" in concatenated_batch: | |
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] | |
if "image_sizes" in concatenated_batch: | |
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] | |
prompt_attention_mask = concatenated_batch["prompt_attention_mask"] | |
completion_attention_mask = concatenated_batch["completion_attention_mask"] | |
if self.is_encoder_decoder: | |
# 1. Get encoder outputs | |
encoder_outputs = unwrapped_model.get_encoder()( | |
concatenated_batch["prompt_input_ids"], | |
attention_mask=concatenated_batch["prompt_attention_mask"], | |
return_dict=True, | |
) | |
# 2. Prepare decoder inputs | |
decoder_input_ids = shift_tokens_right( | |
concatenated_batch["completion_input_ids"], | |
unwrapped_model.config.decoder_start_token_id, | |
) | |
# 3. Get decoder outputs | |
decoder_outputs = unwrapped_model.get_decoder()( | |
input_ids=decoder_input_ids, | |
attention_mask=concatenated_batch["completion_attention_mask"], | |
encoder_hidden_states=encoder_outputs.last_hidden_state, | |
encoder_attention_mask=concatenated_batch["prompt_attention_mask"], | |
use_cache=False, | |
) | |
hidden_states = decoder_outputs.last_hidden_state | |
ref_hidden_states = None | |
if not self.reference_free and self.ref_model is not None: | |
unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) | |
ref_encoder_outputs = unwrapped_ref_model.get_encoder()( | |
concatenated_batch["prompt_input_ids"], | |
attention_mask=concatenated_batch["prompt_attention_mask"], | |
return_dict=True, | |
) | |
ref_decoder_outputs = unwrapped_ref_model.get_decoder()( | |
input_ids=decoder_input_ids, | |
attention_mask=concatenated_batch["completion_attention_mask"], | |
encoder_hidden_states=ref_encoder_outputs.last_hidden_state, | |
encoder_attention_mask=concatenated_batch["prompt_attention_mask"], | |
use_cache=False, | |
) | |
ref_hidden_states = ref_decoder_outputs.last_hidden_state | |
elif not self.reference_free: | |
with self.null_ref_context(): | |
ref_encoder_outputs = unwrapped_model.get_encoder()( | |
concatenated_batch["prompt_input_ids"], | |
attention_mask=concatenated_batch["prompt_attention_mask"], | |
return_dict=True, | |
) | |
ref_decoder_outputs = unwrapped_model.get_decoder()( | |
input_ids=decoder_input_ids, | |
attention_mask=concatenated_batch["completion_attention_mask"], | |
encoder_hidden_states=ref_encoder_outputs.last_hidden_state, | |
encoder_attention_mask=concatenated_batch["prompt_attention_mask"], | |
use_cache=False, | |
) | |
ref_hidden_states = ref_decoder_outputs.last_hidden_state | |
labels = concatenated_batch["completion_input_ids"] | |
loss_mask = completion_attention_mask.bool() | |
else: | |
# For decoder-only models | |
input_ids = torch.cat( | |
(concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 | |
) | |
attention_mask = torch.cat( | |
(concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), | |
dim=1, | |
) | |
# Mask the prompt but not the completion for the loss | |
loss_mask = torch.cat( | |
(torch.zeros_like(prompt_attention_mask), completion_attention_mask), | |
dim=1, | |
) | |
# Flush and truncate | |
if self.max_length is not None and self.max_length < attention_mask.size(1): | |
if self.truncation_mode == "keep_start": | |
# Flush left to reduce the memory usage | |
# [[0, 0, x, x, x, x], -> [[x, x, x, x], | |
# [0, x, x, x, 0, 0]] [x, x, x, 0]] | |
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
attention_mask = attention_mask[:, : self.max_length] | |
input_ids = input_ids[:, : self.max_length] | |
loss_mask = loss_mask[:, : self.max_length] | |
elif self.truncation_mode == "keep_end": | |
# Flush right before truncating left, then flush left | |
# [[0, 0, x, x, x, x], -> [[0, 0, x, x], | |
# [0, x, x, x, 0, 0]] [0, x, x, x]] | |
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) | |
input_ids = input_ids[:, -self.max_length :] | |
attention_mask = attention_mask[:, -self.max_length :] | |
loss_mask = loss_mask[:, -self.max_length :] | |
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
else: | |
raise ValueError( | |
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " | |
"'keep_start']." | |
) | |
else: | |
# Flush left to reduce the memory usage | |
# [[0, 0, x, x, x, x], -> [[x, x, x, x], | |
# [0, x, x, x, 0, 0]] [x, x, x, 0]] | |
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
# Add logits_to_keep optimization | |
if self.use_logits_to_keep: | |
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() | |
logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 | |
model_kwargs["logits_to_keep"] = logits_to_keep | |
model_kwargs["output_hidden_states"] = True | |
# Add padding-free training support | |
if self.padding_free: | |
input_ids = input_ids[attention_mask.bool()].unsqueeze(0) | |
loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) | |
position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 | |
model_kwargs["position_ids"] = position_ids | |
else: | |
model_kwargs["attention_mask"] = attention_mask | |
# Get the base model outputs (before LM head) | |
if hasattr(unwrapped_model, "get_decoder"): | |
base_model = unwrapped_model.get_decoder() | |
else: | |
base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model) | |
outputs = base_model( | |
input_ids, | |
use_cache=False, | |
**model_kwargs, | |
) | |
hidden_states = outputs.last_hidden_state[:, :-1] | |
# Get reference hidden states if needed | |
ref_hidden_states = None | |
if not self.reference_free and self.ref_model is not None: | |
unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) | |
if hasattr(unwrapped_ref_model, "get_decoder"): | |
ref_base_model = unwrapped_ref_model.get_decoder() | |
else: | |
ref_base_model = getattr( | |
unwrapped_ref_model, self.args.base_model_attribute_name, unwrapped_ref_model | |
) | |
ref_outputs = ref_base_model( | |
input_ids, | |
use_cache=False, | |
**model_kwargs, | |
) | |
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] | |
elif not self.reference_free: | |
if hasattr(unwrapped_model, "get_decoder"): | |
ref_base_model = unwrapped_model.get_decoder() | |
else: | |
ref_base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model) | |
with self.null_ref_context(): | |
ref_outputs = ref_base_model( | |
input_ids, | |
attention_mask=attention_mask, | |
use_cache=False, | |
**model_kwargs, | |
) | |
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] | |
masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id) | |
labels = masked_input_ids[:, 1:] # Shift right for casual LM | |
# Get the LM head | |
lm_head = unwrapped_model.get_output_embeddings() | |
# Get reference model weights if needed | |
ref_weight = None | |
ref_bias = None | |
if not self.reference_free: | |
if self.ref_model is not None: | |
unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) | |
ref_lm_head = unwrapped_ref_model.get_output_embeddings() | |
else: | |
with self.null_ref_context(): | |
ref_lm_head = unwrapped_model.get_output_embeddings() | |
ref_weight = ref_lm_head.weight | |
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None | |
# Compute loss using Liger kernel | |
loss_output = self.dpo_loss_fn( | |
lm_head.weight, | |
hidden_states, | |
labels, | |
bias=lm_head.bias if hasattr(lm_head, "bias") else None, | |
ref_input=ref_hidden_states if not self.reference_free else None, | |
ref_weight=ref_weight if not self.reference_free else None, | |
ref_bias=ref_bias if not self.reference_free else None, | |
) | |
( | |
loss, | |
(chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), | |
) = loss_output | |
output = { | |
"loss": loss, | |
"chosen_logps": chosen_logps, | |
"rejected_logps": rejected_logps, | |
"mean_chosen_logits": chosen_logits_mean, | |
"mean_rejected_logits": rejected_logits_mean, | |
"nll_loss": nll_loss, | |
"chosen_rewards": aux_outputs[0], | |
"rejected_rewards": aux_outputs[1], | |
} | |
if self.aux_loss_enabled: | |
output["aux_loss"] = outputs.aux_loss | |
return output | |
def concatenated_forward( | |
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False | |
): | |
""" | |
Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. | |
We do this to avoid doing two forward passes, because it's faster for FSDP. | |
Args: | |
model: | |
Model to run the forward pass on. | |
batch: | |
Batch of input data. | |
is_ref_model: | |
Whether this method is being called for the reference model. If `True`, length desensitization is not | |
applied. | |
""" | |
num_examples = batch["prompt_input_ids"].shape[0] | |
concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) | |
model_kwargs = {"use_cache": False} | |
if self.aux_loss_enabled: | |
model_kwargs["output_router_logits"] = True | |
# Add the pixel values and attention masks for vision models | |
if "pixel_values" in concatenated_batch: | |
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] | |
if "pixel_attention_mask" in concatenated_batch: | |
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] | |
if "image_sizes" in concatenated_batch: | |
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] | |
prompt_input_ids = concatenated_batch["prompt_input_ids"] | |
prompt_attention_mask = concatenated_batch["prompt_attention_mask"] | |
completion_input_ids = concatenated_batch["completion_input_ids"] | |
completion_attention_mask = concatenated_batch["completion_attention_mask"] | |
if self.is_encoder_decoder: | |
labels = completion_input_ids | |
labels[completion_attention_mask == 0] = self.label_pad_token_id | |
outputs = model( | |
input_ids=prompt_input_ids, | |
attention_mask=prompt_attention_mask, | |
labels=labels, # we need the labels for the logits to be returned | |
**model_kwargs, | |
) | |
logits = outputs.logits | |
loss_mask = completion_attention_mask.bool() | |
else: | |
# Concatenate the prompt and completion inputs | |
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) | |
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) | |
# Mask the prompt but not the completion for the loss | |
loss_mask = torch.cat( | |
(torch.zeros_like(prompt_attention_mask), completion_attention_mask), | |
dim=1, | |
) | |
# Flush and truncate | |
if self.max_length is not None and self.max_length < attention_mask.size(1): | |
if self.truncation_mode == "keep_start": | |
# Flush left to reduce the memory usage | |
# [[0, 0, x, x, x, x], -> [[x, x, x, x], | |
# [0, x, x, x, 0, 0]] [x, x, x, 0]] | |
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
attention_mask = attention_mask[:, : self.max_length] | |
input_ids = input_ids[:, : self.max_length] | |
loss_mask = loss_mask[:, : self.max_length] | |
elif self.truncation_mode == "keep_end": | |
# Flush right before truncating left, then flush left | |
# [[0, 0, x, x, x, x], -> [[0, 0, x, x], | |
# [0, x, x, x, 0, 0]] [0, x, x, x]] | |
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) | |
input_ids = input_ids[:, -self.max_length :] | |
attention_mask = attention_mask[:, -self.max_length :] | |
loss_mask = loss_mask[:, -self.max_length :] | |
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
else: | |
raise ValueError( | |
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " | |
"'keep_start']." | |
) | |
else: | |
# Flush left to reduce the memory usage | |
# [[0, 0, x, x, x, x], -> [[x, x, x, x], | |
# [0, x, x, x, 0, 0]] [x, x, x, 0]] | |
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) | |
if self.use_logits_to_keep: | |
# Compute logits_to_keep based on loss_mask pattern: | |
# [[0, 0, 0, x, x, x, x], | |
# [0, 0, 0, x, x, x, 0]] | |
# ^ start computing logits from here ([:, -(7-3+1):]) | |
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() | |
logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label | |
model_kwargs["logits_to_keep"] = logits_to_keep | |
model_kwargs["output_hidden_states"] = True | |
if self.padding_free: | |
# Flatten the input_ids, position_ids, and loss_mask | |
# input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] | |
# [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] | |
input_ids = input_ids[attention_mask.bool()].unsqueeze(0) | |
loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) | |
position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 | |
model_kwargs["position_ids"] = position_ids | |
else: | |
model_kwargs["attention_mask"] = attention_mask | |
outputs = model(input_ids, **model_kwargs) | |
logits = outputs.logits | |
# Offset the logits by one to align with the labels | |
labels = torch.roll(input_ids, shifts=-1, dims=1) | |
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() | |
if self.use_logits_to_keep: | |
# Align labels with logits | |
# logits: -, -, [x2, x3, x4, x5, x6] | |
# ^ --------- ^ after logits[:, :-1, :] | |
# labels: [y0, y1, y2, y3, y4, y5, y6] | |
# ^ --------- ^ with logits_to_keep=4, [:, -4:] | |
# loss_mask: [0, 0, 0, 1, 1, 1, 1] | |
labels = labels[:, -logits_to_keep:] | |
loss_mask = loss_mask[:, -logits_to_keep:] | |
if logits.shape[:2] != labels.shape[:2]: | |
# for llava, the returned logits include the image tokens (placed before the text tokens) | |
seq_len = labels.shape[1] | |
logits = logits[:, -seq_len:] | |
# Compute the log probabilities of the labels | |
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later | |
per_token_logps = selective_log_softmax(logits, labels) | |
per_token_logps[~loss_mask] = 0 | |
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) | |
if self.padding_free: | |
# Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) | |
batch_size, seq_len = attention_mask.shape | |
per_token_logps_ = torch.zeros( | |
batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype | |
) | |
per_token_logps_[attention_mask.bool()] = per_token_logps | |
per_token_logps = per_token_logps_ | |
all_logps = per_token_logps[:, 1:].sum(-1) | |
output = {} | |
if self.use_weighting: | |
with torch.no_grad(): | |
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 | |
logprobs = F.log_softmax(logits, dim=-1) | |
weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space | |
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor | |
all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) | |
chosen_weights = all_weights[:num_examples] | |
rejected_weights = all_weights[num_examples:] | |
output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) | |
if self.args.rpo_alpha is not None: | |
# Only use the chosen logits for the RPO loss | |
chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] | |
chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples] | |
# Compute the log probabilities of the labels | |
output["nll_loss"] = F.cross_entropy( | |
torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 | |
) | |
if self.loss_type == "ipo": | |
all_logps = all_logps / loss_mask.sum(-1) | |
if self.args.ld_alpha is not None and not is_ref_model: | |
# Compute response lengths based on loss_mask | |
completion_lengths = loss_mask.sum(dim=1) | |
chosen_lengths = completion_lengths[:num_examples] | |
rejected_lengths = completion_lengths[num_examples:] | |
public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper | |
public_lengths = torch.cat([public_lengths, public_lengths], dim=0) | |
seq_len = per_token_logps.size(1) | |
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) | |
ld_mask = position_ids < public_lengths.unsqueeze(1) | |
mask = position_ids < completion_lengths.unsqueeze(1) | |
front_mask = (ld_mask & mask).float() | |
rear_mask = (~ld_mask & mask).float() | |
front_logps = (per_token_logps * front_mask).sum(dim=1) | |
rear_logps = (per_token_logps * rear_mask).sum(dim=1) | |
all_logps = front_logps + self.args.ld_alpha * rear_logps | |
output["chosen_logps"] = all_logps[:num_examples] | |
output["rejected_logps"] = all_logps[num_examples:] | |
# Compute the mean logits | |
if self.padding_free: | |
# position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). | |
# There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, | |
# and the second half to the rejected tokens. | |
# To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. | |
split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] | |
mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() | |
mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() | |
else: | |
mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() | |
mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() | |
output["mean_chosen_logits"] = mean_chosen_logits | |
output["mean_rejected_logits"] = mean_rejected_logits | |
if self.aux_loss_enabled: | |
output["aux_loss"] = outputs.aux_loss | |
return output | |
def get_batch_loss_metrics( | |
self, | |
model, | |
batch: dict[str, Union[list, torch.LongTensor]], | |
train_eval: Literal["train", "eval"] = "train", | |
): | |
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" | |
metrics = {} | |
if self.args.use_liger_loss: | |
model_output = self._compute_loss_liger(model, batch) | |
losses = model_output["loss"] | |
chosen_rewards = model_output["chosen_rewards"] | |
rejected_rewards = model_output["rejected_rewards"] | |
else: | |
model_output = self.concatenated_forward(model, batch) | |
# if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model | |
if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: | |
ref_chosen_logps = batch["ref_chosen_logps"] | |
ref_rejected_logps = batch["ref_rejected_logps"] | |
else: | |
ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) | |
losses, chosen_rewards, rejected_rewards = self.dpo_loss( | |
model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps | |
) | |
reward_accuracies = (chosen_rewards > rejected_rewards).float() | |
if self.args.rpo_alpha is not None: | |
losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper | |
if self.use_weighting: | |
losses = losses * model_output["policy_weights"] | |
if self.aux_loss_enabled: | |
losses = losses + self.aux_loss_coef * model_output["aux_loss"] | |
prefix = "eval_" if train_eval == "eval" else "" | |
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() | |
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() | |
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() | |
metrics[f"{prefix}rewards/margins"] = ( | |
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() | |
) | |
metrics[f"{prefix}logps/chosen"] = ( | |
self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() | |
) | |
metrics[f"{prefix}logps/rejected"] = ( | |
self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() | |
) | |
metrics[f"{prefix}logits/chosen"] = ( | |
self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() | |
) | |
metrics[f"{prefix}logits/rejected"] = ( | |
self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() | |
) | |
if self.args.rpo_alpha is not None: | |
metrics[f"{prefix}nll_loss"] = ( | |
self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() | |
) | |
if self.aux_loss_enabled: | |
metrics[f"{prefix}aux_loss"] = ( | |
self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() | |
) | |
return losses.mean(), metrics | |
def compute_loss( | |
self, | |
model: Union[PreTrainedModel, nn.Module], | |
inputs: dict[str, Union[torch.Tensor, Any]], | |
return_outputs=False, | |
num_items_in_batch=None, | |
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: | |
compute_loss_context_manager = ( | |
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() | |
) | |
with compute_loss_context_manager: | |
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") | |
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: | |
loss = loss.to(self.args.device) | |
# force log the metrics | |
self.store_metrics(metrics, train_eval="train") | |
if return_outputs: | |
return loss, metrics | |
return loss | |
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: | |
"""Generate samples from the model and reference model for the given batch of inputs.""" | |
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with | |
# the torch amp context manager as some hidden states are silently casted to full precision. | |
generate_context_manager = ( | |
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() | |
) | |
with generate_context_manager: | |
policy_output = model.generate( | |
input_ids=batch["prompt_input_ids"], | |
attention_mask=batch["prompt_attention_mask"], | |
max_length=self.max_length, | |
do_sample=True, | |
pad_token_id=self.padding_value, | |
) | |
# if ref_output in batch use that otherwise use the reference model | |
if "ref_output" in batch: | |
ref_output = batch["ref_output"] | |
else: | |
if self.ref_model is None: | |
with self.null_ref_context(): | |
ref_output = self.model.generate( | |
input_ids=batch["prompt_input_ids"], | |
attention_mask=batch["prompt_attention_mask"], | |
max_length=self.max_length, | |
do_sample=True, | |
pad_token_id=self.padding_value, | |
) | |
else: | |
ref_output = self.ref_model.generate( | |
input_ids=batch["prompt_input_ids"], | |
attention_mask=batch["prompt_attention_mask"], | |
max_length=self.max_length, | |
do_sample=True, | |
pad_token_id=self.padding_value, | |
) | |
policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) | |
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) | |
ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) | |
ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) | |
return policy_output_decoded, ref_output_decoded | |
def prediction_step( | |
self, | |
model: Union[PreTrainedModel, nn.Module], | |
inputs: dict[str, Union[torch.Tensor, Any]], | |
prediction_loss_only: bool, | |
ignore_keys: Optional[list[str]] = None, | |
): | |
if ignore_keys is None: | |
if hasattr(model, "config"): | |
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) | |
else: | |
ignore_keys = [] | |
prediction_context_manager = ( | |
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() | |
) | |
with torch.no_grad(), prediction_context_manager: | |
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") | |
# force log the metrics | |
self.store_metrics(metrics, train_eval="eval") | |
if prediction_loss_only: | |
return loss.detach(), None, None | |
# logits for the chosen and rejected samples from model | |
logits_dict = { | |
"eval_logits/chosen": metrics["eval_logits/chosen"], | |
"eval_logits/rejected": metrics["eval_logits/rejected"], | |
} | |
logits = [v for k, v in logits_dict.items() if k not in ignore_keys] | |
logits = torch.tensor(logits, device=self.accelerator.device) | |
labels = torch.zeros(logits.shape[0], device=self.accelerator.device) | |
return (loss.detach(), logits, labels) | |
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: | |
for key, value in metrics.items(): | |
self._stored_metrics[train_eval][key].append(value) | |
def evaluation_loop( | |
self, | |
dataloader: DataLoader, | |
description: str, | |
prediction_loss_only: Optional[bool] = None, | |
ignore_keys: Optional[list[str]] = None, | |
metric_key_prefix: str = "eval", | |
) -> EvalLoopOutput: | |
""" | |
Overriding built-in evaluation loop to store metrics for each batch. | |
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. | |
Works both with or without labels. | |
""" | |
# Sample and save to game log if requested (for one batch to save time) | |
if self.generate_during_eval: | |
# Generate random indices within the range of the total number of samples | |
num_samples = len(dataloader.dataset) | |
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) | |
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader | |
random_batch_dataset = dataloader.dataset.select(random_indices) | |
random_batch = self.data_collator(random_batch_dataset) | |
random_batch = self._prepare_inputs(random_batch) | |
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) | |
table = pd.DataFrame( | |
columns=["Prompt", "Policy", "Ref Model"], | |
data=[ | |
[prompt, pol[len(prompt) :], ref[len(prompt) :]] | |
for prompt, pol, ref in zip( | |
random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded | |
) | |
], | |
) | |
if "wandb" in self.args.report_to and self.accelerator.is_main_process: | |
wandb.log({"game_log": wandb.Table(data=table)}) | |
if "comet_ml" in self.args.report_to: | |
log_table_to_comet_experiment( | |
name="game_log.csv", | |
table=table, | |
) | |
# Base evaluation | |
initial_output = super().evaluation_loop( | |
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix | |
) | |
return initial_output | |
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: | |
""" | |
Log `logs` on the various objects watching training, including stored metrics. | |
Args: | |
logs (`dict[str, float]`): | |
The values to log. | |
start_time (`float` or `None`, *optional*, defaults to `None`): | |
Start time of the training. | |
""" | |
# logs either has 'loss' or 'eval_loss' | |
train_eval = "train" if "loss" in logs else "eval" | |
# Add averaged stored metrics to logs | |
for key, metrics in self._stored_metrics[train_eval].items(): | |
logs[key] = torch.tensor(metrics).mean().item() | |
del self._stored_metrics[train_eval] | |
return super().log(logs, start_time) | |
# Ensure the model card is saved along with the checkpoint | |
def _save_checkpoint(self, model, trial): | |
if self.args.hub_model_id is None: | |
model_name = Path(self.args.output_dir).name | |
else: | |
model_name = self.args.hub_model_id.split("/")[-1] | |
self.create_model_card(model_name=model_name) | |
super()._save_checkpoint(model, trial) | |
def create_model_card( | |
self, | |
model_name: Optional[str] = None, | |
dataset_name: Optional[str] = None, | |
tags: Union[str, list[str], None] = None, | |
): | |
""" | |
Creates a draft of a model card using the information available to the `Trainer`. | |
Args: | |
model_name (`str` or `None`, *optional*, defaults to `None`): | |
Name of the model. | |
dataset_name (`str` or `None`, *optional*, defaults to `None`): | |
Name of the dataset used for training. | |
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): | |
Tags to be associated with the model card. | |
""" | |
if not self.is_world_process_zero(): | |
return | |
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): | |
base_model = self.model.config._name_or_path | |
else: | |
base_model = None | |
tags = tags or set() | |
if isinstance(tags, str): | |
tags = {tags} | |
if hasattr(self.model.config, "unsloth_version"): | |
tags.add("unsloth") | |
tags.update(self._tag_names) | |
citation = textwrap.dedent( | |
"""\ | |
@inproceedings{rafailov2023direct, | |
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, | |
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, | |
year = 2023, | |
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, | |
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, | |
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, | |
}""" | |
) | |
model_card = generate_model_card( | |
base_model=base_model, | |
model_name=model_name, | |
hub_model_id=self.hub_model_id, | |
dataset_name=dataset_name, | |
tags=tags, | |
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, | |
comet_url=get_comet_experiment_url(), | |
trainer_name="DPO", | |
trainer_citation=citation, | |
paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model", | |
paper_id="2305.18290", | |
) | |
model_card.save(os.path.join(self.args.output_dir, "README.md")) | |