trl-sandbox / trl /trainer /model_config.py
ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelConfig:
"""
Configuration class for the models.
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
Model checkpoint for weights initialization.
model_revision (`str`, *optional*, defaults to `"main"`):
Specific model version to use. It can be a branch name, a tag name, or a commit id.
torch_dtype (`Literal["auto", "bfloat16", "float16", "float32"]` or `None`, *optional*, defaults to `None`):
Override the default `torch.dtype` and load the model under this dtype. Possible values are
- `"bfloat16"`: `torch.bfloat16`
- `"float16"`: `torch.float16`
- `"float32"`: `torch.float32`
- `"auto"`: Automatically derive the dtype from the model's weights.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether to allow for custom models defined on the Hub in their own modeling files. This option should only
be set to `True` for repositories you trust and in which you have read the code, as it will execute code
present on the Hub on your local machine.
attn_implementation (`str` or `None`, *optional*, defaults to `None`):
Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case
you must install this manually by running `pip install flash-attn --no-build-isolation`.
use_peft (`bool`, *optional*, defaults to `False`):
Whether to use PEFT for training.
lora_r (`int`, *optional*, defaults to `16`):
LoRA R value.
lora_alpha (`int`, *optional*, defaults to `32`):
LoRA alpha.
lora_dropout (`float`, *optional*, defaults to `0.05`):
LoRA dropout.
lora_target_modules (`Union[str, list[str]]` or `None`, *optional*, defaults to `None`):
LoRA target modules.
lora_modules_to_save (`list[str]` or `None`, *optional*, defaults to `None`):
Model layers to unfreeze & train.
lora_task_type (`str`, *optional*, defaults to `"CAUSAL_LM"`):
Task type to pass for LoRA (use `"SEQ_CLS"` for reward modeling).
use_rslora (`bool`, *optional*, defaults to `False`):
Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, instead of
the original default value of `lora_alpha/r`.
use_dora (`bool`, *optional*, defaults to `False`):
Enable [Weight-Decomposed Low-Rank Adaptation (DoRA)](https://huggingface.co/papers/2402.09353). This
technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is
handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can
improve the performance of LoRA, especially at low ranks. Right now, DoRA only supports linear and Conv2D
layers. DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for
inference.
load_in_8bit (`bool`, *optional*, defaults to `False`):
Whether to use 8 bit precision for the base model. Works only with LoRA.
load_in_4bit (`bool`, *optional*, defaults to `False`):
Whether to use 4 bit precision for the base model. Works only with LoRA.
bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`):
Quantization type (`"fp4"` or `"nf4"`).
use_bnb_nested_quant (`bool`, *optional*, defaults to `False`):
Whether to use nested quantization.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "Model checkpoint for weights initialization."},
)
model_revision: str = field(
default="main",
metadata={"help": "Specific model version to use. It can be a branch name, a tag name, or a commit id."},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": "Override the default `torch.dtype` and load the model under this dtype.",
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": "Whether to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
},
)
attn_implementation: Optional[str] = field(
default=None,
metadata={
"help": "Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in "
"which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
},
)
use_peft: bool = field(
default=False,
metadata={"help": "Whether to use PEFT for training."},
)
lora_r: int = field(
default=16,
metadata={"help": "LoRA R value."},
)
lora_alpha: int = field(
default=32,
metadata={"help": "LoRA alpha."},
)
lora_dropout: float = field(
default=0.05,
metadata={"help": "LoRA dropout."},
)
lora_target_modules: Optional[list[str]] = field(
default=None,
metadata={"help": "LoRA target modules."},
)
lora_modules_to_save: Optional[list[str]] = field(
default=None,
metadata={"help": "Model layers to unfreeze & train."},
)
lora_task_type: str = field(
default="CAUSAL_LM",
metadata={"help": "Task type to pass for LoRA (use 'SEQ_CLS' for reward modeling)."},
)
use_rslora: bool = field(
default=False,
metadata={
"help": "Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, "
"instead of the original default value of `lora_alpha/r`."
},
)
use_dora: bool = field(
default=False,
metadata={
"help": "Enable Weight-Decomposed Low-Rank Adaptation (DoRA). This technique decomposes the updates of "
"the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the "
"magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, "
"especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a "
"bigger overhead than pure LoRA, so it is recommended to merge weights for inference."
},
)
load_in_8bit: bool = field(
default=False,
metadata={"help": "Whether to use 8 bit precision for the base model. Works only with LoRA."},
)
load_in_4bit: bool = field(
default=False,
metadata={"help": "Whether to use 4 bit precision for the base model. Works only with LoRA."},
)
bnb_4bit_quant_type: str = field(
default="nf4",
metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]},
)
use_bnb_nested_quant: bool = field(
default=False,
metadata={"help": "Whether to use nested quantization."},
)
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
if hasattr(self.lora_target_modules, "__len__") and len(self.lora_target_modules) == 1:
self.lora_target_modules = self.lora_target_modules[0]