Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
__version__ = "0.19.0.dev0" | |
from typing import TYPE_CHECKING | |
from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available | |
_import_structure = { | |
"scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"], | |
"data_utils": [ | |
"apply_chat_template", | |
"extract_prompt", | |
"is_conversational", | |
"maybe_apply_chat_template", | |
"maybe_convert_to_chatml", | |
"maybe_extract_prompt", | |
"maybe_unpair_preference_dataset", | |
"pack_dataset", | |
"pack_examples", | |
"truncate_dataset", | |
"unpair_preference_dataset", | |
], | |
"environment": ["TextEnvironment", "TextHistory"], | |
"extras": ["BestOfNSampler"], | |
"models": [ | |
"SUPPORTED_ARCHITECTURES", | |
"AutoModelForCausalLMWithValueHead", | |
"AutoModelForSeq2SeqLMWithValueHead", | |
"PreTrainedModelWrapper", | |
"create_reference_model", | |
"setup_chat_format", | |
], | |
"trainer": [ | |
"AlignPropConfig", | |
"AlignPropTrainer", | |
"AllTrueJudge", | |
"BaseBinaryJudge", | |
"BaseJudge", | |
"BasePairwiseJudge", | |
"BaseRankJudge", | |
"BCOConfig", | |
"BCOTrainer", | |
"CPOConfig", | |
"CPOTrainer", | |
"DataCollatorForCompletionOnlyLM", | |
"DPOConfig", | |
"DPOTrainer", | |
"FDivergenceConstants", | |
"FDivergenceType", | |
"GKDConfig", | |
"GKDTrainer", | |
"GRPOConfig", | |
"GRPOTrainer", | |
"HfPairwiseJudge", | |
"IterativeSFTConfig", | |
"IterativeSFTTrainer", | |
"KTOConfig", | |
"KTOTrainer", | |
"LogCompletionsCallback", | |
"MergeModelCallback", | |
"ModelConfig", | |
"NashMDConfig", | |
"NashMDTrainer", | |
"OnlineDPOConfig", | |
"OnlineDPOTrainer", | |
"OpenAIPairwiseJudge", | |
"ORPOConfig", | |
"ORPOTrainer", | |
"PairRMJudge", | |
"PPOConfig", | |
"PPOTrainer", | |
"PRMConfig", | |
"PRMTrainer", | |
"RewardConfig", | |
"RewardTrainer", | |
"RLOOConfig", | |
"RLOOTrainer", | |
"SFTConfig", | |
"SFTTrainer", | |
"WinRateCallback", | |
"XPOConfig", | |
"XPOTrainer", | |
], | |
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], | |
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], | |
} | |
try: | |
if not is_diffusers_available(): | |
raise OptionalDependencyNotAvailable() | |
except OptionalDependencyNotAvailable: | |
pass | |
else: | |
_import_structure["models"].extend( | |
[ | |
"DDPOPipelineOutput", | |
"DDPOSchedulerOutput", | |
"DDPOStableDiffusionPipeline", | |
"DefaultDDPOStableDiffusionPipeline", | |
] | |
) | |
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"]) | |
if TYPE_CHECKING: | |
from .data_utils import ( | |
apply_chat_template, | |
extract_prompt, | |
is_conversational, | |
maybe_apply_chat_template, | |
maybe_convert_to_chatml, | |
maybe_extract_prompt, | |
maybe_unpair_preference_dataset, | |
pack_dataset, | |
pack_examples, | |
truncate_dataset, | |
unpair_preference_dataset, | |
) | |
from .environment import TextEnvironment, TextHistory | |
from .extras import BestOfNSampler | |
from .models import ( | |
SUPPORTED_ARCHITECTURES, | |
AutoModelForCausalLMWithValueHead, | |
AutoModelForSeq2SeqLMWithValueHead, | |
PreTrainedModelWrapper, | |
create_reference_model, | |
setup_chat_format, | |
) | |
from .scripts import ScriptArguments, TrlParser, init_zero_verbose | |
from .trainer import ( | |
AlignPropConfig, | |
AlignPropTrainer, | |
AllTrueJudge, | |
BaseBinaryJudge, | |
BaseJudge, | |
BasePairwiseJudge, | |
BaseRankJudge, | |
BCOConfig, | |
BCOTrainer, | |
CPOConfig, | |
CPOTrainer, | |
DataCollatorForCompletionOnlyLM, | |
DPOConfig, | |
DPOTrainer, | |
FDivergenceConstants, | |
FDivergenceType, | |
GKDConfig, | |
GKDTrainer, | |
GRPOConfig, | |
GRPOTrainer, | |
HfPairwiseJudge, | |
IterativeSFTConfig, | |
IterativeSFTTrainer, | |
KTOConfig, | |
KTOTrainer, | |
LogCompletionsCallback, | |
MergeModelCallback, | |
ModelConfig, | |
NashMDConfig, | |
NashMDTrainer, | |
OnlineDPOConfig, | |
OnlineDPOTrainer, | |
OpenAIPairwiseJudge, | |
ORPOConfig, | |
ORPOTrainer, | |
PairRMJudge, | |
PPOConfig, | |
PPOTrainer, | |
PRMConfig, | |
PRMTrainer, | |
RewardConfig, | |
RewardTrainer, | |
RLOOConfig, | |
RLOOTrainer, | |
SFTConfig, | |
SFTTrainer, | |
WinRateCallback, | |
XPOConfig, | |
XPOTrainer, | |
) | |
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback | |
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config | |
try: | |
if not is_diffusers_available(): | |
raise OptionalDependencyNotAvailable() | |
except OptionalDependencyNotAvailable: | |
pass | |
else: | |
from .models import ( | |
DDPOPipelineOutput, | |
DDPOSchedulerOutput, | |
DDPOStableDiffusionPipeline, | |
DefaultDDPOStableDiffusionPipeline, | |
) | |
from .trainer import DDPOConfig, DDPOTrainer | |
else: | |
import sys | |
sys.modules[__name__] = _LazyModule( | |
__name__, | |
globals()["__file__"], | |
_import_structure, | |
module_spec=__spec__, | |
extra_objects={"__version__": __version__}, | |
) | |