File size: 18,029 Bytes
734b6a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex
This file is adapted from FP16_Optimizer in NVIDIA/apex
"""
from deepspeed.moe.utils import split_params_grads_into_shared_and_expert_params
import torch
from torch._utils import _flatten_dense_tensors
from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger
from deepspeed.utils.torch import required_torch_version
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
from deepspeed.accelerator import get_accelerator
from deepspeed import comm as dist
class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
"""
FP16 Optimizer without weight fusion to support LAMB optimizer
For usage example please see, TODO: DeepSpeed V2 Tutorial
"""
def __init__(self,
init_optimizer,
deepspeed=None,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
mpu=None,
clip_grad=0.0,
fused_lamb_legacy=False):
self.fused_lamb_legacy = fused_lamb_legacy
self._global_grad_norm = 0.
if dist.get_rank() == 0:
logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
if not get_accelerator().is_available():
raise SystemError("Cannot use fp16 without accelerator.")
self.optimizer = init_optimizer
# param groups
self.fp16_groups = []
self.fp32_groups = []
# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
#fp16 weights that represents the actual model weights
self.fp16_groups.append(param_group['params'])
#creating a fp32 copy of the weights that will be updated first then
#copied to fp16 weights
fp32_group = [p.clone().float().detach() for p in param_group['params']]
#in case the internal optimizer needs it
for p in fp32_group:
p.requires_grad = True
#setting the param groups in the optimizer to point to fp32
#note these are not the weights used by the model
#the model uses the fp16 version that we added to fp16_group
self.fp32_groups.append(fp32_group)
param_group['params'] = self.fp32_groups[i]
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
self.dynamic_loss_scale = True
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2.0
if dynamic_loss_args is None:
self.cur_scale = 1.0 * 2**16
self.scale_window = 1000
self.min_loss_scale = 0.25
else:
self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
self.scale_window = dynamic_loss_args[SCALE_WINDOW]
self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
self.cur_scale = static_loss_scale
self.custom_loss_scaler = False
self.external_loss_scale = None
self.verbose = verbose
self.clip_grad = clip_grad
self.norm_type = 2
if required_torch_version(max_version=0.4):
self.clip_grad_norm = torch.nn.utils.clip_grad_norm
else:
self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
self.mpu = mpu
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed)
self.initialize_optimizer_states()
def zero_grad(self, set_to_none=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist outside of the step function
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_to_none:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def step_fused_lamb(self, closure=None):
"""
Not supporting closure.
"""
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
grads_groups = []
norm_groups = []
expert_norm_groups = []
for i, group in enumerate(self.fp16_groups):
grads = [
torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group
]
grads_groups.append(grads)
grads_groups_flat.append(_flatten_dense_tensors(grads))
grads_for_norm, expert_grads_for_norm = split_params_grads_into_shared_and_expert_params(group)
norm_group_value = 0.0
if len(grads_for_norm) > 0:
norm_group_value = get_weight_norm(_flatten_dense_tensors(grads_for_norm), mpu=self.mpu)
norm_groups.append(norm_group_value)
expert_norm_group_value = 0.0
if len(expert_grads_for_norm) > 0:
expert_norm_group_value = get_weight_norm(_flatten_dense_tensors(expert_grads_for_norm), mpu=self.mpu)
expert_norm_groups.append(expert_norm_group_value)
self.overflow = self.overflow_checker.check_using_norm(norm_groups + expert_norm_groups)
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
return self.overflow
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
combined_scale = self.unscale_and_clip_grads(self._global_grad_norm, apply_scale=False)
self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale)
for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
#remove the fp32 grad
fp32_param.grad = None
#copy data from fp32 to fp16
fp16_param.data.copy_(fp32_param.data)
return self.overflow
def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]["lr"]
def override_loss_scale(self, loss_scale):
if loss_scale != self.external_loss_scale:
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale
def step(self, closure=None):
"""
Not supporting closure.
"""
if self.fused_lamb_legacy:
return self.step_fused_lamb()
self.overflow = self.overflow_checker.check()
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
return self.overflow
norm_groups = []
for i, group in enumerate(self.fp16_groups):
grads_for_norm, _ = split_params_grads_into_shared_and_expert_params(group)
norm_group_value = 0.0
if len(grads_for_norm) > 0:
norm_group_value = get_weight_norm(grads_for_norm, mpu=self.mpu)
norm_groups.append(norm_group_value)
# copying gradients to fp32 to work with fp32 parameters
for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]):
if fp16_param.grad is None:
fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device)
else:
fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
self.unscale_and_clip_grads(self._global_grad_norm)
self.optimizer.step()
for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups):
for idx, (fp32_param, fp16_param) in enumerate(zip(fp32_group, fp16_group)):
#remove the fp32 grad
fp32_param.grad = None
#copy data from fp32 to fp16
fp16_param.data.copy_(fp32_param.data)
return self.overflow
def unscale_and_clip_grads(self, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
if clip > 1:
combined_scale = clip * self.cur_scale
if apply_scale:
for group in self.fp32_groups:
for param in group:
if param.grad is not None:
param.grad.data.mul_(1. / combined_scale)
return combined_scale
def backward(self, loss, create_graph=False, retain_graph=False):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
if self.custom_loss_scaler:
scaled_loss = self.external_loss_scale * loss
scaled_loss.backward()
else:
scaled_loss = (loss.float()) * self.cur_scale
scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)
def _update_scale(self, skip):
if self.dynamic_loss_scale:
prev_scale = self.cur_scale
if skip:
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale)
self.last_overflow_iter = self.cur_iter
if self.verbose:
logger.info("Grad overflow on iteration: %s", self.cur_iter)
logger.info(f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}")
else:
# Ensure self.scale_window updates since last overflow
stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
self.cur_scale *= self.scale_factor
if self.verbose:
logger.info(f"No Grad overflow for {self.scale_window} iterations")
logger.info(f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}")
else:
if skip:
logger.info("Grad overflow on iteration %s", self.cur_iter)
logger.info("Using static loss scale of %s", self.cur_scale)
self.cur_iter += 1
return
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
if self.custom_loss_scaler:
return self.external_loss_scale
else:
return self.cur_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['cur_scale'] = self.cur_scale
state_dict['cur_iter'] = self.cur_iter
if state_dict['dynamic_loss_scale']:
state_dict['last_overflow_iter'] = self.last_overflow_iter
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
state_dict['fp32_groups'] = self.fp32_groups
return state_dict
# Refresh fp32 master params from fp16 copies
def refresh_fp32_params(self):
for current_group, saved_group in zip(self.fp32_groups, self.fp16_groups):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
if state_dict['dynamic_loss_scale']:
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_groups, state_dict['fp32_groups']):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)
def __repr__(self):
return repr(self.optimizer)
def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
for param in group:
param.grad = torch.zeros(param.size(),
dtype=param.dtype,
device=get_accelerator().current_device_name())
for i, group in enumerate(self.fp32_groups):
for param in group:
param.grad = torch.zeros(param.size(),
dtype=param.dtype,
device=get_accelerator().current_device_name())
self.optimizer.step()
for i, group in enumerate(self.fp16_groups):
for param in group:
param.grad = None
for i, group in enumerate(self.fp32_groups):
for param in group:
param.grad = None
|