Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import itertools | |
from contextlib import contextmanager | |
from copy import deepcopy | |
from dataclasses import dataclass | |
from typing import TYPE_CHECKING, Any, Literal, Optional, Union | |
import torch.nn as nn | |
from packaging import version | |
from transformers import PreTrainedModel, PreTrainedTokenizer | |
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead | |
SUPPORTED_ARCHITECTURES = ( | |
AutoModelForCausalLMWithValueHead, | |
AutoModelForSeq2SeqLMWithValueHead, | |
) | |
if TYPE_CHECKING: | |
from accelerate import Accelerator | |
from deepspeed.runtime.engine import DeepSpeedEngine | |
from torch.nn import Module | |
from torch.nn.parallel.distributed import DistributedDataParallel | |
# TODO: Add Abstract Base Class if more formats are added | |
class ChatMlSpecialTokens: | |
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" | |
bos_token: str = "<|im_start|>" | |
eos_token: str = "<|im_end|>" | |
pad_token: str = "<|im_end|>" | |
def system(self): | |
return f"{self.bos_token}system" | |
def user(self): | |
return f"{self.bos_token}user" | |
def assistant(self): | |
return f"{self.bos_token}assistant" | |
def chat_template(self): | |
return ( | |
"{% for message in messages %}" | |
f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" | |
"{% endfor %}" | |
"{% if add_generation_prompt %}" | |
f"{{{{ '{self.assistant}\n' }}}}" | |
"{% endif %}" | |
) | |
FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} | |
def setup_chat_format( | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer, | |
format: Optional[Literal["chatml"]] = "chatml", | |
resize_to_multiple_of: Optional[int] = None, | |
) -> tuple[PreTrainedModel, PreTrainedTokenizer]: | |
""" | |
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. | |
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`. | |
Args: | |
model (`~transformers.PreTrainedModel`): The model to be modified. | |
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. | |
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". | |
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None. | |
Returns: | |
model (`~transformers.PreTrainedModel`): The modified model. | |
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. | |
""" | |
# check if model already had a chat template | |
if tokenizer.chat_template is not None: | |
raise ValueError( | |
"Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None" | |
) | |
# check if format available and retrieve | |
if format not in FORMAT_MAPPING: | |
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") | |
chat_format = FORMAT_MAPPING[format]() | |
# set special tokens and them | |
tokenizer.eos_token = chat_format.eos_token | |
tokenizer.pad_token = chat_format.pad_token | |
tokenizer.bos_token = chat_format.bos_token | |
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) | |
# set chat format for tokenizer | |
tokenizer.chat_template = chat_format.chat_template | |
# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 | |
model.resize_token_embeddings( | |
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None | |
) | |
# Update the model config to use the new eos & bos tokens | |
if getattr(model, "config", None) is not None: | |
model.config.pad_token_id = tokenizer.pad_token_id | |
model.config.bos_token_id = tokenizer.bos_token_id | |
model.config.eos_token_id = tokenizer.eos_token_id | |
# Update the generation config to use the new eos & bos token | |
if getattr(model, "generation_config", None) is not None: | |
model.generation_config.bos_token_id = tokenizer.bos_token_id | |
model.generation_config.eos_token_id = tokenizer.eos_token_id | |
model.generation_config.pad_token_id = tokenizer.pad_token_id | |
return model, tokenizer | |
def remove_hooks(model: "DeepSpeedEngine") -> None: | |
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.""" | |
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer | |
return | |
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): | |
optimizer_offload = model.optimizer.parameter_offload | |
elif model.optimizer is not None: | |
optimizer_offload = model.optimizer | |
else: | |
raise RuntimeError("The model optimizer is None, which is not yet supported.") | |
for param in iter_params(optimizer_offload.module, recurse=True): | |
param.ds_active_sub_modules.clear() | |
for hook in optimizer_offload.forward_hooks: | |
hook.remove() | |
for hook in optimizer_offload.backward_hooks: | |
hook.remove() | |
optimizer_offload.forward_hooks = [] | |
optimizer_offload.backward_hooks = [] | |
def get_all_parameters(sub_module, recurse=False): | |
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) | |
def iter_params(module, recurse=False): | |
return [param for _, param in get_all_parameters(module, recurse)] | |
def add_hooks(model: "DeepSpeedEngine") -> None: | |
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" | |
import deepspeed | |
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer | |
return | |
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): | |
optimizer_offload = model.optimizer.parameter_offload | |
elif model.optimizer is not None: | |
optimizer_offload = model.optimizer | |
else: | |
raise RuntimeError("The model optimizer is None, which is not yet supported.") | |
if version.parse(deepspeed.__version__) >= version.parse("0.16.4"): | |
# Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847 | |
optimizer_offload._register_deepspeed_module(optimizer_offload.module) | |
else: | |
optimizer_offload._register_hooks_recursively(optimizer_offload.module) | |
def unwrap_model_for_generation( | |
model: Union["DistributedDataParallel", "DeepSpeedEngine"], | |
accelerator: "Accelerator", | |
gather_deepspeed3_params: bool = True, | |
): | |
""" | |
Context manager to unwrap distributed or accelerated models for generation tasks. | |
Args: | |
model (`Union[DistributedDataParallel, DeepSpeedEngine]`): | |
Model to be unwrapped. | |
accelerator (`~accelerate.Accelerator`): | |
Accelerator instance managing the model. | |
gather_deepspeed3_params (`bool`, *optional*, defaults to `True`): | |
Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which | |
can be more memory-efficient but may lead to slower generation times. | |
Yields: | |
Unwrapped model. | |
Example: | |
```python | |
with unwrap_model_for_generation(model, accelerator) as unwrapped_model: | |
generated_outputs = unwrapped_model.generate(input_ids) | |
``` | |
""" | |
unwrapped_model = accelerator.unwrap_model(model) | |
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: | |
if not gather_deepspeed3_params: | |
yield accelerator.unwrap_model(model) | |
else: | |
import deepspeed | |
with deepspeed.zero.GatheredParameters(model.parameters()): | |
remove_hooks(model) | |
yield accelerator.unwrap_model(model) | |
add_hooks(model) | |
else: | |
yield unwrapped_model | |
def prepare_deepspeed(model: "Module", accelerator: "Accelerator"): | |
"""Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration. | |
Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 | |
""" | |
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252 | |
deepspeed_plugin = accelerator.state.deepspeed_plugin | |
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) | |
stage = config_kwargs["zero_optimization"]["stage"] | |
if model is not None: | |
hidden_size = ( | |
max(model.config.hidden_sizes) | |
if getattr(model.config, "hidden_sizes", None) | |
else getattr(model.config, "hidden_size", None) | |
) | |
if hidden_size is not None and stage == 3: | |
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache | |
# @ step 0: expected module 1, but got module 0` | |
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 | |
config_kwargs.update( | |
{ | |
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size, | |
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, | |
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, | |
} | |
) | |
# If ZeRO-3 is used, we shard both the active and reference model. | |
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO | |
# disabled (stage 0) | |
if stage != 3: | |
config_kwargs["zero_optimization"]["stage"] = 0 | |
model, *_ = deepspeed.initialize(model=model, config=config_kwargs) | |
model.eval() | |
return model | |
def prepare_fsdp(model, accelerator): | |
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1421 | |
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP | |
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so, | |
# don't wrap it again | |
if not isinstance(model, FSDP): | |
accelerator.state.fsdp_plugin.set_auto_wrap_policy(model) | |
fsdp_plugin = accelerator.state.fsdp_plugin | |
kwargs = { | |
"sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward, | |
"cpu_offload": fsdp_plugin.cpu_offload, | |
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy, | |
"mixed_precision": fsdp_plugin.mixed_precision_policy, | |
"sync_module_states": fsdp_plugin.sync_module_states, | |
"backward_prefetch": fsdp_plugin.backward_prefetch, | |
"forward_prefetch": fsdp_plugin.forward_prefetch, | |
"use_orig_params": fsdp_plugin.use_orig_params, | |
"param_init_fn": fsdp_plugin.param_init_fn, | |
"ignored_modules": fsdp_plugin.ignored_modules, | |
"limit_all_gathers": fsdp_plugin.limit_all_gathers, | |
"device_id": accelerator.device, | |
} | |
model = FSDP(model, **kwargs) | |
model.eval() | |
return model | |
class _ForwardRedirection: | |
"""Implements the `forward-redirection`. | |
Taken from Pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602 | |
A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead. | |
""" | |
def __call__( | |
self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any, **kwargs: Any | |
): | |
"""Reroutes a method call through the `wrapper_module`'s `forward` method. | |
Args: | |
wrapper_module: The module that has `original_module` wrapped. | |
original_module: The module that was wrapped inside `wrapper_module`. | |
method_name: The name of the method that should be called on the `original_module` after inputs get | |
redirected through the `wrapper_module`'s `forward` method. | |
*args: The positional arguments to the method `method_name`. They will get passed to a patched | |
`forward` method instead. | |
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched | |
`forward` method instead. | |
""" | |
original_forward = original_module.forward | |
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: | |
# Unpatch ourselves immediately before calling the method `method_name` | |
# because itself may want to call the real `forward` | |
original_module.forward = original_forward # type: ignore[method-assign] | |
# Call the actual method e.g. `.training_step(...)` | |
out = method(*_args, **_kwargs) | |
self.on_after_inner_forward(wrapper_module, original_module) | |
return out | |
# Patch the original_module's forward so we can redirect the arguments back to the real method | |
original_module.forward = wrapped_forward # type: ignore[method-assign] | |
wrapper_output = wrapper_module(*args, **kwargs) | |
self.on_after_outer_forward(wrapper_module, original_module) | |
return wrapper_output | |
def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: | |
pass | |
def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: | |
pass | |