Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2020-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. | |
""" | |
import contextlib | |
import copy | |
import functools | |
import glob | |
import importlib.metadata | |
import inspect | |
import json | |
import math | |
import os | |
import random | |
import re | |
import shutil | |
import sys | |
import tempfile | |
import time | |
import warnings | |
from collections.abc import Mapping | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union | |
# Integrations must be imported before ML frameworks: | |
# isort: off | |
from transformers.integrations import ( | |
get_reporting_integration_callbacks, | |
hp_params, | |
) | |
# isort: on | |
import huggingface_hub.utils as hf_hub_utils | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from huggingface_hub import ModelCard, create_repo, upload_folder | |
from packaging import version | |
from torch import nn | |
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler | |
from transformers import __version__ | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator | |
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow | |
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor | |
from transformers.feature_extraction_utils import FeatureExtractionMixin | |
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend | |
from transformers.image_processing_utils import BaseImageProcessor | |
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available | |
from transformers.integrations.tpu import tpu_spmd_dataloader | |
from transformers.modelcard import TrainingSummary | |
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model | |
from transformers.models.auto.modeling_auto import ( | |
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, | |
MODEL_MAPPING_NAMES, | |
) | |
from transformers.optimization import Adafactor, get_scheduler | |
from transformers.processing_utils import ProcessorMixin | |
from transformers.pytorch_utils import ( | |
ALL_LAYERNORM_LAYERS, | |
is_torch_greater_or_equal_than_2_3, | |
) | |
from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
from transformers.trainer_callback import ( | |
CallbackHandler, | |
DefaultFlowCallback, | |
ExportableState, | |
PrinterCallback, | |
ProgressCallback, | |
TrainerCallback, | |
TrainerControl, | |
TrainerState, | |
) | |
from transformers.trainer_pt_utils import ( | |
DistributedTensorGatherer, | |
EvalLoopContainer, | |
IterableDatasetShard, | |
LabelSmoother, | |
LayerWiseDummyOptimizer, | |
LengthGroupedSampler, | |
SequentialDistributedSampler, | |
distributed_broadcast_scalars, | |
distributed_concat, | |
find_batch_size, | |
get_model_param_count, | |
get_module_class_from_name, | |
get_parameter_names, | |
nested_concat, | |
nested_detach, | |
nested_numpify, | |
nested_xla_mesh_reduce, | |
reissue_pt_warnings, | |
remove_dummy_checkpoint, | |
) | |
from transformers.trainer_utils import ( | |
PREFIX_CHECKPOINT_DIR, | |
BestRun, | |
EvalLoopOutput, | |
EvalPrediction, | |
HPSearchBackend, | |
HubStrategy, | |
PredictionOutput, | |
RemoveColumnsCollator, | |
SaveStrategy, | |
TrainerMemoryTracker, | |
TrainOutput, | |
check_target_module_exists, | |
default_compute_objective, | |
denumpify_detensorize, | |
enable_full_determinism, | |
find_executable_batch_size, | |
get_last_checkpoint, | |
has_length, | |
neftune_post_forward_hook, | |
number_of_arguments, | |
seed_worker, | |
set_seed, | |
speed_metrics, | |
) | |
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments | |
from transformers.utils import ( | |
ADAPTER_CONFIG_NAME, | |
ADAPTER_SAFE_WEIGHTS_NAME, | |
ADAPTER_WEIGHTS_NAME, | |
CONFIG_NAME, | |
SAFE_WEIGHTS_INDEX_NAME, | |
SAFE_WEIGHTS_NAME, | |
WEIGHTS_INDEX_NAME, | |
WEIGHTS_NAME, | |
XLA_FSDPV2_MIN_VERSION, | |
PushInProgress, | |
PushToHubMixin, | |
can_return_loss, | |
find_labels, | |
is_accelerate_available, | |
is_apex_available, | |
is_bitsandbytes_available, | |
is_datasets_available, | |
is_galore_torch_available, | |
is_grokadamw_available, | |
is_in_notebook, | |
is_ipex_available, | |
is_liger_kernel_available, | |
is_lomo_available, | |
is_peft_available, | |
is_safetensors_available, | |
is_sagemaker_dp_enabled, | |
is_sagemaker_mp_enabled, | |
is_schedulefree_available, | |
is_torch_compile_available, | |
is_torch_mlu_available, | |
is_torch_mps_available, | |
is_torch_musa_available, | |
is_torch_neuroncore_available, | |
is_torch_npu_available, | |
is_torch_xla_available, | |
is_torch_xpu_available, | |
is_torchao_available, | |
logging, | |
strtobool, | |
) | |
from transformers.utils.deprecation import deprecate_kwarg | |
from transformers.utils.quantization_config import QuantizationMethod | |
DEFAULT_CALLBACKS = [DefaultFlowCallback] | |
DEFAULT_PROGRESS_CALLBACK = ProgressCallback | |
if is_in_notebook(): | |
from transformers.utils.notebook import NotebookProgressCallback | |
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback | |
if is_apex_available(): | |
from apex import amp | |
if is_datasets_available(): | |
import datasets | |
if is_torch_xla_available(): | |
import torch_xla.core.xla_model as xm | |
import torch_xla.debug.metrics as met | |
from torch_xla import __version__ as XLA_VERSION | |
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) | |
if IS_XLA_FSDPV2_POST_2_2: | |
import torch_xla.distributed.spmd as xs | |
import torch_xla.runtime as xr | |
else: | |
IS_XLA_FSDPV2_POST_2_2 = False | |
if is_sagemaker_mp_enabled(): | |
import smdistributed.modelparallel.torch as smp | |
from smdistributed.modelparallel import __version__ as SMP_VERSION | |
IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") | |
from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat | |
else: | |
IS_SAGEMAKER_MP_POST_1_10 = False | |
if is_safetensors_available(): | |
import safetensors.torch | |
if is_peft_available(): | |
from peft import PeftModel | |
if is_accelerate_available(): | |
from accelerate import Accelerator, skip_first_batches | |
from accelerate import __version__ as accelerate_version | |
from accelerate.state import AcceleratorState | |
from accelerate.utils import ( | |
DistributedDataParallelKwargs, | |
DistributedType, | |
load_fsdp_model, | |
load_fsdp_optimizer, | |
save_fsdp_model, | |
save_fsdp_optimizer, | |
) | |
DATA_SAMPLERS = [RandomSampler] | |
if version.parse(accelerate_version) > version.parse("0.23.0"): | |
from accelerate.data_loader import SeedableRandomSampler | |
DATA_SAMPLERS += [SeedableRandomSampler] | |
if is_deepspeed_available(): | |
from accelerate.utils import DeepSpeedSchedulerWrapper | |
if is_accelerate_available("0.28.0"): | |
from accelerate.utils import DataLoaderConfiguration | |
def _is_peft_model(model): | |
if is_peft_available(): | |
classes_to_check = (PeftModel,) if is_peft_available() else () | |
# Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 | |
if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): | |
from peft import PeftMixedModel | |
classes_to_check = (*classes_to_check, PeftMixedModel) | |
return isinstance(model, classes_to_check) | |
return False | |
def _get_fsdp_ckpt_kwargs(): | |
# TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release | |
if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): | |
return {"adapter_only": True} | |
else: | |
return {} | |
def safe_globals(): | |
# Starting from version 2.4 PyTorch introduces a check for the objects loaded | |
# with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes | |
# a default and requires allowlisting of objects being loaded. | |
# See: https://github.com/pytorch/pytorch/pull/137602 | |
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals | |
# See: https://github.com/huggingface/accelerate/pull/3036 | |
if version.parse(torch.__version__).release < version.parse("2.6").release: | |
return contextlib.nullcontext() | |
np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core | |
allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] | |
# numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for | |
# all versions of numpy | |
allowlist += [type(np.dtype(np.uint32))] | |
return torch.serialization.safe_globals(allowlist) | |
if TYPE_CHECKING: | |
import optuna | |
if is_datasets_available(): | |
import datasets | |
logger = logging.get_logger(__name__) | |
logger.setLevel("INFO") | |
# Name of the files used for checkpointing | |
TRAINING_ARGS_NAME = "training_args.bin" | |
TRAINER_STATE_NAME = "trainer_state.json" | |
OPTIMIZER_NAME = "optimizer.pt" | |
OPTIMIZER_NAME_BIN = "optimizer.bin" | |
SCHEDULER_NAME = "scheduler.pt" | |
SCALER_NAME = "scaler.pt" | |
FSDP_MODEL_NAME = "pytorch_model_fsdp" | |
DATA_PRINT_ONCE = True | |
BATCH = None | |
def print_batch(batch, tokenizer, args): | |
global DATA_PRINT_ONCE | |
global BATCH | |
if batch is not None: | |
BATCH = batch | |
else: | |
batch = BATCH | |
DATA_PRINT_ONCE = True | |
if batch is None: | |
return | |
if DATA_PRINT_ONCE: | |
global_rank = torch.distributed.get_rank() | |
f = open(os.path.join(args.output_dir, f"print_batch_{global_rank}.log"), "a") | |
torch.set_printoptions(threshold=100_000) | |
if "loss_mask" in batch and batch["loss_mask"] is not None: | |
loss_mask = batch["loss_mask"] | |
print(f"loss_mask {loss_mask} {loss_mask.size()}", file=f) | |
if "position_ids" in batch and batch["position_ids"] is not None: | |
position_ids = batch["position_ids"] | |
print(f"position_ids {position_ids} {position_ids.size()}", file=f) | |
if "attention_mask" in batch and batch["attention_mask"] is not None: | |
attention_mask = batch["attention_mask"] | |
if isinstance(attention_mask, list): | |
attention_mask = attention_mask[0] | |
print(f"attention_mask {attention_mask} {attention_mask.size()}", file=f) | |
if "input_ids" in batch and batch["input_ids"] is not None: | |
tokens = batch["input_ids"] | |
print(f"tokens {tokens} {tokens.size()}", file=f) | |
tokens_ = tokens.cpu().clone().detach() | |
tokens_ = tokenizer.batch_decode(tokens_.tolist(), skip_special_tokens=False) | |
print(f"tokens_ {tokens_[:]}", file=f) | |
if "labels" in batch and batch["labels"] is not None: | |
labels = batch["labels"] | |
print(f"labels {labels} {labels.size()}", file=f) | |
labels_ = labels.cpu().clone().detach() | |
labels_[labels_==-100] = tokenizer("-", add_special_tokens=False).input_ids[0] | |
labels_ = tokenizer.batch_decode(labels_.tolist(), skip_special_tokens=False) | |
print(f"labels {labels_}", file=f) | |
# labels__ = labels.cpu().clone().detach() | |
# labels__[loss_mask.to(torch.int64)==0] = tokenizer("-", add_special_tokens=False).input_ids[0] | |
# labels__ = tokenizer.batch_decode(labels__.tolist(), skip_special_tokens=False) | |
# print(f"labels__ {labels__}", file=f) | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
print(f"{k} {v} {v.size()}", file=f) | |
else: | |
print(f"{k} {v}", file=f) | |
f.close() | |
DATA_PRINT_ONCE = False | |
from transformers import Trainer as HFTrainer | |
class Trainer(HFTrainer): | |
def get_train_dataloader(self) -> DataLoader: | |
""" | |
Returns the training [`~torch.utils.data.DataLoader`]. | |
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed | |
training if necessary) otherwise. | |
Subclass and override this method if you want to inject some custom behavior. | |
""" | |
if self.train_dataset is None: | |
raise ValueError("Trainer: training requires a train_dataset.") | |
train_dataset = self.train_dataset | |
data_collator = self.data_collator | |
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): | |
train_dataset = self._remove_unused_columns(train_dataset, description="training") | |
else: | |
data_collator = self._get_collator_with_removed_columns(data_collator, description="training") | |
dataloader_params = { | |
"batch_size": self._train_batch_size, | |
"collate_fn": data_collator, | |
"num_workers": self.args.dataloader_num_workers, | |
"pin_memory": self.args.dataloader_pin_memory, | |
"persistent_workers": self.args.dataloader_persistent_workers, | |
"multiprocessing_context": "spawn", | |
} | |
if not isinstance(train_dataset, torch.utils.data.IterableDataset): | |
dataloader_params["sampler"] = self._get_train_sampler() | |
dataloader_params["drop_last"] = self.args.dataloader_drop_last | |
dataloader_params["worker_init_fn"] = seed_worker | |
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor | |
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) | |
def create_optimizer(self): | |
""" | |
Setup the optimizer. | |
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
Trainer's init through `optimizers`, or subclass and override this method in a subclass. | |
""" | |
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | |
if self.optimizer is None: | |
decay_parameters = self.get_decay_parameter_names(opt_model) | |
if self.args.vision_model_lr_mult != 1.0 or self.args.vision_model_lr_decay_rate != 1.0: | |
vision_parameters = [name for name, _ in opt_model.named_parameters() if "vision_model" in name] | |
logger.info(f"{vision_parameters=}") | |
else: | |
vision_parameters = [] | |
if self.args.mtp_model_lr_mult != 1.0: | |
mtp_parameters = [] | |
mtp_names = ["mtp"] | |
num_nextn_predict_layers = self.model.config.num_nextn_predict_layers | |
num_hidden_layers = self.model.config.num_hidden_layers | |
for mtp_idx in range(num_nextn_predict_layers): | |
layer_idx = num_hidden_layers - num_nextn_predict_layers + mtp_idx | |
mtp_names.append(f"model.layers.{layer_idx}") | |
for name, param in opt_model.named_parameters(): | |
if any([x in name for x in mtp_names]): | |
mtp_parameters.append(name) | |
logger.info(f"{mtp_parameters=}") | |
else: | |
mtp_parameters = [] | |
exclude_parameters = vision_parameters + mtp_parameters | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n not in exclude_parameters) | |
], | |
"weight_decay": self.args.weight_decay, | |
}, | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n not in exclude_parameters) | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
if self.args.vision_model_lr_decay_rate != 1.0: | |
for n, p in opt_model.named_parameters(): | |
if p.requires_grad and n in vision_parameters: | |
pass | |
else: | |
continue | |
if n in decay_parameters: | |
weight_decay = self.args.weight_decay | |
else: | |
weight_decay = 0.0 | |
lr = self.args.learning_rate * get_vit_lr_decay_rate(n, opt_model.config.visual.num_hidden_layers, self.args.vision_model_lr_decay_rate) | |
optimizer_grouped_parameters.append( | |
{ | |
"params": [p], | |
"weight_decay": weight_decay, | |
"lr": lr, | |
} | |
) | |
logger.info(f"create_optimizer name {n} weight_decay {weight_decay} lr {lr}") | |
elif self.args.vision_model_lr_mult != 1.0: | |
optimizer_grouped_parameters.extend( | |
[ | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in vision_parameters) | |
], | |
"weight_decay": self.args.weight_decay, | |
"lr": self.args.learning_rate * self.args.vision_model_lr_mult, | |
}, | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in vision_parameters) | |
], | |
"weight_decay": 0.0, | |
"lr": self.args.learning_rate * self.args.vision_model_lr_mult, | |
}, | |
] | |
) | |
logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in vision_parameters)]} weight_decay {self.args.weight_decay} lr_mult {self.args.vision_model_lr_mult}") | |
logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in vision_parameters)]} weight_decay {0.0} lr_mult {self.args.vision_model_lr_mult}") | |
if self.args.mtp_model_lr_mult != 1.0: | |
optimizer_grouped_parameters.extend( | |
[ | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in mtp_parameters) | |
], | |
"weight_decay": self.args.weight_decay, | |
"lr": self.args.learning_rate * self.args.mtp_model_lr_mult, | |
}, | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in mtp_parameters) | |
], | |
"weight_decay": 0.0, | |
"lr": self.args.learning_rate * self.args.mtp_model_lr_mult, | |
}, | |
] | |
) | |
logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in mtp_parameters)]} weight_decay {self.args.weight_decay} lr_mult {self.args.mtp_model_lr_mult}") | |
logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in mtp_parameters)]} weight_decay {0.0} lr_mult {self.args.mtp_model_lr_mult}") | |
if self.optimizer_cls_and_kwargs is not None: | |
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs | |
else: | |
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) | |
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` | |
# e.g. for GaLore optimizer. | |
if "params" in optimizer_kwargs: | |
optimizer_grouped_parameters = optimizer_kwargs.pop("params") | |
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` | |
# e.g. for LOMO optimizer. | |
if "model" in optimizer_kwargs: | |
optimizer_grouped_parameters = optimizer_kwargs.pop("model") | |
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` | |
# to avoid arguments conflicts. | |
if "optimizer_dict" in optimizer_kwargs: | |
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") | |
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
if optimizer_cls.__name__ == "Adam8bit": | |
import bitsandbytes | |
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() | |
skipped = 0 | |
for module in opt_model.modules(): | |
if isinstance(module, nn.Embedding): | |
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) | |
logger.info(f"skipped {module}: {skipped/2**20}M params") | |
manager.register_module_override(module, "weight", {"optim_bits": 32}) | |
logger.debug(f"bitsandbytes: will optimize {module} in fp32") | |
logger.info(f"skipped: {skipped/2**20}M params") | |
if is_sagemaker_mp_enabled(): | |
self.optimizer = smp.DistributedOptimizer(self.optimizer) | |
return self.optimizer | |
def training_step( | |
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None | |
) -> torch.Tensor: | |
""" | |
Perform a training step on a batch of inputs. | |
Subclass and override to inject custom behavior. | |
Args: | |
model (`nn.Module`): | |
The model to train. | |
inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |
The inputs and targets of the model. | |
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |
argument `labels`. Check your model's documentation for all accepted arguments. | |
Return: | |
`torch.Tensor`: The tensor with training loss on this batch. | |
""" | |
print_batch(inputs, self.processing_class, self.args) | |
model.train() | |
if hasattr(self.optimizer, "train") and callable(self.optimizer.train): | |
self.optimizer.train() | |
inputs = self._prepare_inputs(inputs) | |
if is_sagemaker_mp_enabled(): | |
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) | |
return loss_mb.reduce_mean().detach().to(self.args.device) | |
with self.compute_loss_context_manager(): | |
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) | |
del inputs | |
if ( | |
self.args.torch_empty_cache_steps is not None | |
and self.state.global_step % self.args.torch_empty_cache_steps == 0 | |
): | |
if is_torch_xpu_available(): | |
torch.xpu.empty_cache() | |
elif is_torch_mlu_available(): | |
torch.mlu.empty_cache() | |
elif is_torch_musa_available(): | |
torch.musa.empty_cache() | |
elif is_torch_npu_available(): | |
torch.npu.empty_cache() | |
elif is_torch_mps_available(min_version="2.0"): | |
torch.mps.empty_cache() | |
elif is_torch_hpu_available(): | |
logger.warning( | |
"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()." | |
) | |
else: | |
torch.cuda.empty_cache() | |
kwargs = {} | |
# For LOMO optimizers you need to explicitly use the learnign rate | |
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: | |
kwargs["learning_rate"] = self._get_learning_rate() | |
if self.args.n_gpu > 1: | |
loss = loss.mean() # mean() to average on multi-gpu parallel training | |
if self.use_apex: | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
# Finally we need to normalize the loss for reporting | |
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: | |
loss = loss / self.args.gradient_accumulation_steps | |
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled | |
# https://github.com/huggingface/transformers/pull/35808 | |
if self.accelerator.distributed_type == DistributedType.DEEPSPEED: | |
kwargs["scale_wrt_gas"] = False | |
self.accelerator.backward(loss, **kwargs) | |
return loss.detach() | |
def get_batch_samples(self, epoch_iterator, num_batches): | |
batch_samples = [] | |
num_items_in_batch = None | |
for _ in range(num_batches): | |
try: | |
while True: | |
batch_sample = next(epoch_iterator) | |
if "input_ids" in batch_sample: | |
break | |
batch_samples += [batch_sample] | |
except StopIteration: | |
break | |
if len(batch_samples) > 0 and "labels" in batch_samples[0]: | |
# For now we don't support object detection | |
try: | |
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) | |
except (TypeError, AttributeError): | |
pass | |
if self.args.average_tokens_across_devices and num_items_in_batch is not None: | |
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() | |
if torch.is_tensor(num_items_in_batch): | |
num_items_in_batch = num_items_in_batch.item() | |
return batch_samples, num_items_in_batch | |
def get_vit_lr_decay_rate(name, num_layers, lr_decay_rate): | |
layer_id = num_layers + 1 | |
if "vision_model." in name: | |
if ".position_embedding." in name or ".conv1." in name: | |
layer_id = 0 | |
elif ".layers." in name: | |
layer_id = int(name[name.find(".layers.") :].split(".")[2]) + 1 | |
return lr_decay_rate ** (num_layers + 1 - layer_id) | |