trl-sandbox / trl /trainer /ddpo_config.py
ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from transformers import is_bitsandbytes_available
from ..core import flatten_dict
@dataclass
class DDPOConfig:
r"""
Configuration class for the [`DDPOTrainer`].
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
Name of this experiment (by default is the file name without the extension name).
run_name (`str`, *optional*, defaults to `""`):
Name of this run.
seed (`int`, *optional*, defaults to `0`):
Random seed.
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
Log with either 'wandb' or 'tensorboard', check
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
Keyword arguments for the tracker (e.g. wandb_project).
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
Keyword arguments for the accelerator.
project_kwargs (`Dict`, *optional*, defaults to `{}`):
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
Name of project to use for tracking.
logdir (`str`, *optional*, defaults to `"logs"`):
Top-level logging directory for checkpoint saving.
num_epochs (`int`, *optional*, defaults to `100`):
Number of epochs to train.
save_freq (`int`, *optional*, defaults to `1`):
Number of epochs between saving model checkpoints.
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
Number of checkpoints to keep before overwriting old ones.
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
Mixed precision training.
allow_tf32 (`bool`, *optional*, defaults to `True`):
Allow `tf32` on Ampere GPUs.
resume_from (`str`, *optional*, defaults to `""`):
Resume training from a checkpoint.
sample_num_steps (`int`, *optional*, defaults to `50`):
Number of sampler inference steps.
sample_eta (`float`, *optional*, defaults to `1.0`):
Eta parameter for the DDIM sampler.
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
Classifier-free guidance weight.
sample_batch_size (`int`, *optional*, defaults to `1`):
Batch size (per GPU) to use for sampling.
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
Number of batches to sample per epoch.
train_batch_size (`int`, *optional*, defaults to `1`):
Batch size (per GPU) to use for training.
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
Use 8bit Adam optimizer from bitsandbytes.
train_learning_rate (`float`, *optional*, defaults to `3e-4`):
Learning rate.
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
Adam beta1.
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
Adam beta2.
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
Adam weight decay.
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
Adam epsilon.
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
Number of gradient accumulation steps.
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
Maximum gradient norm for gradient clipping.
train_num_inner_epochs (`int`, *optional*, defaults to `1`):
Number of inner epochs per outer epoch.
train_cfg (`bool`, *optional*, defaults to `True`):
Whether to use classifier-free guidance during training.
train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
Clip advantages to the range.
train_clip_range (`float`, *optional*, defaults to `1e-4`):
PPO clip range.
train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
Fraction of timesteps to train on.
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
Whether to track statistics for each prompt separately.
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
Number of reward values to store in the buffer for each prompt.
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
Minimum number of reward values to store in the buffer.
async_reward_computation (`bool`, *optional*, defaults to `False`):
Whether to compute rewards asynchronously.
max_workers (`int`, *optional*, defaults to `2`):
Maximum number of workers to use for async reward computation.
negative_prompts (`str`, *optional*, defaults to `""`):
Comma-separated list of prompts to use as negative examples.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the final model checkpoint to the Hub.
"""
exp_name: str = field(
default=os.path.basename(sys.argv[0])[: -len(".py")],
metadata={"help": "Name of this experiment (by default is the file name without the extension name)."},
)
run_name: str = field(
default="",
metadata={"help": "Name of this run."},
)
seed: int = field(
default=0,
metadata={"help": "Random seed."},
)
log_with: Optional[str] = field(
default=None,
metadata={
"help": "Log with either 'wandb' or 'tensorboard'.",
"choices": ["wandb", "tensorboard"],
},
)
tracker_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Keyword arguments for the tracker (e.g. wandb_project)."},
)
accelerator_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Keyword arguments for the accelerator."},
)
project_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Keyword arguments for the accelerator project config (e.g. `logging_dir`)."},
)
tracker_project_name: str = field(
default="trl",
metadata={"help": "Name of project to use for tracking."},
)
logdir: str = field(
default="logs",
metadata={"help": "Top-level logging directory for checkpoint saving."},
)
num_epochs: int = field(
default=100,
metadata={"help": "Number of epochs to train."},
)
save_freq: int = field(
default=1,
metadata={"help": "Number of epochs between saving model checkpoints."},
)
num_checkpoint_limit: int = field(
default=5,
metadata={"help": "Number of checkpoints to keep before overwriting old ones."},
)
mixed_precision: str = field(
default="fp16",
metadata={"help": "Mixed precision training."},
)
allow_tf32: bool = field(
default=True,
metadata={"help": "Allow `tf32` on Ampere GPUs."},
)
resume_from: str = field(
default="",
metadata={"help": "Resume training from a checkpoint."},
)
sample_num_steps: int = field(
default=50,
metadata={"help": "Number of sampler inference steps."},
)
sample_eta: float = field(
default=1.0,
metadata={"help": "Eta parameter for the DDIM sampler."},
)
sample_guidance_scale: float = field(
default=5.0,
metadata={"help": "Classifier-free guidance weight."},
)
sample_batch_size: int = field(
default=1,
metadata={"help": "Batch size (per GPU) to use for sampling."},
)
sample_num_batches_per_epoch: int = field(
default=2,
metadata={"help": "Number of batches to sample per epoch."},
)
train_batch_size: int = field(
default=1,
metadata={"help": "Batch size (per GPU) to use for training."},
)
train_use_8bit_adam: bool = field(
default=False,
metadata={"help": "Use 8bit Adam optimizer from bitsandbytes."},
)
train_learning_rate: float = field(
default=3e-4,
metadata={"help": "Learning rate."},
)
train_adam_beta1: float = field(
default=0.9,
metadata={"help": "Adam beta1."},
)
train_adam_beta2: float = field(
default=0.999,
metadata={"help": "Adam beta2."},
)
train_adam_weight_decay: float = field(
default=1e-4,
metadata={"help": "Adam weight decay."},
)
train_adam_epsilon: float = field(
default=1e-8,
metadata={"help": "Adam epsilon."},
)
train_gradient_accumulation_steps: int = field(
default=1,
metadata={"help": "Number of gradient accumulation steps."},
)
train_max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm for gradient clipping."},
)
train_num_inner_epochs: int = field(
default=1,
metadata={"help": "Number of inner epochs per outer epoch."},
)
train_cfg: bool = field(
default=True,
metadata={"help": "Whether to use classifier-free guidance during training."},
)
train_adv_clip_max: float = field(
default=5.0,
metadata={"help": "Clip advantages to the range."},
)
train_clip_range: float = field(
default=1e-4,
metadata={"help": "PPO clip range."},
)
train_timestep_fraction: float = field(
default=1.0,
metadata={"help": "Fraction of timesteps to train on."},
)
per_prompt_stat_tracking: bool = field(
default=False,
metadata={"help": "Whether to track statistics for each prompt separately."},
)
per_prompt_stat_tracking_buffer_size: int = field(
default=16,
metadata={"help": "Number of reward values to store in the buffer for each prompt."},
)
per_prompt_stat_tracking_min_count: int = field(
default=16,
metadata={"help": "Minimum number of reward values to store in the buffer."},
)
async_reward_computation: bool = field(
default=False,
metadata={"help": "Whether to compute rewards asynchronously."},
)
max_workers: int = field(
default=2,
metadata={"help": "Maximum number of workers to use for async reward computation."},
)
negative_prompts: str = field(
default="",
metadata={"help": "Comma-separated list of prompts to use as negative examples."},
)
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the final model checkpoint to the Hub."},
)
def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return flatten_dict(output_dict)
def __post_init__(self):
if self.train_use_8bit_adam and not is_bitsandbytes_available():
raise ImportError(
"You need to install bitsandbytes to use 8bit Adam. "
"You can install it with `pip install bitsandbytes`."
)