trl-sandbox / trl /__init__.py
ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.19.0.dev0"
from typing import TYPE_CHECKING
from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
_import_structure = {
"scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"],
"data_utils": [
"apply_chat_template",
"extract_prompt",
"is_conversational",
"maybe_apply_chat_template",
"maybe_convert_to_chatml",
"maybe_extract_prompt",
"maybe_unpair_preference_dataset",
"pack_dataset",
"pack_examples",
"truncate_dataset",
"unpair_preference_dataset",
],
"environment": ["TextEnvironment", "TextHistory"],
"extras": ["BestOfNSampler"],
"models": [
"SUPPORTED_ARCHITECTURES",
"AutoModelForCausalLMWithValueHead",
"AutoModelForSeq2SeqLMWithValueHead",
"PreTrainedModelWrapper",
"create_reference_model",
"setup_chat_format",
],
"trainer": [
"AlignPropConfig",
"AlignPropTrainer",
"AllTrueJudge",
"BaseBinaryJudge",
"BaseJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"BCOConfig",
"BCOTrainer",
"CPOConfig",
"CPOTrainer",
"DataCollatorForCompletionOnlyLM",
"DPOConfig",
"DPOTrainer",
"FDivergenceConstants",
"FDivergenceType",
"GKDConfig",
"GKDTrainer",
"GRPOConfig",
"GRPOTrainer",
"HfPairwiseJudge",
"IterativeSFTConfig",
"IterativeSFTTrainer",
"KTOConfig",
"KTOTrainer",
"LogCompletionsCallback",
"MergeModelCallback",
"ModelConfig",
"NashMDConfig",
"NashMDTrainer",
"OnlineDPOConfig",
"OnlineDPOTrainer",
"OpenAIPairwiseJudge",
"ORPOConfig",
"ORPOTrainer",
"PairRMJudge",
"PPOConfig",
"PPOTrainer",
"PRMConfig",
"PRMTrainer",
"RewardConfig",
"RewardTrainer",
"RLOOConfig",
"RLOOTrainer",
"SFTConfig",
"SFTTrainer",
"WinRateCallback",
"XPOConfig",
"XPOTrainer",
],
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
}
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["models"].extend(
[
"DDPOPipelineOutput",
"DDPOSchedulerOutput",
"DDPOStableDiffusionPipeline",
"DefaultDDPOStableDiffusionPipeline",
]
)
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])
if TYPE_CHECKING:
from .data_utils import (
apply_chat_template,
extract_prompt,
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_dataset,
pack_examples,
truncate_dataset,
unpair_preference_dataset,
)
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .models import (
SUPPORTED_ARCHITECTURES,
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
setup_chat_format,
)
from .scripts import ScriptArguments, TrlParser, init_zero_verbose
from .trainer import (
AlignPropConfig,
AlignPropTrainer,
AllTrueJudge,
BaseBinaryJudge,
BaseJudge,
BasePairwiseJudge,
BaseRankJudge,
BCOConfig,
BCOTrainer,
CPOConfig,
CPOTrainer,
DataCollatorForCompletionOnlyLM,
DPOConfig,
DPOTrainer,
FDivergenceConstants,
FDivergenceType,
GKDConfig,
GKDTrainer,
GRPOConfig,
GRPOTrainer,
HfPairwiseJudge,
IterativeSFTConfig,
IterativeSFTTrainer,
KTOConfig,
KTOTrainer,
LogCompletionsCallback,
MergeModelCallback,
ModelConfig,
NashMDConfig,
NashMDTrainer,
OnlineDPOConfig,
OnlineDPOTrainer,
OpenAIPairwiseJudge,
ORPOConfig,
ORPOTrainer,
PairRMJudge,
PPOConfig,
PPOTrainer,
PRMConfig,
PRMTrainer,
RewardConfig,
RewardTrainer,
RLOOConfig,
RLOOTrainer,
SFTConfig,
SFTTrainer,
WinRateCallback,
XPOConfig,
XPOTrainer,
)
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .models import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
from .trainer import DDPOConfig, DDPOTrainer
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
extra_objects={"__version__": __version__},
)