File size: 13,762 Bytes
a5dc865 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Create a container object to save model-specific tensors using the policy file above.
from abc import ABC
import torch
import deepspeed
from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
from deepspeed.accelerator import get_accelerator
# If the intermediate size attribute is set DEFAULT_INTERMEDIATE_SIZE
# it is assumed the intermediate size is 4x the embedding dimension
DEFAULT_INTERMEDIATE_SIZE = -1
class BaseConvolutionContainer(ABC):
# not implemented
def __init__(self):
pass
class BaseTransformerContainer(ABC):
def __init__(self, policy, config, model_config, layer_id, child):
self.policy = policy
self.config = config
self.model_config = model_config
self.layer_id = layer_id
self.child = child
self.megatron_v2 = self.policy.is_megatron_v2
self.scale_attention = self.policy.scale_attention
self.ckpt_load_enabled = False
# configuration for models. todo: can this be moved to a pydantic model config?
self.hidden_size = None
self.intermediate_size = None
self.num_attention_heads = None
self.mp_size = self.config.tensor_parallel.tp_size
self.pre_layer_norm = self.model_config.do_layer_norm_before if \
hasattr(self.model_config, 'do_layer_norm_before') else self.policy.pre_attn_norm
self.dtype = self.config.dtype
self.attn_linear_layer = self.policy.linear_layer
self.mlp_linear_layer = self.policy.linear_layer
self.return_tuple = self.config.return_tuple
self.triangular_masking = True
self.local_attention = ((self.model_config.attention_layers[self.layer_id] == "local") if hasattr(
self.model_config, 'attention_layers') else False)
self.window_size = getattr(self.model_config, "window_size", 1)
self.mlp_act_func_type = self.policy.mlp_act_func_type
self.norm_type = self.policy.norm_type
self.training_mp_size = self.config.training_mp_size
self.bigscience_bloom = False
self.max_out_tokens = self.config.max_out_tokens
self.min_out_tokens = self.config.min_out_tokens
self.scale_attn_by_inverse_layer_idx = getattr(self.config, "scale_attn_by_inverse_layer_idx", False)
self.use_mup = self.policy.use_mup
self.return_single_tuple = False
self.rotary_dim = self.get_rotary_dim()
self.mlp_after_attn = (self.rotary_dim is None or self.rotary_dim < 0)
# Attention tensors
self.qkvw = None
self.qkvb = None
self.dense_w = None
self.dense_b = None
# MLP tensors
self._h4h_w = None
self._h4h_b = None
self._4hh_w = None
self._4hh_b = None
# LayerNorm tensors
self.attn_nw = None
self.attn_nb = None
self.input_nw = None
self.input_nb = None
self.mp_group = None
self.use_triton = False
# Triton
self.use_triton = config.use_triton and deepspeed.HAS_TRITON
def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads())
assert self.num_attention_heads % self.mp_size == 0,\
"To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
"This is because the attention computation is partitioned evenly among the parallel GPUs."
self.ds_model_config = DeepSpeedInferenceConfig(
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
heads=self.num_attention_heads,
layer_norm_eps=self.layernorm_epsilon,
dtype=self.dtype,
pre_layer_norm=self.pre_layer_norm,
norm_type=self.norm_type,
mp_size=self.mp_size,
return_tuple=self.return_tuple,
triangular_masking=self.triangular_masking,
local_attention=self.local_attention,
window_size=self.window_size,
rotary_dim=self.rotary_dim,
mlp_after_attn=self.mlp_after_attn,
mlp_act_func_type=self.mlp_act_func_type,
training_mp_size=self.training_mp_size,
bigscience_bloom=self.bigscience_bloom,
max_out_tokens=self.max_out_tokens,
min_out_tokens=self.min_out_tokens,
scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx,
use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple,
set_empty_params=self.config.set_empty_params,
transposed_mode=self.config.transposed_mode,
use_triton=self.use_triton,
triton_autotune=self.config.triton_autotune)
if self.use_triton and deepspeed.HAS_TRITON:
from .bert import DS_BERTContainer
if not isinstance(self, DS_BERTContainer):
raise NotImplementedError("Triton kernels are only for BERT-like models yet")
if not self.config.triton_autotune:
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()
return self.ds_model_config
def check_meta_tensor_support(self):
if hasattr(self.qkvw, 'is_meta'):
if self.qkvw.is_meta:
assert self.ckpt_load_enabled, "Meta tensors are not supported for this model currently."
else:
raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")
def initialize_tensors(self, enable_training=False):
# Set the tensors from policy (user module) to container (DS module)
self.set_attention(*self.policy.attention(enable_training=enable_training))
self.set_mlp(*self.policy.mlp(enable_training=enable_training))
self.set_layernorm(*self.policy.layernorm())
#self.check_meta_tensor_support()
def convert_to_required_dtype(self):
# Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy
if self.dtype in [torch.half, torch.bfloat16]:
for k, v in self.__dict__.items():
# The list comprehension is used for MoE tensor lists
if isinstance(v, list) and all((isinstance(tensor, torch.Tensor) \
or isinstance(tensor, torch.nn.Parameter)) for tensor in v):
self.__dict__[k] = [moe_tensor.to(self.dtype) for moe_tensor in v]
if isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter):
self.__dict__[k] = v.to(self.dtype)
def get_rotary_dim(self):
if hasattr(self.model_config, 'rotary_dim'):
return self.model_config.rotary_dim
if hasattr(self.child, 'attention') and hasattr(self.child.attention, 'rotary_ndims'):
return self.child.attention.rotary_ndims
return -1
def set_moe(self, moe=False):
self.moe = moe
def set_tensor_parallel_config(self, mp_size, mp_group):
self.mp_size = mp_size
self.mp_group = mp_group
def set_quantization_config(self, quantizer):
self.quantizer = quantizer
def set_hidden_heads(self, hidden_size, num_attention_heads, epsilon, intermediate_size):
"""
Args:
hidden_size: embedding dimension of the model
num_attention_heads: number of attention heads in the model
epsilon: epsilon value for layer norm (same value used for all norms)
intermediate_size: Size of MLP projection. If `DEFAULT_INTERMEDIATE_SIZE` is passed
it is assumed to be `4 * hidden_size`
"""
self.hidden_size = hidden_size
if intermediate_size == DEFAULT_INTERMEDIATE_SIZE:
self.intermediate_size = 4 * hidden_size
else:
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.layernorm_epsilon = epsilon
def set_attention(self, qkvw, qkvb, dense_w, dense_b):
self.qkvw = qkvw
self.qkvb = qkvb
self.dense_w = dense_w
self.dense_b = dense_b
def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b):
self._h4h_w = _h4h_w
self._h4h_b = _h4h_b
self._4hh_w = _4hh_w
self._4hh_b = _4hh_b
def set_layernorm(self, attn_nw, attn_nb, input_nw, input_nb):
self.attn_nw = attn_nw
self.attn_nb = attn_nb
self.input_nw = input_nw
self.input_nb = input_nb
def apply_weight_quantization(self):
# quantize attention weights
self.attention_quantization()
# quantize mlp weights
self.mlp_quantization()
def attention_quantization(self):
self.module.attention.attn_qkvw = self.quantizer.quantize(self.module.attention.attn_qkvw)
self.module.attention.attn_ow = self.quantizer.quantize(self.module.attention.attn_ow)
def mlp_quantization(self):
self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w)
self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w)
def apply_tensor_parallelism(self, mp_replace):
# setup the new Attention module
self.attention_qkv_mp(mp_replace)
self.attention_o_mp(mp_replace)
# setup the new MLP module
self.mlp_inter_mp(mp_replace)
self.mlp_output_mp(mp_replace)
# Apply weight quantization
# TODO(cmikeh2): Re-enable this once verified
#self.apply_weight_quantization()
def attention_qkv_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qkvw = mp_replace.strided_copy(self.module.attention.attn_qkvw,
self.qkvw,
num_splits=3,
int8=reversed_dim)
self.module.attention.attn_qkvb = mp_replace.strided_copy(self.module.attention.attn_qkvb,
self.qkvb,
num_splits=3,
int8=reversed_dim)
def attention_o_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, self.dense_w, int8=reversed_dim)
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob,
self.dense_b,
int8=reversed_dim,
allocate_tensor=reversed_dim)
def mlp_inter_mp(self, mp_replace, reversed_dim=False):
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w, int8=reversed_dim)
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b, int8=reversed_dim)
def mlp_output_mp(self, mp_replace, reversed_dim=False):
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w, int8=reversed_dim)
self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b,
self._4hh_b,
int8=reversed_dim,
allocate_tensor=reversed_dim)
def copy_data_to_new_module(self):
params = {'attn_nw': self.attn_nw, 'attn_nb': self.attn_nb}
for key in params:
if params[key] is None:
setattr(self.module.mlp, key, None)
else:
setattr(self.module.mlp, key,
torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name())))
params = {'norm_w': self.input_nw, 'norm_b': self.input_nb}
for key in params:
if params[key] is None:
setattr(self.module, key, None)
else:
setattr(self.module, key,
torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name())))
def transpose(self):
self.transpose_attention()
self.transpose_mlp()
def transpose_attention(self):
if self.attn_linear_layer:
self.qkvw = self.transpose_impl(self.qkvw.data)
self.dense_w = self.transpose_impl(self.dense_w.data)
def transpose_mlp(self):
if self.mlp_linear_layer:
self._h4h_w = self.transpose_impl(self._h4h_w.data)
self._4hh_w = self.transpose_impl(self._4hh_w.data)
def transpose_impl(self, data):
data = data.contiguous()
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
data = data.reshape(data.shape[-1], data.shape[-2])
data.to(get_accelerator().current_device_name())
return data
def get_all_params(self):
params = [
self.attn_nw,
self.attn_nb,
self.input_nw,
self.input_nb,
]
params.extend(self.get_attn_params())
params.extend(self.get_mlp_params())
return params
def get_attn_params(self):
return [self.qkvw, self.qkvb, self.dense_w, self.dense_b]
def get_mlp_params(self):
return [self._h4h_w, self._h4h_b, self._4hh_w, self._4hh_b]
|