File size: 23,247 Bytes
734b6a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import types
import torch
import numpy as np
from deepspeed import comm as dist
from deepspeed.utils.torch import required_torch_version
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.accelerator import get_accelerator
class OnebitLamb(torch.optim.Optimizer):
"""Implements the 1-bit Lamb algorithm. Currently GPU-only.
For usage example please see https://www.deepspeed.ai/tutorials/onebit-lamb/
For technical details please see our paper https://arxiv.org/abs/2104.06069.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
freeze_step (int, optional): Number of steps for warmup (uncompressed)
stage before we start using compressed communication. (default 100000)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in 1-bit Lamb!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
cuda_aware (boolean, required): Set True if the underlying MPI implementation
supports CUDA-Aware communication. (default: False)
comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl')
coeff_beta (float, optional): coefficient used for computing
running averages of lamb coefficient (default: 0.9) note that you may want to
increase or decrease this beta depending on the freeze_step you choose, as
1/(1 - coeff_beta) should be smaller than or equal to freeze_step
factor_max (float, optional): maximum value of scaling factor to the frozen lamb
coefficient during compression stage (default: 4.0)
factor_min (float, optional): minimum value of scaling factor to the frozen lamb
coefficient during compression stage (default: 0.5)
factor_threshold (float, optional): threshold of how much the scaling factor can
fluctuate between steps (default: 0.1)
.. _Large Batch Optimization for Deep Learning\\: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _Adam\\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
deepspeed=None,
lr=1e-3,
freeze_step=100000,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
eps_inside_sqrt=False,
weight_decay=0.,
max_grad_norm=0.,
max_coeff=10.0,
min_coeff=0.01,
amsgrad=False,
cuda_aware=False,
comm_backend_name='nccl',
coeff_beta=0.9,
factor_max=4.0,
factor_min=0.5,
factor_threshold=0.1):
if amsgrad:
raise RuntimeError('1-bit Lamb does not support the AMSGrad variant.')
defaults = dict(lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
max_coeff=max_coeff,
min_coeff=min_coeff)
super(OnebitLamb, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self.deepspeed = deepspeed
self.lamb_freeze_key = False
self.initialize = False
self.freeze_step = freeze_step
self.cuda_aware = cuda_aware
self.coeff_beta = coeff_beta
self.factor_max = factor_max
self.factor_min = factor_min
self.factor_threshold = factor_threshold
self.using_pipeline = False
self.comm_backend_name = comm_backend_name
assert dist.is_initialized(), "Please initialize the torch distributed backend."
# Empty initializer. Set handle based on the comm backend as follows.
self.comm_backend_handle = None
if self.comm_backend_name == 'nccl':
assert (
required_torch_version(min_version=1.8)
), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
from deepspeed.runtime.comm.nccl import NcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
self.comm_backend_handle = MpiBackend(cuda_aware)
elif self.comm_backend_name == 'hccl':
from deepspeed.runtime.comm.hccl import HcclBackend
self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce')
self.comm_backend_handle = HcclBackend(self.deepspeed.mpu)
self.size = self.comm_backend_handle.size
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
self.exp_avg_flat = []
self.dummy_exp_avg = {}
self.corrected_tensor_sizes = []
self.server_chunk_sizes = []
self.worker_errors = []
self.server_errors = []
self.lamb_coeffs = []
def step(self, closure=None, grads=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
"""
loss = None
if closure is not None:
loss = closure()
if grads is None:
grads_group = [None] * len(self.param_groups)
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0]) != list:
grads_group = [grads]
else:
grads_group = grads
# remove the previous stats
del self.lamb_coeffs[:]
if self.lamb_freeze_key:
exp_avg_last_step = []
for group in self.param_groups:
exp_avg_last_step.append([self.state[p]['exp_avg'].detach().clone() for p in group['params']])
if 'scaling_coeff' not in self.state[self.param_groups[0]['params'][0]]:
# Compute the scaling_coeff for each momentum at the end of warmup stage.
# This is used to reduce compression error during compression stage.
momentum_scales = []
for group in self.param_groups:
momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) /
np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
for p in group['params']])
united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])
for i, group in enumerate(self.param_groups):
for j, p in enumerate(group['params']):
self.state[p]['scaling_coeff'] = united_scale / momentum_scales[i][j]
for group, grads_this_group in zip(self.param_groups, grads_group):
if grads_this_group is None:
grads_this_group = [None] * len(group['params'])
bias_correction = 1 if group['bias_correction'] else 0
for p, grad in zip(group['params'], grads_this_group):
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('1-bit Lamb does not support sparse gradients')
state = self.state[p]
# State initialization
if len(state) == 0 or (len(state) == 1 and 'scaling_coeff' in state.keys()):
state['step'] = 0
state['lamb_coeff_freeze'] = 0.0
state['last_factor'] = 1.0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg_sq_fresh'] = torch.zeros_like(p.data)
if not self.initialize:
self.lamb_freeze_key = True
exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state[
'exp_avg_sq_fresh']
beta1, beta2 = group['betas']
max_coeff = group['max_coeff']
min_coeff = group['min_coeff']
state['step'] += 1
if self.lamb_freeze_key is False:
# warmup stage, baseline Lamb optimization
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if state['step'] == self.freeze_step:
exp_avg_sq_fresh.data = exp_avg_sq.detach().clone()
grad = None
if self.initialize:
weight_norm = p.data.pow(2).sum().sqrt()
update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
update_norm = update.pow(2).sum().sqrt()
lamb_coeff = 1.0
if weight_norm != 0 and update_norm != 0:
lamb_coeff = (weight_norm / update_norm).item()
if lamb_coeff > max_coeff:
lamb_coeff = max_coeff
if lamb_coeff < min_coeff:
lamb_coeff = min_coeff
if lamb_coeff != 1.0:
state['lamb_coeff_freeze'] = self.coeff_beta * state['lamb_coeff_freeze'] + (
1 - self.coeff_beta) * lamb_coeff
self.lamb_coeffs.append(lamb_coeff)
with torch.no_grad():
p.add_(-group['lr'] * lamb_coeff * update)
else:
# compression stage, update each momentum locally, then
# communicate based on the compressed_allreduce below
if self.initialize:
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg.mul_(self.state[p]['scaling_coeff'])
grad = None
# init fused momentum
if len(self.exp_avg_flat) == 0:
momentum_groups = []
tensor_size = 0
for group in self.param_groups:
for p in group['params']:
momentum_groups.append(self.state[p]['exp_avg'])
tensor_size += torch.numel(p.data)
corrected_tensor_size = tensor_size
if tensor_size % (self.size * self.divider) != 0:
difference = ((self.size * self.divider) - (tensor_size % (self.size * self.divider)))
corrected_tensor_size += difference
self.dummy_exp_avg[0] = torch.zeros(difference, device=momentum_groups[0].data.device)
momentum_groups.append(self.dummy_exp_avg[0])
self.corrected_tensor_sizes.append(corrected_tensor_size)
self.server_chunk_sizes.append(corrected_tensor_size // self.size)
self.exp_avg_flat.append(_flatten_dense_tensors([p.detach().clone() for p in momentum_groups]))
updated_params = _unflatten_dense_tensors(self.exp_avg_flat[0], momentum_groups)
for p, q in zip(momentum_groups, updated_params):
p.data = q.data
if self.initialize and len(self.worker_errors) == 0:
get_accelerator().empty_cache()
for i in range(len(self.exp_avg_flat)):
self.worker_errors.append(
torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device))
self.server_errors.append(torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device))
get_accelerator().empty_cache()
if self.lamb_freeze_key:
if self.size > 1:
for i in range(len(self.exp_avg_flat)):
if not self.initialize:
get_accelerator().empty_cache()
self.worker_errors.append(
torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device))
self.server_errors.append(
torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device))
get_accelerator().empty_cache()
if dist.get_rank() == 0:
print("Cupy Buffers Initialized Successfully.")
self.comm_backend_handle.compressed_allreduce(self.exp_avg_flat[i], self.worker_errors[0],
self.server_errors[0], self.deepspeed.local_rank)
if dist.get_rank() == 0:
print('Pop out errors', flush=True)
del self.worker_errors[:]
del self.server_errors[:]
else:
self.comm_backend_handle.compressed_allreduce(self.exp_avg_flat[i], self.worker_errors[i],
self.server_errors[i], self.deepspeed.local_rank)
if self.lamb_freeze_key and self.initialize:
for i, group in enumerate(self.param_groups):
bias_correction = 1 if group['bias_correction'] else 0
for j, p in enumerate(group['params']):
state = self.state[p]
exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state[
'exp_avg_sq_fresh']
beta1, beta2 = group['betas']
exp_avg.div_(self.state[p]['scaling_coeff'])
# Because 1-bit compression cannot represent exact zero, it is required to
# provide a momentum mask for those params that have constant exact zeros in their
# momentums, otherwise the compression error would keep accumulating.
# For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight
# always have exact zeros in its momentum for row 129 to 512, because it only
# learns up to seq length 128 while the model supports up to 512 seq length.
# (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py about how
# to add this exp_avg_mask for BERT pre-training.)
if 'exp_avg_mask' in group:
if exp_avg.device != group['exp_avg_mask'].device:
group['exp_avg_mask'] = group['exp_avg_mask'].to(device=exp_avg.device)
exp_avg.mul_(group['exp_avg_mask'])
grad_reconstruct = ((exp_avg - exp_avg_last_step[i][j] * beta1) / (1 - beta1))
exp_avg_sq_fresh.mul_(beta2).addcmul_(1 - beta2, grad_reconstruct, grad_reconstruct)
denom = exp_avg_sq.sqrt() + group['eps']
update_prelim = exp_avg / denom
if group['weight_decay'] > 0.0:
update = update_prelim + group['weight_decay'] * p.data
else:
update = update_prelim
lamb_coeff = 1.0
update_norm = update.pow(2).sum().sqrt()
denom_real = exp_avg_sq_fresh.sqrt() + group['eps']
factor = (denom / denom_real).max().item()
if group['weight_decay'] > 0.0:
update_ratio = min(1.0, (update_prelim.pow(2).sum().sqrt() / update_norm).item())
factor = factor * update_ratio + (1.0 - update_ratio)
if factor > self.factor_max:
factor = self.factor_max
if factor < self.factor_min:
factor = self.factor_min
if factor > state['last_factor'] * (1.0 + self.factor_threshold):
factor = state['last_factor'] * (1.0 + self.factor_threshold)
if factor < state['last_factor'] * (1.0 - self.factor_threshold):
factor = state['last_factor'] * (1.0 - self.factor_threshold)
state['last_factor'] = factor
lamb_coeff = state['lamb_coeff_freeze'] * factor
self.lamb_coeffs.append(lamb_coeff)
with torch.no_grad():
p.add_(-group['lr'] * lamb_coeff * update)
del exp_avg_last_step[:]
exp_avg_last_step = None
if not self.initialize:
self.lamb_freeze_key = False
self.initialize = True
print(f"Finished the initialization step at rank {dist.get_rank()}")
return loss
if self.lamb_freeze_key is False:
if state['step'] >= self.freeze_step:
print('OnebitLamb - starting compressed communication')
self.lamb_freeze_key = True
if self.using_pipeline:
self.deepspeed.pipeline_enable_backward_allreduce = False
else:
self.deepspeed.enable_backward_allreduce = False
return loss
def load_state_dict(self, state_dict):
"""
Overrides load_state_dict() to add special handling when loading checkpoints
"""
# Because at different stage exp_avg_mask may change (e.g.,
# BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask
# in checkpoints but always use the one user provided in training script.
# (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
# Thus here we keep the exp_avg_mask unchanged when loading checkpoint
for i, group in enumerate(self.param_groups):
if 'exp_avg_mask' in group:
state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask']
elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict['param_groups'][i]:
state_dict['param_groups'][i].pop('exp_avg_mask')
super().load_state_dict(state_dict)
# need to reset the fused momentum since loading states will break the linking
del self.exp_avg_flat[:]
self.dummy_exp_avg.clear()
del self.corrected_tensor_sizes[:]
del self.server_chunk_sizes[:]
if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
if dist.get_rank() == 0:
print("Checkpoint loaded and OnebitLamb warmup stage starts/continues.")
if self.lamb_freeze_key is True:
self.lamb_freeze_key = False
if self.using_pipeline:
self.deepspeed.pipeline_enable_backward_allreduce = True
else:
self.deepspeed.enable_backward_allreduce = True
for group in self.param_groups:
for p in group['params']:
self.state[p]['lamb_coeff_freeze'] = 0.0
self.state[p]['last_factor'] = 1.0
if 'scaling_coeff' in self.state[p]:
self.state[p].pop('scaling_coeff')
else:
if dist.get_rank() == 0:
print("Checkpoint loaded and OnebitLamb compression stage starts/continues.")
if self.lamb_freeze_key is False:
self.lamb_freeze_key = True
if self.using_pipeline:
self.deepspeed.pipeline_enable_backward_allreduce = False
else:
self.deepspeed.enable_backward_allreduce = False
# We reset the compression errors when loading checkpoints for 3 reasons:
# 1) The worker and server error at each GPU are distinct, so in current implementation
# only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
# If we want to save them correctly we need O(num_gpu*model_size) memory in order to
# gather all the error, which is a very large memory requirement. It's possible to save
# them in a distributed way, but it will make the checkpoint saving/loading much more complicated.
# 2) Even if we are able to save the compression errors correctly, you need to have the
# exact same number of GPUs in order to load them correctly.
# 3) We verified on BERT pre-training that occasionally resetting the compression error
# at checkpoint loading does not affect the convergence.
# However, please avoid frequent checkpoint loading which could break the error
# compensation mechanism thus affect the convergence.
del self.worker_errors[:]
del self.server_errors[:]
def get_lamb_coeffs(self):
return self.lamb_coeffs
|