VITA-Audio / tools /trainer_v4_48_3.py
shenyunhang's picture
-a
52e4f53
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import contextlib
import copy
import functools
import glob
import importlib.metadata
import inspect
import json
import math
import os
import random
import re
import shutil
import sys
import tempfile
import time
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
# Integrations must be imported before ML frameworks:
# isort: off
from transformers.integrations import (
get_reporting_integration_callbacks,
hp_params,
)
# isort: on
import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import ModelCard, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from transformers import __version__
from transformers.configuration_utils import PretrainedConfig
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from transformers.image_processing_utils import BaseImageProcessor
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from transformers.integrations.tpu import tpu_spmd_dataloader
from transformers.modelcard import TrainingSummary
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from transformers.optimization import Adafactor, get_scheduler
from transformers.processing_utils import ProcessorMixin
from transformers.pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_2_3,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
ExportableState,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from transformers.trainer_pt_utils import (
DistributedTensorGatherer,
EvalLoopContainer,
IterableDatasetShard,
LabelSmoother,
LayerWiseDummyOptimizer,
LengthGroupedSampler,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_model_param_count,
get_module_class_from_name,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
reissue_pt_warnings,
remove_dummy_checkpoint,
)
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
HubStrategy,
PredictionOutput,
RemoveColumnsCollator,
SaveStrategy,
TrainerMemoryTracker,
TrainOutput,
check_target_module_exists,
default_compute_objective,
denumpify_detensorize,
enable_full_determinism,
find_executable_batch_size,
get_last_checkpoint,
has_length,
neftune_post_forward_hook,
number_of_arguments,
seed_worker,
set_seed,
speed_metrics,
)
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.utils import (
ADAPTER_CONFIG_NAME,
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
XLA_FSDPV2_MIN_VERSION,
PushInProgress,
PushToHubMixin,
can_return_loss,
find_labels,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_grokadamw_available,
is_in_notebook,
is_ipex_available,
is_liger_kernel_available,
is_lomo_available,
is_peft_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_schedulefree_available,
is_torch_compile_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
logging,
strtobool,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.quantization_config import QuantizationMethod
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from transformers.utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
if is_apex_available():
from apex import amp
if is_datasets_available():
import datasets
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla import __version__ as XLA_VERSION
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
else:
IS_XLA_FSDPV2_POST_2_2 = False
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
IS_SAGEMAKER_MP_POST_1_10 = False
if is_safetensors_available():
import safetensors.torch
if is_peft_available():
from peft import PeftModel
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.state import AcceleratorState
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler
DATA_SAMPLERS += [SeedableRandomSampler]
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
if is_accelerate_available("0.28.0"):
from accelerate.utils import DataLoaderConfiguration
def _is_peft_model(model):
if is_peft_available():
classes_to_check = (PeftModel,) if is_peft_available() else ()
# Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
from peft import PeftMixedModel
classes_to_check = (*classes_to_check, PeftMixedModel)
return isinstance(model, classes_to_check)
return False
def _get_fsdp_ckpt_kwargs():
# TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
return {"adapter_only": True}
else:
return {}
def safe_globals():
# Starting from version 2.4 PyTorch introduces a check for the objects loaded
# with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes
# a default and requires allowlisting of objects being loaded.
# See: https://github.com/pytorch/pytorch/pull/137602
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
# See: https://github.com/huggingface/accelerate/pull/3036
if version.parse(torch.__version__).release < version.parse("2.6").release:
return contextlib.nullcontext()
np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]
# numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
# all versions of numpy
allowlist += [type(np.dtype(np.uint32))]
return torch.serialization.safe_globals(allowlist)
if TYPE_CHECKING:
import optuna
if is_datasets_available():
import datasets
logger = logging.get_logger(__name__)
logger.setLevel("INFO")
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"
DATA_PRINT_ONCE = True
BATCH = None
def print_batch(batch, tokenizer, args):
global DATA_PRINT_ONCE
global BATCH
if batch is not None:
BATCH = batch
else:
batch = BATCH
DATA_PRINT_ONCE = True
if batch is None:
return
if DATA_PRINT_ONCE:
global_rank = torch.distributed.get_rank()
f = open(os.path.join(args.output_dir, f"print_batch_{global_rank}.log"), "a")
torch.set_printoptions(threshold=100_000)
if "loss_mask" in batch and batch["loss_mask"] is not None:
loss_mask = batch["loss_mask"]
print(f"loss_mask {loss_mask} {loss_mask.size()}", file=f)
if "position_ids" in batch and batch["position_ids"] is not None:
position_ids = batch["position_ids"]
print(f"position_ids {position_ids} {position_ids.size()}", file=f)
if "attention_mask" in batch and batch["attention_mask"] is not None:
attention_mask = batch["attention_mask"]
if isinstance(attention_mask, list):
attention_mask = attention_mask[0]
print(f"attention_mask {attention_mask} {attention_mask.size()}", file=f)
if "input_ids" in batch and batch["input_ids"] is not None:
tokens = batch["input_ids"]
print(f"tokens {tokens} {tokens.size()}", file=f)
tokens_ = tokens.cpu().clone().detach()
tokens_ = tokenizer.batch_decode(tokens_.tolist(), skip_special_tokens=False)
print(f"tokens_ {tokens_[:]}", file=f)
if "labels" in batch and batch["labels"] is not None:
labels = batch["labels"]
print(f"labels {labels} {labels.size()}", file=f)
labels_ = labels.cpu().clone().detach()
labels_[labels_==-100] = tokenizer("-", add_special_tokens=False).input_ids[0]
labels_ = tokenizer.batch_decode(labels_.tolist(), skip_special_tokens=False)
print(f"labels {labels_}", file=f)
# labels__ = labels.cpu().clone().detach()
# labels__[loss_mask.to(torch.int64)==0] = tokenizer("-", add_special_tokens=False).input_ids[0]
# labels__ = tokenizer.batch_decode(labels__.tolist(), skip_special_tokens=False)
# print(f"labels__ {labels__}", file=f)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
print(f"{k} {v} {v.size()}", file=f)
else:
print(f"{k} {v}", file=f)
f.close()
DATA_PRINT_ONCE = False
from transformers import Trainer as HFTrainer
class Trainer(HFTrainer):
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"multiprocessing_context": "spawn",
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(opt_model)
if self.args.vision_model_lr_mult != 1.0 or self.args.vision_model_lr_decay_rate != 1.0:
vision_parameters = [name for name, _ in opt_model.named_parameters() if "vision_model" in name]
logger.info(f"{vision_parameters=}")
else:
vision_parameters = []
if self.args.mtp_model_lr_mult != 1.0:
mtp_parameters = []
mtp_names = ["mtp"]
num_nextn_predict_layers = self.model.config.num_nextn_predict_layers
num_hidden_layers = self.model.config.num_hidden_layers
for mtp_idx in range(num_nextn_predict_layers):
layer_idx = num_hidden_layers - num_nextn_predict_layers + mtp_idx
mtp_names.append(f"model.layers.{layer_idx}")
for name, param in opt_model.named_parameters():
if any([x in name for x in mtp_names]):
mtp_parameters.append(name)
logger.info(f"{mtp_parameters=}")
else:
mtp_parameters = []
exclude_parameters = vision_parameters + mtp_parameters
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n not in exclude_parameters)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n not in exclude_parameters)
],
"weight_decay": 0.0,
},
]
if self.args.vision_model_lr_decay_rate != 1.0:
for n, p in opt_model.named_parameters():
if p.requires_grad and n in vision_parameters:
pass
else:
continue
if n in decay_parameters:
weight_decay = self.args.weight_decay
else:
weight_decay = 0.0
lr = self.args.learning_rate * get_vit_lr_decay_rate(n, opt_model.config.visual.num_hidden_layers, self.args.vision_model_lr_decay_rate)
optimizer_grouped_parameters.append(
{
"params": [p],
"weight_decay": weight_decay,
"lr": lr,
}
)
logger.info(f"create_optimizer name {n} weight_decay {weight_decay} lr {lr}")
elif self.args.vision_model_lr_mult != 1.0:
optimizer_grouped_parameters.extend(
[
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in vision_parameters)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate * self.args.vision_model_lr_mult,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in vision_parameters)
],
"weight_decay": 0.0,
"lr": self.args.learning_rate * self.args.vision_model_lr_mult,
},
]
)
logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in vision_parameters)]} weight_decay {self.args.weight_decay} lr_mult {self.args.vision_model_lr_mult}")
logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in vision_parameters)]} weight_decay {0.0} lr_mult {self.args.vision_model_lr_mult}")
if self.args.mtp_model_lr_mult != 1.0:
optimizer_grouped_parameters.extend(
[
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in mtp_parameters)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate * self.args.mtp_model_lr_mult,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in mtp_parameters)
],
"weight_decay": 0.0,
"lr": self.args.learning_rate * self.args.mtp_model_lr_mult,
},
]
)
logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in mtp_parameters)]} weight_decay {self.args.weight_decay} lr_mult {self.args.mtp_model_lr_mult}")
logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in mtp_parameters)]} weight_decay {0.0} lr_mult {self.args.mtp_model_lr_mult}")
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
print_batch(inputs, self.processing_class, self.args)
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
elif is_torch_hpu_available():
logger.warning(
"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
)
else:
torch.cuda.empty_cache()
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
# https://github.com/huggingface/transformers/pull/35808
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs["scale_wrt_gas"] = False
self.accelerator.backward(loss, **kwargs)
return loss.detach()
def get_batch_samples(self, epoch_iterator, num_batches):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
while True:
batch_sample = next(epoch_iterator)
if "input_ids" in batch_sample:
break
batch_samples += [batch_sample]
except StopIteration:
break
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
except (TypeError, AttributeError):
pass
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.item()
return batch_samples, num_items_in_batch
def get_vit_lr_decay_rate(name, num_layers, lr_decay_rate):
layer_id = num_layers + 1
if "vision_model." in name:
if ".position_embedding." in name or ".conv1." in name:
layer_id = 0
elif ".layers." in name:
layer_id = int(name[name.find(".layers.") :].split(".")[2]) + 1
return lr_decay_rate ** (num_layers + 1 - layer_id)