Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import dataclasses | |
import importlib.resources as pkg_resources | |
import json | |
import random | |
import warnings | |
from collections import deque | |
from dataclasses import dataclass, field | |
from importlib.metadata import version | |
from typing import Any, Literal, Optional, Union | |
import datasets | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.data | |
from accelerate import Accelerator, PartialState | |
from accelerate.state import AcceleratorState | |
from huggingface_hub import ModelCard, ModelCardData | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data import IterableDataset | |
from transformers import ( | |
BitsAndBytesConfig, | |
DataCollatorForLanguageModeling, | |
EvalPrediction, | |
GenerationConfig, | |
PreTrainedTokenizerBase, | |
TrainerState, | |
TrainingArguments, | |
is_comet_available, | |
) | |
from transformers.utils import ( | |
is_peft_available, | |
is_rich_available, | |
is_torch_mlu_available, | |
is_torch_npu_available, | |
is_torch_xpu_available, | |
) | |
from ..trainer.model_config import ModelConfig | |
if is_rich_available(): | |
from rich.console import Console | |
from rich.panel import Panel | |
from rich.table import Table | |
from rich.text import Text | |
if is_comet_available(): | |
import comet_ml | |
if is_peft_available(): | |
from peft import LoraConfig, PeftConfig | |
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): | |
""" | |
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' | |
when they do not come from the assistant. This ensure that the loss is only | |
calculated on the completion made by the assistant. | |
Args: | |
response_template (`Union[str, list[int]]`): the template form that indicates the start of the response, typically something like | |
'### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response | |
differently if it does not have proper context. | |
instruction_template (`Union[str, list[int]]`): the template form that indicates the start of the human instruction, typically something like | |
'### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids. | |
mlm (`bool`, *optional*, defaults to `False`): Whether to use masked language modeling in the underlying | |
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present | |
for flexibility and backwards-compatibility. | |
ignore_index (`int`, *optional*, defaults to `-100`): | |
The index to use to ignore the initial tokens with | |
""" | |
def __init__( | |
self, | |
response_template: Union[str, list[int]], | |
instruction_template: Optional[Union[str, list[int]]] = None, | |
*args, | |
mlm: bool = False, | |
ignore_index: int = -100, | |
padding_free: bool = False, | |
**kwargs, | |
): | |
super().__init__(*args, mlm=mlm, **kwargs) | |
warnings.warn( | |
"This class is deprecated and will be removed in version 0.20.0. To train on completion only, please use " | |
"the parameter `completion_only_loss` of `SFTConfig` instead.", | |
DeprecationWarning, | |
) | |
self.instruction_template = instruction_template | |
if isinstance(instruction_template, str): | |
# The user provides a string, must tokenize | |
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) | |
else: | |
# The user already provides the token ids | |
self.instruction_token_ids = instruction_template | |
self.response_template = response_template | |
if isinstance(response_template, str): | |
# The user provides a string, must tokenize | |
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) | |
else: | |
# The user already provides the token ids | |
self.response_token_ids = response_template | |
if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: | |
warnings.warn( | |
"The pad_token_id and eos_token_id values of this tokenizer are identical. " | |
"If you are planning for multi-turn training, " | |
"it can result in the model continuously generating questions and answers without eos token. " | |
"To avoid this, set the pad_token_id to a different value.", | |
UserWarning, | |
) | |
self.ignore_index = ignore_index | |
self.padding_free = padding_free | |
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: | |
batch = super().torch_call(examples) | |
if self.instruction_template is None: | |
for i in range(len(examples)): | |
response_token_ids_start_idx = None | |
for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: | |
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match | |
if ( | |
self.response_token_ids | |
== batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist() | |
): | |
response_token_ids_start_idx = idx | |
if response_token_ids_start_idx is None: | |
warnings.warn( | |
f"Could not find response key `{self.response_template}` in the following instance: " | |
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " | |
"calculation. Note, if this happens often, consider increasing the `max_length`.", | |
UserWarning, | |
) | |
batch["labels"][i, :] = self.ignore_index | |
else: | |
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids) | |
# Make pytorch loss function ignore all tokens up through the end of the response key | |
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index | |
else: | |
for i in range(len(examples)): | |
response_token_ids_idxs = [] | |
human_token_ids_idxs = [] | |
for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: | |
# find the indexes of the start of a response. | |
if ( | |
self.response_token_ids | |
== batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist() | |
): | |
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids)) | |
if len(response_token_ids_idxs) == 0: | |
warnings.warn( | |
f"Could not find response key `{self.response_template}` in the following instance: " | |
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " | |
"calculation. Note, if this happens often, consider increasing the `max_length`.", | |
UserWarning, | |
) | |
batch["labels"][i, :] = self.ignore_index | |
human_token_ids = self.instruction_token_ids | |
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: | |
# find the indexes of the start of a human answer. | |
if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist(): | |
human_token_ids_idxs.append(human_idx) | |
if len(human_token_ids_idxs) == 0: | |
warnings.warn( | |
f"Could not find instruction key `{self.instruction_template}` in the following instance: " | |
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " | |
"calculation. Note, if this happens often, consider increasing the `max_length`.", | |
UserWarning, | |
) | |
batch["labels"][i, :] = self.ignore_index | |
if ( | |
len(human_token_ids_idxs) > 0 | |
and len(response_token_ids_idxs) > 0 | |
and human_token_ids_idxs[0] > response_token_ids_idxs[0] | |
): | |
human_token_ids_idxs = [0] + human_token_ids_idxs | |
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): | |
# Make pytorch loss function ignore all non response tokens | |
if idx != 0: | |
batch["labels"][i, start:end] = self.ignore_index | |
else: | |
batch["labels"][i, :end] = self.ignore_index | |
if len(response_token_ids_idxs) < len(human_token_ids_idxs): | |
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index | |
if self.padding_free: | |
# remove padding, `attention_mask` and add `position_ids` | |
attn_mask = batch.pop("attention_mask") | |
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0) | |
batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1 | |
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0) | |
batch["labels"][batch["position_ids"] == 0] = self.ignore_index | |
# Calculate cumulative sequence lengths for queries and keys to prevent graph breaks during further computations. | |
flattened_position_ids = batch["position_ids"].flatten() | |
indices_q = torch.arange( | |
flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32 | |
) | |
batch["cu_seq_lens_q"] = torch.cat( | |
( | |
indices_q[flattened_position_ids == 0], | |
torch.tensor( | |
flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32 | |
), | |
) | |
).unsqueeze(0) | |
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] | |
# Determine maximum sequence lengths to prevent graph breaks during further computations. | |
batch["max_length_k"] = torch.tensor([flattened_position_ids.max().item() + 1]) | |
batch["max_length_q"] = batch["max_length_k"] | |
return batch | |
class DataCollatorForChatML: | |
""" | |
Data collator for ChatML format datasets. | |
""" | |
tokenizer: PreTrainedTokenizerBase | |
ignore_index: int = -100 | |
max_length: int = None | |
prompt_key: str = "prompt" | |
messages_key: str = "messages" | |
def __post_init__(self): | |
if self.tokenizer.pad_token_id is None: | |
raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.") | |
if self.max_length is None: | |
# set a sensible default | |
self.max_length = min(self.tokenizer.model_max_length, 1024) | |
def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: | |
input_ids = [] | |
attention_mask = [] | |
prompts_input_ids = [] | |
prompt_attention_mask = [] | |
labels = [] | |
for example in examples: | |
formatted_prompt = example.get(self.prompt_key, None) | |
if formatted_prompt is None: | |
prompt = example[self.messages_key][:-1] | |
formatted_prompt = self.tokenizer.apply_chat_template( | |
prompt, tokenize=False, add_generation_prompt=True | |
) | |
if "input_ids" not in example: | |
message = example[self.messages_key] | |
formatted_message = self.tokenizer.apply_chat_template( | |
message, tokenize=False, add_generation_prompt=False | |
) | |
tokenized_message = self.tokenizer( | |
formatted_message, | |
truncation=True, | |
max_length=self.max_length, | |
padding=False, | |
return_tensors=None, | |
add_special_tokens=False, | |
) | |
input_ids.append(tokenized_message["input_ids"]) | |
if "attention_mask" in example: | |
attention_mask.append(tokenized_message["attention_mask"]) | |
else: | |
attention_mask.append([1] * len(tokenized_message["input_ids"])) | |
else: | |
input_ids.append(example["input_ids"]) | |
if "attention_mask" in example: | |
attention_mask.append(example["attention_mask"]) | |
else: | |
attention_mask.append([1] * len(example["input_ids"])) | |
tokenized_prompt = self.tokenizer( | |
formatted_prompt, | |
truncation=True, | |
max_length=len(input_ids[-1]), | |
padding=False, | |
return_tensors=None, | |
add_special_tokens=False, | |
) | |
prompts_input_ids.append(tokenized_prompt["input_ids"]) | |
prompt_attention_mask.append(tokenized_prompt["attention_mask"]) | |
# Create the labels that will have all but the completion tokens of the example["input_ids"] set to ignore_index | |
label = [self.ignore_index] * len(input_ids[-1]) | |
completion_start_idx = len(tokenized_prompt["input_ids"]) | |
label[completion_start_idx:] = input_ids[-1][completion_start_idx:] | |
labels.append(label) | |
# convert to list of tensors and pad | |
input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] | |
attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] | |
labels = [torch.tensor(label, dtype=torch.long) for label in labels] | |
input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) | |
attention_mask = pad(attention_mask, padding_side="left", padding_value=0) | |
labels = pad(labels, padding_side="left", padding_value=self.ignore_index) | |
prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids] | |
prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask] | |
prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) | |
prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) | |
return { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"labels": labels, | |
"prompts": prompts_input_ids, | |
"prompt_attention_mask": prompt_attention_mask, | |
} | |
class RewardDataCollatorWithPadding: | |
r""" | |
Reward DataCollator class that pads the inputs to the maximum length of the batch. | |
Args: | |
tokenizer (`PreTrainedTokenizerBase`): | |
The tokenizer used for encoding the data. | |
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): | |
padding_strategy to pass to the tokenizer. | |
pad_to_multiple_of (`int` or `None`, `optional`, defaults to `None`): | |
If set will pad the sequence to a multiple of the provided value. | |
return_tensors (`str`, `optional`, defaults to `"pt"`): | |
The tensor type to use. | |
""" | |
tokenizer: PreTrainedTokenizerBase | |
padding: Union[bool, str] = True | |
pad_to_multiple_of: Optional[int] = None | |
return_tensors: str = "pt" | |
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: | |
features_chosen = [] | |
features_rejected = [] | |
margin = [] | |
# check if we have a margin. If we do, we need to batch it as well | |
has_margin = "margin" in features[0] | |
for feature in features: | |
# check if the keys are named as expected | |
if ( | |
"input_ids_chosen" not in feature | |
or "input_ids_rejected" not in feature | |
or "attention_mask_chosen" not in feature | |
or "attention_mask_rejected" not in feature | |
): | |
raise ValueError( | |
"The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`" | |
) | |
features_chosen.append( | |
{ | |
"input_ids": feature["input_ids_chosen"], | |
"attention_mask": feature["attention_mask_chosen"], | |
} | |
) | |
features_rejected.append( | |
{ | |
"input_ids": feature["input_ids_rejected"], | |
"attention_mask": feature["attention_mask_rejected"], | |
} | |
) | |
if has_margin: | |
margin.append(feature["margin"]) | |
batch_chosen = self.tokenizer.pad( | |
features_chosen, | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors=self.return_tensors, | |
) | |
batch_rejected = self.tokenizer.pad( | |
features_rejected, | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors=self.return_tensors, | |
) | |
batch = { | |
"input_ids_chosen": batch_chosen["input_ids"], | |
"attention_mask_chosen": batch_chosen["attention_mask"], | |
"input_ids_rejected": batch_rejected["input_ids"], | |
"attention_mask_rejected": batch_rejected["attention_mask"], | |
"return_loss": True, | |
} | |
if has_margin: | |
margin = torch.tensor(margin, dtype=torch.float) | |
batch["margin"] = margin | |
return batch | |
def pad( | |
tensors: list[torch.Tensor], | |
padding_value: int = 0, | |
padding_side: str = "right", | |
pad_to_multiple_of: Optional[int] = None, | |
) -> torch.Tensor: | |
""" | |
Pads a list of tensors to the same shape along the first dimension. | |
Args: | |
tensors (`list[torch.Tensor]`): | |
List of input tensors to pad. | |
padding_value (`int`): | |
Value to use for padding. Default is 0. | |
padding_side (`str`): | |
Side on which to add padding. Must be 'left' or 'right'. Default is 'right'. | |
pad_to_multiple_of (`int`, *optional*, defaults to `None`): | |
If set will pad the sequence to a multiple of the provided value. | |
Returns: | |
`torch.Tensor`: | |
A single tensor containing the padded tensors. | |
Examples: | |
>>> import torch | |
>>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])]) | |
tensor([[1, 2, 3], | |
[4, 5, 0]]) | |
>>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])]) | |
tensor([[[1, 2], | |
[3, 4]], | |
[[5, 6], | |
[0, 0]]]) | |
""" | |
# Determine the maximum shape for each dimension | |
output_shape = np.max([t.shape for t in tensors], 0).tolist() | |
# Apply pad_to_multiple_of to the first (sequence) dimension | |
if pad_to_multiple_of is not None: | |
remainder = output_shape[0] % pad_to_multiple_of | |
if remainder != 0: | |
output_shape[0] += pad_to_multiple_of - remainder | |
# Create an output tensor filled with the padding value | |
output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) | |
for i, t in enumerate(tensors): | |
if padding_side == "left": | |
seq_start = output_shape[0] - t.shape[0] | |
elif padding_side == "right": | |
seq_start = 0 | |
else: | |
raise ValueError("padding_side must be 'left' or 'right'") | |
# Define the slices | |
seq_slice = slice(seq_start, seq_start + t.shape[0]) | |
slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) | |
output[i][slices] = t | |
return output | |
class DPODataCollatorWithPadding: | |
r""" | |
DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. | |
Args: | |
pad_token_id (`int` defaults to 0): | |
The tokenizer's pad_token_id. | |
label_pad_token_id (`int`, defaults to -100): | |
The label used for masking. | |
is_encoder_decoder (`bool` or `None`, `optional`, defaults to `None`): | |
Whether you model has an encoder_decoder architecture. | |
""" | |
pad_token_id: int = 0 | |
label_pad_token_id: int = -100 | |
is_encoder_decoder: Optional[bool] = False | |
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: | |
# first, pad everything to the same length | |
padded_batch = {} | |
for k in features[0].keys(): | |
if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")): | |
if self.is_encoder_decoder: | |
to_pad = [torch.LongTensor(ex[k]) for ex in features] | |
if (k.startswith("prompt")) and (k.endswith("input_ids")): | |
if self.pad_token_id is None: | |
raise ValueError( | |
"Padding is enabled, but the tokenizer is not configured with a padding token." | |
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" | |
" before calling the trainer." | |
) | |
padding_value = self.pad_token_id | |
elif k.endswith("_attention_mask"): | |
padding_value = 0 | |
elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k): | |
padding_value = self.label_pad_token_id | |
else: | |
raise ValueError(f"Unexpected key in batch '{k}'") | |
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) | |
else: | |
# Set padding value based on the key | |
if k.endswith("_input_ids"): | |
if self.pad_token_id is None: | |
raise ValueError( | |
"Padding is enabled, but the tokenizer is not configured with a padding token." | |
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" | |
" before calling the trainer." | |
) | |
padding_value = self.pad_token_id | |
elif k.endswith("_labels"): | |
padding_value = self.label_pad_token_id | |
elif k.endswith("_attention_mask"): | |
padding_value = 0 | |
elif k.endswith("_pixel_values"): | |
padding_value = 0 # TODO: check if this is correct | |
else: | |
raise ValueError(f"Unexpected key in batch '{k}'") | |
# Set padding side based on the key | |
if k in ["prompt_input_ids", "prompt_attention_mask"]: | |
padding_side = "left" | |
else: | |
padding_side = "right" | |
# Set the dtype | |
if k.endswith("_pixel_values"): | |
dtype = torch.float32 # will be downcasted if necessary by the Trainer | |
else: | |
dtype = torch.int64 | |
# Convert to tensor and pad | |
to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features] | |
padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side) | |
elif k.endswith("_logps"): | |
# the cached reference model logprobs | |
padded_batch[k] = torch.tensor([ex[k] for ex in features]) | |
else: | |
padded_batch[k] = [ex[k] for ex in features] | |
return padded_batch | |
class ConstantLengthDataset(IterableDataset): | |
""" | |
Iterable dataset that returns constant length chunks of tokens from stream of text files. | |
The dataset also formats the text before tokenization with a specific format that is provided | |
by the user. | |
Args: | |
tokenizer (`transformers.PreTrainedTokenizer`): | |
The processor used for processing the data. | |
dataset (`dataset.Dataset`): | |
Dataset with text files. | |
dataset_text_field (`str` or `None`, *optional*, defaults to `None`): | |
Name of the field in the dataset that contains the text. Only one of `dataset_text_field` and | |
`formatting_func` should be provided. | |
formatting_func (`Callable`, *optional*): | |
Function that formats the text before tokenization. Usually it is recommended to follow a certain | |
pattern such as `"### Question: {question} ### Answer: {answer}"`. Only one of `dataset_text_field` and | |
`formatting_func` should be provided. | |
infinite (`bool`, *optional*, defaults to `False`): | |
If True the iterator is reset after dataset reaches end else stops. | |
seq_length (`int`, *optional*, defaults to `1024`): | |
Length of token sequences to return. | |
num_of_sequences (`int`, *optional*, defaults to `1024`): | |
Number of token sequences to keep in buffer. | |
chars_per_token (`int`, *optional*, defaults to `3.6`): | |
Number of characters per token used to estimate number of tokens in text buffer. | |
eos_token_id (`int`, *optional*, defaults to `0`): | |
Id of the end of sequence token if the passed tokenizer does not have an EOS token. | |
shuffle (`bool`, *optional*, defaults to `True`) | |
Shuffle the examples before they are returned | |
append_concat_token (`bool`, *optional*, defaults to `True`) | |
If true, appends `eos_token_id` at the end of each sample being packed. | |
add_special_tokens (`bool`, *optional*, defaults to `True`) | |
If true, tokenizers adds special tokens to each sample being packed. | |
""" | |
def __init__( | |
self, | |
tokenizer, | |
dataset, | |
dataset_text_field=None, | |
formatting_func=None, | |
infinite=False, | |
seq_length=1024, | |
num_of_sequences=1024, | |
chars_per_token=3.6, | |
eos_token_id=0, | |
shuffle=True, | |
append_concat_token=True, | |
add_special_tokens=True, | |
): | |
warnings.warn( | |
"This class is deprecated and will be removed in version 0.20.0. To use packing, use the argument " | |
"`packing` of `SFTConfig` instead.", | |
DeprecationWarning, | |
) | |
self.tokenizer = tokenizer | |
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id | |
self.dataset = dataset | |
self.seq_length = seq_length | |
self.infinite = infinite | |
self.current_size = 0 | |
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences | |
self.shuffle = shuffle | |
self.append_concat_token = append_concat_token | |
self.add_special_tokens = add_special_tokens | |
if dataset_text_field is not None and formatting_func is not None: | |
warnings.warn( | |
"Only one of `dataset_text_field` and `formatting_func` should be provided. " | |
"Ignoring `dataset_text_field` and using `formatting_func`.", | |
UserWarning, | |
) | |
if formatting_func is not None: | |
self.formatting_func = formatting_func | |
elif dataset_text_field is not None: | |
self.formatting_func = lambda x: x[dataset_text_field] | |
else: # neither is provided | |
raise ValueError("Either `dataset_text_field` or `formatting_func` should be provided.") | |
self.pretokenized = False | |
column_names = ( | |
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None | |
) | |
if column_names is not None and "input_ids" in column_names: | |
self.pretokenized = True | |
# since the dataset is tokenized, the unit of buffer size should be tokens | |
self.max_buffer_size = seq_length * num_of_sequences | |
def __len__(self): | |
return len(self.dataset) | |
def __iter__(self): | |
iterator = iter(self.dataset) | |
more_examples = True | |
while more_examples: | |
buffer, buffer_len = [], 0 | |
while True: | |
if buffer_len >= self.max_buffer_size: | |
break | |
try: | |
buffer.append(self.formatting_func(next(iterator))) | |
buffer_len += len(buffer[-1]) | |
except StopIteration: | |
if self.infinite: | |
iterator = iter(self.dataset) | |
else: | |
more_examples = False | |
break | |
if self.shuffle: | |
random.shuffle(buffer) | |
if self.pretokenized: | |
tokenized_inputs = buffer | |
else: | |
tokenized_inputs = self.tokenizer( | |
buffer, add_special_tokens=self.add_special_tokens, truncation=False | |
)["input_ids"] | |
all_token_ids = [] | |
for tokenized_input in tokenized_inputs: | |
if self.append_concat_token: | |
tokenized_input = tokenized_input + [self.concat_token_id] | |
all_token_ids.extend(tokenized_input) | |
examples = [] | |
for i in range(0, len(all_token_ids), self.seq_length): | |
input_ids = all_token_ids[i : i + self.seq_length] | |
if len(input_ids) == self.seq_length: | |
examples.append(input_ids) | |
if self.shuffle: | |
# Shuffle again, otherwise split examples occur in consecutive tensors. | |
random.shuffle(examples) | |
for example in examples: | |
self.current_size += 1 | |
yield { | |
"input_ids": torch.LongTensor(example), | |
"labels": torch.LongTensor(example), | |
} | |
class RunningMoments: | |
""" | |
Calculates the running mean and standard deviation of a data stream. Reference: | |
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 | |
""" | |
accelerator: Accelerator | |
mean: float = 0 | |
std: float = 1 | |
var: float = 1 | |
count: float = 1e-24 | |
def update(self, xs: torch.Tensor) -> tuple[float, float]: | |
""" | |
Updates running moments from batch's moments computed across ranks | |
""" | |
if self.accelerator.use_distributed: | |
xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) | |
else: | |
xs_count = xs.numel() | |
xs_var, xs_mean = torch.var_mean(xs, unbiased=False) | |
xs_mean, xs_var = xs_mean.float(), xs_var.float() | |
delta = xs_mean - self.mean | |
tot_count = self.count + xs_count | |
new_sum = xs_var * xs_count | |
# correct old_sum deviation accounting for the new mean | |
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count | |
tot_sum = old_sum + new_sum | |
self.mean += (delta * xs_count / tot_count).item() | |
new_var = tot_sum / tot_count | |
self.std = (new_var * tot_count / (tot_count - 1)).float().sqrt().item() | |
self.var = new_var.item() | |
self.count = tot_count | |
return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() | |
def save_to_json(self, json_path: str): | |
"""Save the content of this instance in JSON format inside `json_path`.""" | |
# save everything except accelerator | |
if self.accelerator.is_main_process: | |
save_dict = dataclasses.asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if k != "accelerator"}) | |
json_string = json.dumps(save_dict, indent=2, sort_keys=True) + "\n" | |
with open(json_path, "w", encoding="utf-8") as f: | |
f.write(json_string) | |
def load_from_json(cls, accelerator: Accelerator, json_path: str): | |
"""Create an instance from the content of `json_path`.""" | |
# load everything except accelerator | |
with open(json_path, encoding="utf-8") as f: | |
text = f.read() | |
return cls(accelerator=accelerator, **json.loads(text)) | |
def get_global_statistics( | |
accelerator, xs: torch.Tensor, mask=None, device="cpu" | |
) -> tuple[torch.Tensor, torch.Tensor, int]: | |
""" | |
Computes element-wise mean and variance of the tensor across processes. Reference: | |
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 | |
""" | |
xs = xs.to(accelerator.device) | |
sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device) | |
sum_and_count = accelerator.reduce(sum_and_count) | |
global_sum, count = sum_and_count | |
global_mean = global_sum / count | |
sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) | |
sum_var = accelerator.reduce(sum_var) | |
global_var = sum_var / count | |
return global_mean.to(device), global_var.to(device), count.item() | |
def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]: | |
predictions, labels = eval_pred | |
if predictions.ndim == 3: | |
# Token classification task. Shapes are (batch_size, seq_len, num_labels) and (batch_size, seq_len) | |
# Used to compute the accuracy in the prm_trainer. | |
predictions = np.argmax(predictions, axis=2) | |
# Flatten the predictions and labels to remove the ignored tokens. | |
predictions = np.array( | |
[p for prediction, label in zip(predictions, labels) for (p, lbl) in zip(prediction, label) if lbl != -100] | |
) | |
labels = np.array([lbl for label in labels for lbl in label if lbl != -100]) | |
else: | |
# Here, predictions is rewards_chosen and rewards_rejected. Shapes are (batch_size, 2) and (batch_size,) | |
# We want to see how much of the time rewards_chosen > rewards_rejected. | |
equal_mask = predictions[:, 0] == predictions[:, 1] | |
equal_predictions_count = int(equal_mask.sum()) | |
if equal_predictions_count > 0: | |
warnings.warn( | |
f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions " | |
"for both options are equal. These instances are ignored in the accuracy computation.", | |
UserWarning, | |
) | |
# Filter out equal predictions | |
predictions = predictions[~equal_mask] | |
labels = labels[~equal_mask] | |
# Use the remaining predictions for accuracy calculation | |
predictions = np.argmax(predictions, axis=1) | |
accuracy = np.array(predictions == labels, dtype=float).mean().item() | |
return {"accuracy": accuracy} | |
def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: | |
if tensor.size(dim) >= length: | |
return tensor | |
else: | |
pad_size = list(tensor.shape) | |
pad_size[dim] = length - tensor.size(dim) | |
return torch.cat( | |
[ | |
tensor, | |
pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), | |
], | |
dim=dim, | |
) | |
def disable_dropout_in_model(model: torch.nn.Module) -> None: | |
for module in model.modules(): | |
if isinstance(module, torch.nn.Dropout): | |
module.p = 0 | |
def exact_div(a, b, custom_error_message=""): | |
q = a // b | |
if a != q * b: | |
raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}") | |
return q | |
# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5 | |
class PerPromptStatTracker: | |
r""" | |
Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm | |
Args: | |
buffer_size (`int`): | |
Size of the buffer to keep for each prompt. | |
min_count (`int`): | |
Minimum number of samples to keep in the buffer before calculating the mean and std. | |
""" | |
def __init__(self, buffer_size, min_count): | |
self.buffer_size = buffer_size | |
self.min_count = min_count | |
self.stats = {} | |
def update(self, prompts, rewards): | |
prompts = np.array(prompts) | |
rewards = np.array(rewards) | |
unique = np.unique(prompts) | |
advantages = np.empty_like(rewards) | |
for prompt in unique: | |
prompt_rewards = rewards[prompts == prompt] | |
if prompt not in self.stats: | |
self.stats[prompt] = deque(maxlen=self.buffer_size) | |
self.stats[prompt].extend(prompt_rewards) | |
if len(self.stats[prompt]) < self.min_count: | |
mean = np.mean(rewards) | |
std = np.std(rewards) + 1e-6 | |
else: | |
mean = np.mean(self.stats[prompt]) | |
std = np.std(self.stats[prompt]) + 1e-6 | |
advantages[prompts == prompt] = (prompt_rewards - mean) / std | |
return advantages | |
def get_stats(self): | |
return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} | |
def peft_module_casting_to_bf16(model): | |
for name, module in model.named_modules(): | |
if isinstance(module, torch.nn.LayerNorm) or "norm" in name: | |
module = module.to(torch.float32) | |
elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): | |
if hasattr(module, "weight"): | |
if module.weight.dtype == torch.float32: | |
module = module.to(torch.bfloat16) | |
def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]: | |
if model_args.load_in_4bit: | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype` | |
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, | |
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, | |
bnb_4bit_quant_storage=model_args.torch_dtype, | |
) | |
elif model_args.load_in_8bit: | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
) | |
else: | |
quantization_config = None | |
return quantization_config | |
def get_kbit_device_map() -> Optional[dict[str, int]]: | |
if torch.cuda.is_available() or is_torch_xpu_available(): | |
return {"": PartialState().local_process_index} | |
else: | |
return None | |
def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]": | |
if model_args.use_peft is False: | |
return None | |
if not is_peft_available(): | |
raise ValueError( | |
"You need to have PEFT library installed in your environment, make sure to install `peft`. " | |
"Make sure to run `pip install -U peft`." | |
) | |
peft_config = LoraConfig( | |
task_type=model_args.lora_task_type, | |
r=model_args.lora_r, | |
target_modules=model_args.lora_target_modules, | |
lora_alpha=model_args.lora_alpha, | |
lora_dropout=model_args.lora_dropout, | |
bias="none", | |
use_rslora=model_args.use_rslora, | |
use_dora=model_args.use_dora, | |
modules_to_save=model_args.lora_modules_to_save, | |
) | |
return peft_config | |
def get_exp_cap(value, decimal=4): | |
""" | |
Get the exponent cap of a value. This is used to cap the exponent of a value to avoid overflow. | |
The formula is : log(value.dtype.max) | |
E.g. | |
For float32 data type, the maximum exponent value is 88.7228 to 4 decimal points. | |
``` | |
Args: | |
value (`torch.Tensor`): | |
The input tensor to obtain the data type | |
decimal (`int`): | |
The number of decimal points of the output exponent cap. | |
eg: direct calling exp(log(torch.float32.max)) will result in inf | |
so we cap the exponent to 88.7228 to avoid overflow. | |
""" | |
vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max | |
vdtype_log_max = torch.log(vdtype_max).to(value.device) | |
return torch.floor(vdtype_log_max * 10**decimal) / 10**decimal if decimal > 0 else vdtype_log_max | |
def cap_exp(value, cap=-1): | |
# Cap the exponent value below the upper-bound to avoid overflow, before calling torch.exp | |
cap = get_exp_cap(value) if cap < 0 else cap | |
return torch.exp(torch.clamp(value, max=cap)) | |
def print_rich_table(df: pd.DataFrame) -> None: | |
if not is_rich_available(): | |
raise ImportError( | |
"The function `print_rich_table` requires the `rich` library. Please install it with `pip install rich`." | |
) | |
console = Console() | |
table = Table(show_lines=True) | |
for column in df.columns: | |
table.add_column(column) | |
for _, row in df.iterrows(): | |
table.add_row(*row.astype(str).tolist()) | |
console.print(table) | |
SIMPLE_SFT_CHAT_TEMPLATE = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}{{ eos_token }}" | |
# SIMPLE_SFT_CHAT_TEMPLATE simply ends things with an EOS token, this helps the SFT model learn to end the completions with EOS tokens | |
SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" | |
class OnlineTrainerState(TrainerState): | |
episode: int = 0 | |
class OnPolicyConfig(TrainingArguments): | |
r""" | |
Base configuration class for on-policy trainers. | |
This class includes only the parameters that are specific to some on-policy training. For a full list of training | |
arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this | |
class may differ from those in [`~transformers.TrainingArguments`]. | |
Using [`~transformers.HfArgumentParser`] we can turn this class into | |
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the | |
command line. | |
Parameters: | |
run_name (`str` or `None`, *optional*, defaults to `None`): | |
Name of the run. | |
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): | |
Number of processes to use for processing the dataset. | |
num_mini_batches (`int`, *optional*, defaults to `1`): | |
Number of minibatches to split a batch into. | |
total_episodes (`int` or `None`, *optional*, defaults to `None`): | |
Total number of episodes in the dataset. | |
local_rollout_forward_batch_size (`int`, *optional*, defaults to `64`): | |
Per rank no grad forward pass in the rollout phase. | |
num_sample_generations (`int`, *optional*, defaults to `10`): | |
Number of debugging samples generations (i.e., `generate_completions` calls) throughout training. | |
response_length (`int`, *optional*, defaults to `53`): | |
Length of the response. | |
stop_token (`str` or `None`, *optional*, defaults to `None`): | |
Specifies the stop token to use for text generation. This parameter is mutually exclusive with | |
`stop_token_id`. | |
- `None`: No stop token is applied, unless `stop_token_id` is specified. | |
- `'eos'`: Uses the tokenizer's `eos_token`. | |
stop_token_id (`int` or `None`, *optional*, defaults to `None`): | |
Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, | |
unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. | |
temperature (`float`, *optional*, defaults to `0.7`): | |
Sampling temperature. | |
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): | |
Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage | |
to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive | |
value. | |
sft_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): | |
Path to the SFT model. | |
world_size (`int` or `None`, *optional*, defaults to `None`): | |
Number of processes (GPUs) to use for the training. | |
num_total_batches (`int` or `None`, *optional*, defaults to `None`): | |
Number of total batches to train. | |
micro_batch_size (`int` or `None`, *optional*, defaults to `None`): | |
Micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`). | |
local_batch_size (`int` or `None`, *optional*, defaults to `None`): | |
Batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`). | |
batch_size (`int` or `None`, *optional*, defaults to `None`): | |
Batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`). | |
local_mini_batch_size (`int` or `None`, *optional*, defaults to `None`): | |
Mini batch size per GPU. | |
mini_batch_size (`int` or `None`, *optional*, defaults to `None`): | |
Mini batch size across GPUs. | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether to push the model to the Hub after training. | |
""" | |
# Parameters whose default values are overridden from TrainingArguments | |
logging_steps: float = field( | |
default=10, | |
metadata={ | |
"help": ( | |
"Log every X updates steps. Should be an integer or a float in range `[0,1)`. " | |
"If smaller than 1, will be interpreted as ratio of total training steps." | |
) | |
}, | |
) | |
bf16: bool = field( | |
default=True, | |
metadata={ | |
"help": ( | |
"Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " | |
"architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." | |
) | |
}, | |
) | |
run_name: Optional[str] = field( | |
default=None, | |
metadata={"help": "Name of the run."}, | |
) | |
dataset_num_proc: Optional[int] = field( | |
default=None, | |
metadata={"help": "Number of processes to use for processing the dataset."}, | |
) | |
num_mini_batches: int = field( | |
default=1, | |
metadata={"help": "Number of minibatches to split a batch into."}, | |
) | |
total_episodes: Optional[int] = field( | |
default=None, | |
metadata={"help": "Total number of episodes in the dataset."}, | |
) | |
local_rollout_forward_batch_size: int = field( | |
default=64, | |
metadata={"help": "Per rank no grad forward pass in the rollout phase."}, | |
) | |
num_sample_generations: int = field( | |
default=10, | |
metadata={ | |
"help": "Number of debugging samples generations (i.e., `generate_completions` calls) throughout training." | |
}, | |
) | |
response_length: int = field( | |
default=53, | |
metadata={"help": "Length of the response."}, | |
) | |
stop_token: Optional[Literal["eos"]] = field( | |
default=None, | |
metadata={ | |
"help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " | |
"`stop_token_id`." | |
}, | |
) | |
stop_token_id: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " | |
"applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." | |
}, | |
) | |
temperature: float = field( | |
default=0.7, | |
metadata={"help": "Sampling temperature."}, | |
) | |
missing_eos_penalty: Optional[float] = field( | |
default=None, | |
metadata={ | |
"help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to " | |
"encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be " | |
"a positive value." | |
}, | |
) | |
sft_model_path: str = field( | |
default="EleutherAI/pythia-160m", | |
metadata={"help": "Path to the SFT model."}, | |
) | |
world_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "Number of processes (GPUs) to use for the training."}, | |
) | |
num_total_batches: Optional[int] = field( | |
default=None, | |
metadata={"help": "Number of total batches to train."}, | |
) | |
micro_batch_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "Micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)."}, | |
) | |
local_batch_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "Batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)."}, | |
) | |
batch_size: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "Batch size across devices (HF's `per_device_train_batch_size` * `world_size` * " | |
"`gradient_accumulation_steps`)." | |
}, | |
) | |
local_mini_batch_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "Mini batch size per GPU."}, | |
) | |
mini_batch_size: Optional[int] = field( | |
default=None, | |
metadata={"help": "Mini batch size across GPUs."}, | |
) | |
push_to_hub: bool = field( | |
default=False, | |
metadata={"help": "Whether to push the model to the Hub after training."}, | |
) | |
def first_true_indices(bools: torch.Tensor, dtype=torch.long): | |
""" | |
Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving | |
the position of the first True in each "row". | |
Returns the length of the rows (bools.size(-1)) if no element is True in a given row. | |
Args: | |
bools (`torch.Tensor`): | |
An N-dimensional boolean tensor. | |
dtype (`torch.dtype`, optional): | |
The desired data type of the output tensor. Defaults to `torch.long`. | |
Returns: | |
`torch.Tensor`: | |
An (N-1)-dimensional tensor of integers indicating the position of the first True | |
in each row. If no True value is found in a row, returns the length of the row. | |
""" | |
row_len = bools.size(-1) | |
zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) | |
return torch.min(zero_or_index, dim=-1).values | |
def get_reward( | |
model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Computes the reward logits and the rewards for a given model and query responses. | |
Args: | |
model (`torch.nn.Module`): | |
The model used to compute the reward logits. | |
query_responses (`torch.Tensor`): | |
The tensor containing the query responses. | |
pad_token_id (`int`): | |
The token ID representing the pad token. | |
context_length (`int`): | |
The length of the context in the query responses. | |
Returns: | |
tuple: | |
- `reward_logits` (`torch.Tensor`): | |
The logits for the reward model. | |
- `final_rewards` (`torch.Tensor`): | |
The final rewards for each query response. | |
- `sequence_lengths` (`torch.Tensor`): | |
The lengths of the sequences in the query responses. | |
""" | |
attention_mask = query_responses != pad_token_id | |
position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum | |
lm_backbone = getattr(model, model.base_model_prefix) | |
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) | |
output = lm_backbone( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
return_dict=True, | |
output_hidden_states=True, | |
use_cache=False, # otherwise mistral-based RM would error out | |
) | |
reward_logits = model.score(output.hidden_states[-1]) | |
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length | |
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 | |
return ( | |
reward_logits, | |
reward_logits[ | |
torch.arange(reward_logits.size(0), device=reward_logits.device), | |
sequence_lengths, | |
].squeeze(-1), | |
sequence_lengths, | |
) | |
def forward( | |
model: torch.nn.Module, | |
query_responses: torch.Tensor, | |
pad_token_id: int, | |
) -> torch.nn.Module: | |
""" | |
Performs a forward pass through the model with the given query responses and pad token ID. | |
Args: | |
model (`torch.nn.Module`): | |
The model to perform the forward pass. | |
query_responses (`torch.Tensor`): | |
The tensor containing the query responses. | |
pad_token_id (`int`): | |
The token ID representing the pad token. | |
Returns: | |
`torch.nn.Module`: | |
The output of the model, including hidden states. | |
""" | |
attention_mask = query_responses != pad_token_id | |
position_ids = attention_mask.cumsum(1) - attention_mask.long() | |
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) | |
return model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
return_dict=True, | |
output_hidden_states=True, | |
) | |
def prepare_deepspeed( | |
model: torch.nn.Module, per_device_train_batch_size: int, fp16: bool = False, bf16: bool = False | |
): | |
""" | |
Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based on the model and | |
batch size. | |
Args: | |
model (`torch.nn.Module`): | |
The model to be prepared for DeepSpeed training. | |
per_device_train_batch_size (`int`): | |
The training batch size per device. | |
Returns: | |
`torch.nn.Module`: | |
The model initialized and configured with DeepSpeed for training. | |
""" | |
import deepspeed | |
deepspeed_plugin = AcceleratorState().deepspeed_plugin | |
config_kwargs = deepspeed_plugin.deepspeed_config | |
if config_kwargs["zero_optimization"]["stage"] != 3: | |
config_kwargs["train_micro_batch_size_per_gpu"] = per_device_train_batch_size | |
config_kwargs = { | |
"train_micro_batch_size_per_gpu": config_kwargs["train_micro_batch_size_per_gpu"], | |
"prescale_gradients": False, | |
"wall_clock_breakdown": False, | |
} | |
if bf16: | |
config_kwargs["bf16"] = {"enabled": True} | |
elif fp16: | |
config_kwargs["fp16"] = {"enabled": True} | |
else: | |
if hasattr(model, "config"): | |
hidden_size = ( | |
max(model.config.hidden_sizes) | |
if getattr(model.config, "hidden_sizes", None) | |
else getattr(model.config, "hidden_size", None) | |
) | |
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: | |
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` | |
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 | |
config_kwargs.update( | |
{ | |
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size, | |
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, | |
"zero_optimization.stage3_prefetch_bucket_size": 0, | |
} | |
) | |
model, *_ = deepspeed.initialize(model=model, config=config_kwargs) | |
model.eval() | |
return model | |
def truncate_response(stop_token_id: int, pad_token_id: int, responses: torch.Tensor): | |
""" | |
Truncates the responses at the first occurrence of the stop token, filling the rest with pad tokens. | |
Args: | |
stop_token_id (`int`): | |
The token ID representing the stop token where truncation occurs. | |
pad_token_id (`int`): | |
The token ID representing the pad token used to fill the truncated responses. | |
responses (`torch.Tensor`): | |
The tensor containing the responses to be truncated. | |
Returns: | |
`torch.Tensor`: | |
The truncated responses tensor with pad tokens filled after the stop token. | |
""" | |
trunc_idxs = first_true_indices(responses == stop_token_id).unsqueeze(-1) | |
new_size = [1] * (len(responses.size()) - 1) + [responses.shape[1]] | |
idxs = torch.arange(responses.shape[1], device=responses.device).view(*new_size) | |
postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, pad_token_id) | |
return postprocessed_responses | |
def generate( | |
lm_backbone: torch.nn.Module, queries: torch.Tensor, pad_token_id: int, generation_config: GenerationConfig | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Generates sequences from the language model backbone in a way that does not affect padding tokens. | |
Args: | |
lm_backbone (`torch.nn.Module`): | |
The language model backbone used for generation. | |
queries (`torch.Tensor`): | |
The tensor containing the input queries. | |
pad_token_id (`int`): | |
The token ID representing the pad token. | |
generation_config (`GenerationConfig`): | |
The configuration for the generation process. | |
Returns: | |
tuple: | |
- `generated_sequences` (`torch.Tensor`): | |
The concatenated tensor of input queries and generated sequences. | |
- `logits` (`torch.Tensor`): | |
The logits output from the generation process. | |
""" | |
context_length = queries.shape[1] | |
attention_mask = queries != pad_token_id | |
input_ids = torch.masked_fill(queries, ~attention_mask, 0) | |
output = lm_backbone.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
# position_ids=attention_mask.cumsum(1) - attention_mask.long(), # not needed: already adjusted in generations | |
# https://github.com/huggingface/transformers/blob/ac33aeeeee2a7a89b89c93c2962e6feb90daef0a/src/transformers/models/gpt2/modeling_gpt2.py#L1227-L1250 | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
logits = torch.stack(output.scores, 1) | |
return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits | |
def batch_generation( | |
model: torch.nn.Module, | |
queries: torch.Tensor, | |
local_rollout_forward_batch_size: int, | |
pad_token_id: int, | |
generation_config: GenerationConfig, | |
): | |
query_responses = [] | |
logitss = [] | |
batch_size = queries.shape[0] | |
for i in range(0, batch_size, local_rollout_forward_batch_size): | |
query = queries[i : i + local_rollout_forward_batch_size] | |
query_response, logits = generate( | |
model, | |
query, | |
pad_token_id, | |
generation_config, | |
) | |
query_responses.append(query_response) | |
logitss.append(logits) | |
# padding tensors | |
padded_query_responses = pad(query_responses, padding_value=pad_token_id, padding_side="right") | |
padded_logitss = pad(logitss, padding_value=0, padding_side="right") | |
# reshaping | |
padded_query_responses = padded_query_responses.view(-1, padded_query_responses.shape[-1])[:batch_size] | |
padded_logitss = padded_logitss.view(-1, *padded_logitss.shape[2:])[:batch_size] | |
return padded_query_responses, padded_logitss | |
def add_bos_token_if_needed( | |
bos_token_id: Optional[int], | |
prompt_len_input_ids: int, | |
prompt_tokens: dict[str, list[int]], | |
chosen_prompt_len_input_ids: int, | |
chosen_tokens: dict[str, list[int]], | |
rejected_prompt_len_input_ids: int, | |
rejected_tokens: dict[str, list[int]], | |
): | |
if bos_token_id is not None: | |
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: | |
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] | |
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] | |
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: | |
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] | |
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] | |
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: | |
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] | |
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] | |
return prompt_tokens, chosen_tokens, rejected_tokens | |
def add_eos_token_if_needed( | |
eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]] | |
): | |
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: | |
chosen_tokens["input_ids"].append(eos_token_id) | |
chosen_tokens["attention_mask"].append(1) | |
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: | |
rejected_tokens["input_ids"].append(eos_token_id) | |
rejected_tokens["attention_mask"].append(1) | |
return chosen_tokens, rejected_tokens | |
def truncate_right( | |
input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Truncates the input tensor from the right side after the first occurrence of the stop token. | |
Args: | |
input_ids (`torch.Tensor`): | |
The tensor containing the responses to be truncated | |
stop_token_id (`int`): | |
The token ID representing the stop token where truncation occurs | |
pad_token_id (`int`): | |
The token ID representing the pad token used to fill the truncated responses | |
Returns: | |
tuple: | |
- `output_ids` (`torch.Tensor`): | |
The truncated responses tensor with pad tokens filled after the stop token | |
- `mask` (`torch.Tensor`): | |
The mask tensor to indicate the padding tokens | |
""" | |
trunc_idxs = first_true_indices(input_ids == stop_token_id).unsqueeze(-1) | |
new_size = [1] * (len(input_ids.size()) - 1) + [input_ids.shape[1]] | |
idxs = torch.arange(input_ids.shape[1], device=input_ids.device).view(*new_size) | |
output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id) | |
mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0) | |
return output_ids, mask | |
def empty_cache() -> None: | |
"""Empties the cache of the available torch device. | |
This function checks for the availability of different torch devices (XPU, MLU, NPU, CUDA) | |
and empties the cache of the first available device it finds. | |
If none of the specific devices are available, it defaults to emptying the CUDA cache. | |
""" | |
if is_torch_xpu_available(): | |
torch.xpu.empty_cache() | |
elif is_torch_mlu_available(): | |
torch.mlu.empty_cache() | |
elif is_torch_npu_available(): | |
torch.npu.empty_cache() | |
else: | |
torch.cuda.empty_cache() | |
def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenizerBase) -> list[str]: | |
""" | |
Decodes the input tensor and strips the padding tokens. | |
Args: | |
inputs (`torch.Tensor`): | |
The input tensor to be decoded. | |
tokenizer (`transformers.PreTrainedTokenizerBase`): | |
The tokenizer used to decode the input tensor. | |
Returns: | |
`list[str]`: | |
The list of decoded strings with padding tokens stripped. | |
""" | |
decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False) | |
return [d.replace(tokenizer.pad_token, "") for d in decoded] | |
def generate_model_card( | |
base_model: Optional[str], | |
model_name: str, | |
hub_model_id: str, | |
dataset_name: Optional[str], | |
tags: list[str], | |
wandb_url: Optional[str], | |
trainer_name: str, | |
trainer_citation: Optional[str] = None, | |
paper_title: Optional[str] = None, | |
paper_id: Optional[str] = None, | |
comet_url: Optional[str] = None, | |
) -> ModelCard: | |
""" | |
Generate a `ModelCard` from a template. | |
Args: | |
base_model (`str` or `None`): | |
Base model name. | |
model_name (`str`): | |
Model name. | |
hub_model_id (`str`): | |
Hub model ID as `username/model_id`. | |
dataset_name (`str` or `None`): | |
Dataset name. | |
tags (`list[str]`): | |
Tags. | |
wandb_url (`str` or `None`): | |
Weights & Biases run URL. | |
comet_url (`str` or `None`): | |
Comet experiment URL. | |
trainer_name (`str`): | |
Trainer name. | |
trainer_citation (`str` or `None`, defaults to `None`): | |
Trainer citation as a BibTeX entry. | |
paper_title (`str` or `None`, defaults to `None`): | |
Paper title. | |
paper_id (`str` or `None`, defaults to `None`): | |
ArXiv paper ID as `YYMM.NNNNN`. | |
Returns: | |
`ModelCard`: | |
A ModelCard object. | |
""" | |
card_data = ModelCardData( | |
base_model=base_model, | |
datasets=dataset_name, | |
library_name="transformers", | |
licence="license", | |
model_name=model_name, | |
tags=["generated_from_trainer", *tags], | |
) | |
card = ModelCard.from_template( | |
card_data, | |
template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")), | |
base_model=base_model, | |
model_name=model_name, | |
hub_model_id=hub_model_id, | |
dataset_name=dataset_name, | |
wandb_url=wandb_url, | |
comet_url=comet_url, | |
trainer_name=trainer_name, | |
trainer_citation=trainer_citation, | |
paper_title=paper_title, | |
paper_id=paper_id, | |
trl_version=version("trl"), | |
transformers_version=version("transformers"), | |
pytorch_version=version("torch"), | |
datasets_version=version("datasets"), | |
tokenizers_version=version("tokenizers"), | |
) | |
return card | |
def get_comet_experiment_url() -> Optional[str]: | |
""" | |
If Comet integration is enabled, return the URL of the current Comet experiment; otherwise, return `None`. | |
""" | |
if not is_comet_available(): | |
return None | |
if comet_ml.get_running_experiment() is not None: | |
return comet_ml.get_running_experiment().url | |
return None | |
def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: | |
""" | |
If Comet integration is enabled logs a table to the Comet experiment if it is currently running. | |
Args: | |
name (`str`): | |
Table name. | |
table (`pd.DataFrame`): | |
The Pandas DataFrame containing the table to log. | |
""" | |
if not is_comet_available(): | |
raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml") | |
experiment = comet_ml.get_running_experiment() | |
if experiment is not None: | |
experiment.log_table(tabular_data=table, filename=name) | |
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: | |
""" | |
Shift non-zero elements in the mask and corresponding tensors to the left. | |
This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask. | |
For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across | |
all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows: | |
``` | |
[[0, 0, x, x, x, x], -> [[x, x, x, x], | |
[0, x, x, x, 0, 0]] [x, x, x, 0]] | |
``` | |
Args: | |
mask (`torch.Tensor`): | |
2D tensor (binary mask) with shape `(N, M)`. | |
*tensors (`torch.Tensor`) | |
One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`, | |
with non-zero values shifted and excess zero columns truncated in the same manner. | |
Returns: | |
`torch.Tensor`: | |
Updated binary mask with non-zero values flushed to the left and trailing zero columns removed. | |
`*torch.Tensor` | |
Updated tensors, processed in the same way as the mask. | |
Example: | |
```python | |
>>> mask = torch.tensor([[0, 0, 1, 1, 1], | |
... [0, 1, 1, 0, 0]]) | |
>>> tensor = torch.tensor([[9, 9, 2, 3, 4], | |
... [9, 5, 6, 9, 9]]) | |
>>> new_mask, new_tensor = flush_left(mask, tensor) | |
>>> print(new_mask) | |
tensor([[1, 1, 1], | |
[1, 1, 0]]) | |
>>> print(new_tensor) | |
tensor([[2, 3, 4], | |
[5, 6, 0]]) | |
``` | |
""" | |
_, M = mask.shape | |
# Create copy of mask and tensors | |
mask_copy = mask.clone() | |
tensors = [t.clone() for t in tensors] | |
# Shift non-zero values to the left | |
first_non_zero = mask_copy.argmax(dim=1) | |
pos = torch.arange(M, device=mask_copy.device).unsqueeze(0) | |
idx_roll = (pos + first_non_zero.unsqueeze(1)) % M | |
mask_roll = mask_copy.gather(1, idx_roll) | |
rolled_tensors = [t.gather(1, idx_roll) for t in tensors] | |
# Truncate trailing columns that are all zeros in mask_roll | |
col_sums = mask_roll.sum(dim=0) | |
empty_cols = col_sums == 0 | |
first_empty_col = int(empty_cols.to(torch.int8).argmax()) if empty_cols.any() else M | |
flushed_mask = mask_roll[:, :first_empty_col] | |
flushed_tensors = [t[:, :first_empty_col] for t in rolled_tensors] | |
if not flushed_tensors: | |
return flushed_mask | |
return flushed_mask, *flushed_tensors | |
def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: | |
""" | |
Shift non-zero elements in the mask and corresponding tensors to the right. See `flush_left` for details. | |
""" | |
_, M = mask.shape | |
# Create copy of mask and tensors | |
mask_copy = mask.clone() | |
tensors = [t.clone() for t in tensors] | |
# Shift non-zero values to the right | |
flipped_mask = torch.fliplr(mask_copy) | |
first_non_zero = flipped_mask.argmax(dim=1) | |
pos = torch.arange(M, device=mask_copy.device).unsqueeze(0) | |
idx_roll = (pos - first_non_zero.unsqueeze(1)) % M | |
mask_roll = mask_copy.gather(1, idx_roll) | |
rolled_tensors = [t.gather(1, idx_roll) for t in tensors] | |
# Truncate leading columns that are all zeros in mask_roll | |
col_sums = mask_roll.sum(dim=0) | |
non_empty_cols = col_sums != 0 | |
first_non_empty_col = int(non_empty_cols.to(torch.int8).argmax()) if non_empty_cols.any() else M | |
flushed_mask = mask_roll[:, first_non_empty_col:] | |
flushed_tensors = [t[:, first_non_empty_col:] for t in rolled_tensors] | |
if not flushed_tensors: | |
return flushed_mask | |
return flushed_mask, *flushed_tensors | |
def selective_log_softmax(logits, index): | |
""" | |
A memory-efficient implementation of the common `log_softmax -> gather` operation. | |
This function is equivalent to the following naive implementation: | |
```python | |
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) | |
``` | |
Args: | |
logits (`torch.Tensor`): | |
Logits tensor of shape `(..., num_classes)`. | |
index (`torch.Tensor`): | |
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. | |
Returns: | |
`torch.Tensor`: | |
Gathered log probabilities with the same shape as `index`. | |
""" | |
if logits.dtype in [torch.float32, torch.float64]: | |
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) | |
# loop to reduce peak mem consumption | |
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) | |
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) | |
else: | |
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach | |
per_token_logps = [] | |
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption | |
row_logps = F.log_softmax(row_logits, dim=-1) | |
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) | |
per_token_logps.append(row_per_token_logps) | |
per_token_logps = torch.stack(per_token_logps) | |
return per_token_logps | |
def print_prompt_completions_sample( | |
prompts: list[str], | |
completions: list[str], | |
rewards: dict[str, list[float]], | |
advantages: list[float], | |
step: int, | |
num_samples: int = None, | |
) -> None: | |
""" | |
Print out a sample of model completions to the console with multiple reward metrics. | |
This function creates a nicely formatted table showing prompt-completion pairs, useful for monitoring model outputs | |
during training. It requires the `rich` library to be installed. | |
Args: | |
prompts (`list[str]`): | |
List of prompts. | |
completions (`list[str]`): | |
List of completions corresponding to the prompts. | |
rewards (`dict[str, list[float]]`): | |
Dictionary where keys are reward names and values are lists of rewards. | |
advantages (`list[float]`): | |
List of advantages corresponding to the prompts and completions. | |
step (`int`): | |
Current training step number, used in the output title. | |
num_samples (`int` or `None`, *optional*, defaults to `None`): | |
Number of random samples to display. If `None` (default), all items will be displayed. | |
Example: | |
```python | |
>>> from trl.trainer.utils import print_prompt_completions_sample | |
>>> prompts = ["The sky is", "The sun is"] | |
>>> completions = [" blue.", " in the sky."] | |
>>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} | |
>>> advantages = [0.987, 0.654] | |
>>> print_prompt_completions_sample(prompts, completions, rewards, advantages, 42) | |
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Step 42 โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ | |
โ โโโโโโโโโโโโโโณโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโณโโโโโโโโโณโโโโโโโโโโโโ โ | |
โ โ Prompt โ Completion โ Correctness โ Format โ Advantage โ โ | |
โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ | |
โ โ The sky is โ blue. โ 0.12 โ 0.79 โ 0.99 โ โ | |
โ โโโโโโโโโโโโโโผโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโผโโโโโโโโโผโโโโโโโโโโโโค โ | |
โ โ The sun is โ in the sky. โ 0.46 โ 0.10 โ 0.65 โ โ | |
โ โโโโโโโโโโโโโโดโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโดโโโโโโโโโดโโโโโโโโโโโโ โ | |
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ | |
``` | |
""" | |
if not is_rich_available(): | |
raise ImportError( | |
"The function `print_prompt_completions_sample` requires the `rich` library. Please install it with " | |
"`pip install rich`." | |
) | |
console = Console() | |
table = Table(show_header=True, header_style="bold white", expand=True) | |
# Add columns | |
table.add_column("Prompt", style="bright_yellow") | |
table.add_column("Completion", style="bright_green") | |
for reward_name in rewards.keys(): | |
table.add_column(reward_name, style="bold cyan", justify="right") | |
table.add_column("Advantage", style="bold magenta", justify="right") | |
# Some basic input validation | |
if num_samples is not None: | |
if num_samples >= len(prompts): | |
num_samples = None | |
elif num_samples <= 0: | |
return | |
# Subsample data if num_samples is specified | |
if num_samples is not None: | |
indices = random.sample(range(len(prompts)), num_samples) | |
prompts = [prompts[i] for i in indices] | |
completions = [completions[i] for i in indices] | |
rewards = {key: [val[i] for i in indices] for key, val in rewards.items()} | |
advantages = [advantages[i] for i in indices] | |
for i in range(len(prompts)): | |
reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()] # 2 decimals | |
table.add_row(Text(prompts[i]), Text(completions[i]), *reward_values, f"{advantages[i]:.2f}") | |
table.add_section() # Adds a separator between rows | |
panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white") | |
console.print(panel) | |