File size: 12,217 Bytes
9ad9e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import sys
from typing import Optional
from enum import Enum
from deepspeed.pydantic_v1 import Field, validator, root_validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
# ZeRO optimization. By default, this optimization is not enabled.
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
ZERO_FORMAT = """
ZeRO optimization should be enabled as:
"session_params": {
"zero_optimization": {
"stage": [0|1|2],
"stage3_max_live_parameters" : 1000000000,
"stage3_max_reuse_distance" : 1000000000,
"allgather_partitions": [true|false],
"use_multi_rank_bucket_allreduce": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000,
"load_from_fp32_weights": [true|false],
"cpu_offload": [true|false] (deprecated),
"cpu_offload_params" : [true|false] (deprecated),
"cpu_offload_use_pin_memory": [true|false] (deprecated),
"sub_group_size" : 1000000000000,
"offload_param": {...},
"offload_optimizer": {...},
"ignore_unused_parameters": [true|false],
"round_robin_gradients": [true|false],
"zero_hpz_partition_size": 1,
"zero_quantized_weights": [true|false],
"zero_quantized_nontrainable_weights": [true|false],
"zero_quantized_gradients": [true|false],
"memory_efficient_linear": [true|false],
"override_module_apply": [true|false],
}
}
"""
ZERO_OPTIMIZATION = "zero_optimization"
def read_zero_config_deprecated(param_dict):
zero_config_dict = {}
zero_config_dict["stage"] = 1 if param_dict[ZERO_OPTIMIZATION] else 0
if zero_config_dict["stage"] > 0:
zero_config_dict["allgather_bucket_size"] = get_scalar_param(param_dict, "allgather_size", 5e8)
logger.warning(
"DeepSpeedConfig: this format of ZeRO optimization setup is deprecated. Please use the following format: {}".
format(ZERO_FORMAT))
return zero_config_dict
def get_zero_config(param_dict):
if ZERO_OPTIMIZATION in param_dict:
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
if isinstance(zero_config_dict, bool):
zero_config_dict = read_zero_config_deprecated(param_dict)
else:
zero_config_dict = {}
return DeepSpeedZeroConfig(**zero_config_dict)
class ZeroStageEnum(int, Enum):
""" Enum class for possible zero stages """
disabled = 0
optimizer_states = 1
gradients = 2
weights = 3
max_stage = 3
class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""
Sets parameters for ZeRO optimizations.
"""
stage: ZeroStageEnum = 0
"""
Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer
to disabled, optimizer state partitioning, and optimizer+gradient state
partitioning, and optimizer+gradient+parameter partitioning, respectively.
"""
contiguous_gradients: bool = True
"""
Copies the gradients to a contiguous buffer as they are produced. Avoids
memory fragmentation during backward pass.
"""
reduce_scatter: bool = True
"""
Uses reduce or reduce scatter instead of allreduce to average gradients
"""
reduce_bucket_size: int = Field(pp_int(5e8), ge=0)
"""
Number of elements reduced/allreduced at a time. Limits the memory required
for the allgather for large model sizes
"""
use_multi_rank_bucket_allreduce: bool = True
"""
Combine the reduce buckets of the different ranks and do an All-Reduce instead of multiple Reduce ops.
This feature is useful when the model is small and we want to scale it on too many GPUs which therefore
reduces the message sizes of each packet.
"""
allgather_partitions: bool = True
"""
Chooses between allgather collective or a series of broadcast collectives
to gather updated parameters from all the GPUs at the end of each step
"""
allgather_bucket_size: int = Field(pp_int(5e8), ge=0)
"""
Number of elements allgathered at a time. Limits the memory required for
the allgather for large model sizes
"""
overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
"""
Attempts to overlap the reduction of the gradients with backward computation
"""
load_from_fp32_weights: bool = True
"""
Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoint (no precision loss) or from model's fp16 copies (with
precision loss). This can be used to initialize optimizer state even when
checkpoint is missing optimizer state.
"""
elastic_checkpoint: bool = False
"""
Enable loading checkpoint that was saved by job with different GPU count.
No longer supported.
"""
offload_param: Optional[DeepSpeedZeroOffloadParamConfig] = None
"""
Enable offloading of model parameters to CPU or NVMe. This frees up GPU
memory for larger models or batch sizes. Valid only with stage 3. Expects a
dictionary containing values for :any:`DeepSpeedZeroOffloadParamConfig`.
"""
offload_optimizer: Optional[DeepSpeedZeroOffloadOptimizerConfig] = None
"""
Enable offloading of optimizer state to CPU or NVMe, and optimizer
computation to CPU. This frees up GPU memory for larger models or batch
sizes. Valid for ZeRO stage 1, 2, 3. Expects a dictionary containing values
for :any:`DeepSpeedZeroOffloadOptimizerConfig`.
"""
sub_group_size: int = Field(pp_int(1e9), ge=0)
"""
Tile size for parameter processing to fit massive models (with trillions of
parameters). Used by ZeRO3-Offload and ZeRO-Infinity
"""
cpu_offload_param: bool = Field(
None,
deprecated=True,
new_param="offload_param",
new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None),
)
""" Deprecated, please use ``offload_param`` """
cpu_offload_use_pin_memory: bool = Field(
None,
deprecated=True,
new_param="offload_param or offload_optimizer",
set_new_param=False,
)
""" Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
cpu_offload: bool = Field(
None,
deprecated=True,
new_param="offload_optimizer",
new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None),
)
""" Deprecated, please use ``offload_optimizer`` """
prefetch_bucket_size: int = Field(pp_int(5e7), ge=0, alias="stage3_prefetch_bucket_size")
"""
Maximum number of parameter elements to fetch ahead of use. Used by ZeRO3,
ZeRO3-Offload, ZeRO-Infinity, and ZeRO-Inference.
"""
param_persistence_threshold: int = Field(pp_int(1e5), ge=0, alias="stage3_param_persistence_threshold")
"""
Do not partition parameters smaller than this threshold. Smaller values use
less memory, but can greatly increase communication (especially
latency-bound messages).
"""
model_persistence_threshold: int = Field(pp_int(sys.maxsize, "sys.maxsize"),
ge=0,
alias="stage3_model_persistence_threshold")
"""
Maximum number of parameter elements that can be persisted in GPU and not
partitioned. This imposes an upper bound on the number of unpartitioned
parameters resulting from param_persistence_threshold setting. Used by
ZeRO3-Offload, ZeRO-Infinity and ZeRO-Inference.
"""
max_live_parameters: int = Field(pp_int(1e9), ge=0, alias="stage3_max_live_parameters")
"""
The maximum number of parameters resident per GPU before releasing. Smaller
values use less memory, but perform more communication.
"""
max_reuse_distance: int = Field(pp_int(1e9), ge=0, alias="stage3_max_reuse_distance")
"""
Do not release a parameter if it will be reused within this threshold of
parameters. Smaller values use less memory, but perform more communication.
"""
gather_16bit_weights_on_model_save: bool = Field(False, alias="stage3_gather_16bit_weights_on_model_save")
"""
Consolidate the weights before saving the model by ``save_16bit_model()``.
Since the weights are partitioned across GPUs, they aren’t part of
``state_dict``, so this function automatically gathers the weights when
this option is enabled and then saves the fp16 model weights.
"""
stage3_gather_fp16_weights_on_model_save: bool = Field(False,
deprecated=True,
new_param="gather_16bit_weights_on_model_save")
""" Deprecated, please use ``gather_16bit_weights_on_model_save`` """
ignore_unused_parameters: bool = True
"""
Unused parameters in modules may be unexpected in static networks, but
could be normal in dynamic networks. This controls whether or not training
should terminate with an error message when unused parameters are detected.
This is set to ``True`` by default, which means unused parameters are
ignored and training continues. Now is just used in stage 2.
"""
legacy_stage1: bool = False
"""
For backward-compatibility enable old ZeRO stage 1 implementation. Use at
your own risk, will be deprecated soon.
"""
round_robin_gradients: bool = False
"""
Stage 1 and 2 optimization for CPU offloading that parallelizes gradient
copying to CPU memory among ranks by fine-grained gradient partitioning.
Performance benefit grows with gradient accumulation steps (more copying
between optimizer steps) or GPU count (increased parallelism).
"""
zero_hpz_partition_size: int = Field(1, ge=0)
"""
Number of ranks in zero parameters partitioning secondary group
"""
zero_quantized_weights: bool = False
"""
Boolean indicating whether to quantize zero parameters (weights)
for efficient all_gather comm
"""
zero_quantized_nontrainable_weights: bool = False
"""
Boolean indicating whether to quantize non-trainable zero parameters (weights)
for efficient memory usage and communication. Different from zero_quantized_weights
that stores the weights in original precision and only perform quantization during communication,
this flag will store the weights in quantized precision. This is useful for LoRA training.
"""
zero_quantized_gradients: bool = False
"""
Boolean indicating whether to use quantized zero gradients
for efficient all_2_all_reduce comm
"""
mics_shard_size: int = Field(-1, new_param="mics_shard_size")
mics_hierarchical_params_gather: bool = False
memory_efficient_linear: bool = True
"""
Use memory efficient linear implementation, for Stage 3.
"""
"""
Whether force load checkpoint in pipeline mode, current only for Stage 3.
"""
pipeline_loading_checkpoint: bool = False
override_module_apply: bool = True
"""
Override nn.Module apply function, for Stage 3.
"""
# Validators
@validator("overlap_comm")
def overlap_comm_valid(cls, field_value, values):
if field_value is None:
assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
field_value = values["stage"] == ZeroStageEnum.weights
return field_value
@root_validator
def offload_ratio_check(cls, values):
offload_config = getattr(values, "offload_optimizer", {})
if offload_config and offload_config.ratio < 1.0:
assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
return values
|