trl-sandbox / trl /trainer /orpo_trainer.py
ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import random
import textwrap
import warnings
from collections import defaultdict
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
from datasets import Dataset
from torch import autocast
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
BaseImageProcessor,
DataCollator,
FeatureExtractionMixin,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
is_comet_available,
is_torch_xla_available,
is_wandb_available,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .orpo_config import ORPOConfig
from .utils import (
DPODataCollatorWithPadding,
add_bos_token_if_needed,
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
import wandb
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
class ORPOTrainer(Trainer):
r"""
Initialize ORPOTrainer.
Args:
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
args (`ORPOConfig`):
The ORPO config arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
Processing class used to process the data. If provided, will be used to automatically process the inputs
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
reuse the fine-tuned model.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
callbacks (`list[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
"""
_tag_names = ["trl", "orpo"]
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[ORPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
):
if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
torch_dtype = model_init_kwargs.get("torch_dtype")
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
model_init_kwargs["torch_dtype"] = torch_dtype
if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_config, we merge and unload it first
if isinstance(model, PeftModel):
model = model.merge_and_unload()
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
elif args.gradient_checkpointing:
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# get peft model with the given config
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True
# For models that use gradient_checkpointing, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif args.gradient_checkpointing:
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve."
)
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif args.is_encoder_decoder is None:
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
else:
self.is_encoder_decoder = args.is_encoder_decoder
if self.is_encoder_decoder:
self.decoder_start_token_id = model.config.decoder_start_token_id
self.pad_token_id = model.config.pad_token_id
if processing_class is None:
raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
if args.max_length is None:
warnings.warn(
"`max_length` is not set in the ORPOConfig's init"
" it will default to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
else:
max_length = args.max_length
if args.max_prompt_length is None:
warnings.warn(
"`max_prompt_length` is not set in the ORPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
else:
max_prompt_length = args.max_prompt_length
if args.max_completion_length is None and self.is_encoder_decoder:
warnings.warn(
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
self.max_completion_length = 128
else:
self.max_completion_length = args.max_completion_length
if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=processing_class.pad_token_id,
label_pad_token_id=args.label_pad_token_id,
is_encoder_decoder=self.is_encoder_decoder,
)
if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.processing_class = processing_class
self.beta = args.beta
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
warnings.warn(
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
"loss.",
UserWarning,
)
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().main_process_first():
# Extract the prompt if needed, and apply the chat template if needed
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
train_dataset = train_dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
)
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
eval_dataset = eval_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
)
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
Reference:
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
"""
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
# Prepare input tokens for token by token comparison
full_input_ids = np.array(full_tokenized["input_ids"])
if len(full_input_ids) != len(full_concat_input_ids):
raise ValueError("Prompt input ids and answer input ids should have the same length.")
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
# can be merged together when tokenizing prompt+answer. This could result
# on the last token from the prompt being different when tokenized on its own
# vs when done as prompt+answer.
response_token_ids_start_idx = len(prompt_input_ids)
# If tokenized prompt is different than both prompt+answer, then it means the
# last token has changed due to merging.
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
response_token_ids_start_idx -= 1
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
if len(prompt_input_ids) != len(prompt_attention_mask):
raise ValueError("Prompt input ids and attention mask should have the same length.")
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
"""Tokenize a single row from a ORPO specific dataset.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
We also create the labels for the chosen/rejected responses, which are of length equal to
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
"""
batch = {}
prompt = feature["prompt"]
chosen = feature["chosen"]
rejected = feature["rejected"]
if not self.is_encoder_decoder:
# Check issues below for more details
# 1. https://github.com/huggingface/trl/issues/907
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
# 3. https://github.com/LianjiaTech/BELLE/issues/337
if not isinstance(prompt, str):
raise ValueError(f"prompt should be an str but got {type(prompt)}")
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
if not isinstance(chosen, str):
raise ValueError(f"chosen should be an str but got {type(chosen)}")
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
if not isinstance(rejected, str):
raise ValueError(f"rejected should be an str but got {type(rejected)}")
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
# Last prompt token might get merged by tokenizer and
# it should not be included for generation if that happens
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
for k, v in prompt_tokens.items():
prompt_tokens[k] = v[:prompt_len_input_ids]
# Make sure prompts only have one different token at most an
# and length only differs by 1 at most
num_diff_tokens = sum(
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
)
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
if num_diff_tokens > 1 or num_diff_len > 1:
raise ValueError(
"Chosen and rejected prompt_input_ids might only differ on the "
"last token due to tokenizer merge ops."
)
# add BOS token to head of prompt. Avoid adding if it's already there
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
self.processing_class.bos_token_id,
prompt_len_input_ids,
prompt_tokens,
chosen_prompt_len_input_ids,
chosen_tokens,
rejected_prompt_len_input_ids,
rejected_tokens,
)
# add EOS token to end of answer. Avoid adding if it's already there
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
)
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
# if combined sequence is too long, truncate the prompt
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
if self.truncation_mode == "keep_start":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
elif self.truncation_mode == "keep_end":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
# if that's still too long, truncate the response
for answer_tokens in [chosen_tokens, rejected_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
for k in ["input_ids", "attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
# Create labels
chosen_sequence_tokens = {
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
}
rejected_sequence_tokens = {
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
self.label_pad_token_id
] * len(chosen_tokens["prompt_input_ids"])
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
self.label_pad_token_id
] * len(rejected_tokens["prompt_input_ids"])
for k, toks in {
"chosen_": chosen_sequence_tokens,
"rejected_": rejected_sequence_tokens,
"": prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}{type_key}"] = tokens
else:
chosen_tokens = self.processing_class(
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
)
rejected_tokens = self.processing_class(
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
)
prompt_tokens = self.processing_class(
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
)
batch["chosen_labels"] = chosen_tokens["input_ids"]
batch["rejected_labels"] = rejected_tokens["input_ids"]
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["rejected_labels"])
)
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["chosen_labels"])
)
if is_torch_xla_available():
# Pad the sequences to global max_length to avoid TorchXLA recompilation
for k in batch:
if "labels" in k or self.is_encoder_decoder:
pad_value = self.label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = self.padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
return batch
@staticmethod
def concatenated_inputs(
batch: dict[str, Union[list, torch.LongTensor]],
is_encoder_decoder: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
) -> dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
is_encoder_decoder: Whether the model is an encoder-decoder model.
label_pad_token_id: The label pad token id.
padding_value: The padding value to use for the concatenated inputs_ids.
device: The device for the concatenated inputs.
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
concatenated_batch = {}
if is_encoder_decoder:
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
else:
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
for k in batch:
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("chosen", "concatenated")
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
for k in batch:
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("rejected", "concatenated")
concatenated_batch[concatenated_key] = torch.cat(
(
concatenated_batch[concatenated_key],
pad_to_length(batch[k], max_length, pad_value=pad_value),
),
dim=0,
).to(device=device)
if is_encoder_decoder:
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
concatenated_batch["concatenated_attention_mask"] = (
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
)
return concatenated_batch
def odds_ratio_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the ORPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
The `log(sigmoid(log_odds_chosen))` for logging purposes.
"""
# Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
)
ratio = F.logsigmoid(log_odds)
losses = self.beta * ratio
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
@staticmethod
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
if not is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
# dummy token; we'll ignore the losses on these tokens later
labels = torch.where(labels == label_pad_token_id, 0, labels)
per_token_logps = selective_log_softmax(logits, labels)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def concatenated_forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
)
len_chosen = batch["chosen_labels"].shape[0]
model_kwargs = (
{
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
}
if self.is_encoder_decoder
else {}
)
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits
def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
# orpo chosen nll loss is computed over the full prompt and response
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1, :]
rejected_logits = all_logits[len_chosen:, :-1, :]
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
policy_rejected_logits.detach().mean()
).mean()
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
policy_chosen_logits.detach().mean()
).mean()
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
if is_torch_xla_available():
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
compute_loss_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
loss = loss.to(self.args.device)
# force log the metrics
self.store_metrics(metrics, train_eval="train")
if return_outputs:
return (loss, metrics)
return loss
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
"""Generate samples from the model and reference model for the given batch of inputs."""
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
with generate_context_manager:
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.processing_class.pad_token_id,
)
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
return policy_output_decoded
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
prediction_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
with torch.no_grad(), prediction_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
# force log the metrics
self.store_metrics(metrics, train_eval="eval")
if prediction_loss_only:
return (loss.detach(), None, None)
# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
logits = torch.tensor(logits, device=self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
return (loss.detach(), logits, labels)
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded = self.generate_from_model(self.model, random_batch)
table = pd.DataFrame(
columns=["Prompt", "Policy"],
data=[
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
],
)
if "wandb" in self.args.report_to:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
# Base evaluation
initial_output = super().evaluation_loop(
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
)
return initial_output
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`dict[str, float]`):
The values to log.
start_time (`float` or `None`, *optional*, defaults to `None`):
Start time of the training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
def _shift_right(self, input_ids):
if self.decoder_start_token_id is None:
raise ValueError(
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.decoder_start_token_id
if self.pad_token_id is None:
raise ValueError("model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
return shifted_input_ids
# Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial):
if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
tags = tags or set()
if isinstance(tags, str):
tags = {tags}
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
tags.update(self._tag_names)
citation = textwrap.dedent("""\
@article{hong2024orpo,
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
author = {Jiwoo Hong and Noah Lee and James Thorne},
year = 2024,
eprint = {arXiv:2403.07691}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="ORPO",
trainer_citation=citation,
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
paper_id="2403.07691",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))