Spaces:
Paused
Paused
File size: 19,260 Bytes
2f5127c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 |
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import warnings
from collections import defaultdict
from dataclasses import FrozenInstanceError, replace
from pathlib import Path
from typing import Any, Callable, Optional, Union
import pandas as pd
import torch
import torch.nn as nn
from accelerate import PartialState
from accelerate.utils import gather_object
from datasets import Dataset
from transformers import (
BaseImageProcessor,
DataCollator,
FeatureExtractionMixin,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
is_wandb_available,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available, is_rich_available
from ..data_utils import maybe_apply_chat_template
from .reward_config import RewardConfig
from .utils import (
RewardDataCollatorWithPadding,
compute_accuracy,
decode_and_strip_padding,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
print_rich_table,
)
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
import wandb
def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]:
"""Tokenize a batch from a reward modelling dataset."""
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
class RewardTrainer(Trainer):
_tag_names = ["trl", "reward-trainer"]
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
args: Optional[RewardConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
None,
None,
),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
):
"""
Initialize RewardTrainer.
Args:
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
args (`RewardConfig`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
Processing class used to process the data. If provided, will be used to automatically process the inputs
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
reuse the fine-tuned model.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
callbacks (`list[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
if not isinstance(model, PeftModel):
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
warnings.warn(
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
UserWarning,
)
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
model = get_peft_model(model, peft_config)
# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)
if compute_metrics is None:
compute_metrics = compute_accuracy
if data_collator is None:
if processing_class is None:
raise ValueError(
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
)
max_length = args.max_length
data_collator = RewardDataCollatorWithPadding(processing_class)
if args.remove_unused_columns:
try: # for bc before https://github.com/huggingface/transformers/pull/25435
args.remove_unused_columns = False
except FrozenInstanceError:
args = replace(args, remove_unused_columns=False)
# warn users
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_reward_data_collator = True
else:
self.use_reward_data_collator = False
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
# issued.
model.warnings_issued["estimate_tokens"] = True
if "input_ids_chosen" not in train_dataset.column_names:
with PartialState().main_process_first():
fn_kwargs = {"tokenizer": processing_class}
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
train_dataset = train_dataset.map(
_tokenize,
batched=True,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
)
# This filter is important because otherwise you get samples that exceed the model's context length and
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
# user might get surprised if N samples are missing from training.
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
num_proc=args.dataset_num_proc,
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
)
eval_dataset = eval_dataset.map(
_tokenize,
fn_kwargs=fn_kwargs,
batched=True,
num_proc=args.dataset_num_proc,
)
# This filter is important because otherwise you get samples that exceed the model's context length and
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
# user might get surprised if N samples are missing from training.
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= max_length
and len(x["input_ids_rejected"]) <= max_length,
num_proc=args.dataset_num_proc,
)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
rewards_chosen = model(
input_ids=inputs["input_ids_chosen"],
attention_mask=inputs["attention_mask_chosen"],
return_dict=True,
)["logits"]
rewards_rejected = model(
input_ids=inputs["input_ids_rejected"],
attention_mask=inputs["attention_mask_rejected"],
return_dict=True,
)["logits"]
# calculate loss, optionally modulate with margin
if "margin" in inputs:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
else:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
if self.args.center_rewards_coefficient is not None:
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
if return_outputs:
return loss, {
"rewards_chosen": rewards_chosen,
"rewards_rejected": rewards_rejected,
}
return loss
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad():
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
if prediction_loss_only:
return (loss, None, None)
loss = loss.detach()
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
logits = nested_detach(logits)
# Stack accepted against rejected, mean over logits
# and softmax to get preferences between accepted and rejected to sum to 1
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
labels = torch.zeros(logits.shape[0])
labels = self._prepare_inputs(labels)
return loss, logits, labels
def evaluate(self, *args, **kwargs):
num_print_samples = kwargs.pop("num_print_samples", 4)
self.visualize_samples(num_print_samples)
return super().evaluate(*args, **kwargs)
def visualize_samples(self, num_print_samples: int):
"""
Visualize the reward model logits prediction
Args:
num_print_samples (`int`, defaults to `4`):
The number of samples to print. Set to `-1` to print all samples.
"""
eval_dataloader = self.get_eval_dataloader()
table = defaultdict(list)
for _, inputs in enumerate(eval_dataloader):
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
table["chosen_text"].extend(gather_object(chosen_text))
table["rejected_text"].extend(gather_object(rejected_text))
table["logits"].extend(
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
)
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
break
df = pd.DataFrame(table)
if self.accelerator.process_index == 0:
if is_rich_available():
print_rich_table(df[:num_print_samples])
if "wandb" in self.args.report_to:
import wandb
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=df,
)
# Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial):
if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
tags = tags or set()
if isinstance(tags, str):
tags = {tags}
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
tags.update(self._tag_names)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Reward",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|