File size: 11,727 Bytes
9ad9e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import deepspeed
from deepspeed.runtime.utils import partition_uniform as partition
def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension. Adapted from Megatron-LM.
Arguments:
tensor: input tensor.
partitions: list of partition sizes to supply to torch.split
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
# Split.
tensor_list = torch.split(tensor, partitions, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class TiledLinear(torch.nn.Module):
def __init__(self,
in_features,
out_features,
bias=True,
in_splits=1,
out_splits=1,
input_is_already_split=False,
combine_out_splits=True,
linear_cls=torch.nn.Linear,
init_linear=None,
**kwargs):
"""A replacement for ``torch.nn.Linear`` that works with ZeRO-3 to reduce
memory requirements via tiling.
TiledLinear breaks the input and output dimensions of a linear layer
into tiles that are processed in sequence. This class enables huge
linear layers when combined with ZeRO-3 because inactive tiles can be
partitioned and offloaded.
.. note::
We recommend using as few tiles as necessary. Tiling
significantly reduces memory usage, but can reduce throughput
for inexpensive layers. This due to the smaller kernels having
less parallelism and lower arithmetic intensity, while
introducing more frequent synchronization and communication.
Args:
in_features (int): See ``torch.nn.Linear``
out_features (int): See ``torch.nn.Linear``
bias (bool, optional): See ``torch.nn.Linear``
in_splits (int, optional): The number of tiles along the input dimension. Defaults to 1.
out_splits (int, optional): The number of tiles along the output dimension. Defaults to 1.
input_is_already_split (bool, optional): If set to ``True``, assume that the ``input_`` in
to ``forward()`` is already split into ``in_splits`` chunks. Defaults to ``False``.
combine_out_splits (bool, optional): If set to ``False``, do not combine the ``out_splits`` outputs
into a single tensor. Defaults to ``True``.
linear_cls (class, optional): The underlying class to build individual tiles.
Defaults to ``torch.nn.Linear``.
init_linear (``torch.nn.Linear``, optional): If set, copy the parameters of
``init_linear``. Useful for debugging. Defaults to ``None``.
kwargs (dict, optional): additional keyword arguments to provide to ``linear_cls()``.
Raises:
RuntimeError: ``in_splits`` must be within the range [1, in_features).
RuntimeError: ``out_splits`` must be within the range of [1, out_features).
"""
super().__init__()
if (in_splits < 1) or (in_splits > in_features):
raise RuntimeError('in splits must be in range [1, in_features].')
if (out_splits < 1) or (out_splits > out_features):
raise RuntimeError('out splits must be in range [1, out_features].')
# global, not necessarily local
self.in_features = in_features
self.out_features = out_features
self.use_bias = bias
self.out_splits = out_splits
self.in_splits = in_splits
self.input_is_already_split = input_is_already_split
self.combine_out_splits = combine_out_splits
# Build partition-lists. These are CSR-style splits [0, part0, part1, ..., features]
# For example, row_parts[p] gives the start of partition p and row_parts[p+1]
# is the exclusive end.
self.in_parts = partition(num_items=in_features, num_parts=in_splits)
self.out_parts = partition(num_items=out_features, num_parts=out_splits)
assert len(self.out_parts) == out_splits + 1
assert len(self.in_parts) == in_splits + 1
assert self.out_parts[0] == 0
assert self.out_parts[out_splits] == out_features
assert self.in_parts[in_splits] == in_features
self.linears = torch.nn.ModuleList()
for out_id in range(out_splits):
self.linears.append(torch.nn.ModuleList())
local_out_dim = self.out_parts[out_id + 1] - self.out_parts[out_id]
for in_id in range(in_splits):
#if input_size is split, we only need one bias
local_bias = bias if in_id == (in_splits - 1) else False
local_in_dim = self.in_parts[in_id + 1] - self.in_parts[in_id]
local = linear_cls(local_in_dim, local_out_dim, bias=local_bias, **kwargs)
self.linears[out_id].append(local)
# Optionally initialize with a known tensor
if init_linear is not None:
self.copy_params_from(init_linear)
def forward(self, input_):
if self.in_splits > 1 and not self.input_is_already_split:
input_parts = partition(input_.shape[-1], self.in_splits)
split_sizes = [input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)]
inputs = self._split_global_input(input_, split_sizes)
elif self.in_splits > 1:
inputs = input_
assert len(
inputs) == self.in_splits, f"Col splits {self.in_splits} does not match input splits {len(inputs)}"
else:
# no splits
inputs = [input_]
outputs = [None] * self.out_splits
for out_id in range(self.out_splits):
for in_id in range(self.in_splits):
local_output = self.linears[out_id][in_id](inputs[in_id])
outputs[out_id] = self._reduce_local_output(in_id=in_id,
out_id=out_id,
current_out=outputs[out_id],
new_out=local_output)
if self.combine_out_splits:
return self._combine_output_splits(outputs)
return outputs
def _split_global_input(self, input, split_sizes):
"""Partition an input tensor along the last dimension, aligned with given splits.
Subclasses should override this method to account for new input types.
Args:
input (List[Tensor]): The tensor to partition along the last dimension.
split_sizes (List[int]): The size of each partition.
Returns:
List[Any]: A list of the chunks of ``input``.
"""
return split_tensor_along_last_dim(input, split_sizes)
def _reduce_local_output(self, in_id, out_id, current_out, new_out):
"""Reduce (sum) a new local result into the existing local results.
Subclasses should override this method.
For a given ``out_id``, this method is called ``in_id-1`` times. The first input
split is a simple assignment.
Args:
in_id (int): The input split that produced ``new_out``.
out_id (int): The output split that produced ``new_out``.
current_out (Any): The reduced form of all previous ``out_id`` results.
new_out (Any): The local result from forward (``in_id``, ``out_id``)e
Returns:
Any: The combined result of ``current_out`` and ``new_out``.
"""
if current_out is None:
#this clone is necessary to preserve auto grad
#there is some issue with inplace update for outputs that are views
return new_out.clone()
else:
return current_out + new_out
def _combine_output_splits(self, outputs):
"""Join the splits of the output into a single result.
Args:
outputs (List[Any]): The reduced outputs for each output split.
Returns:
Any: The combined outputs.
"""
assert len(outputs) == self.out_splits
return torch.cat(outputs, dim=-1)
@torch.no_grad()
def copy_params_from(self, other):
"""Copy the weight and bias data from ``other``.
This is especially useful for reproducible initialization and testing.
Equivalent to:
.. code-block:: python
with torch.no_grad():
self.weight.copy_(other.weight)
if self.bias is not None:
self.bias.copy_(other.bias)
.. note::
If ZeRO-3 is enabled, this is a collective operation and the
updated parameters of data-parallel rank 0 will be visible on all
ranks. See :class:`deepspeed.zero.GatheredParameters` for more
information.
Args:
other (``torch.nn.Linear``): the linear layer to copy from.
"""
assert hasattr(other, 'weight')
assert other.weight.size() == (self.out_features, self.in_features)
if self.use_bias:
assert hasattr(other, 'bias')
assert other.bias is not None
assert other.bias.size() == (self.out_features, )
else:
assert other.bias is None
for row in range(self.out_splits):
rstart = self.out_parts[row]
rstop = self.out_parts[row + 1]
for col in range(self.in_splits):
cstart = self.in_parts[col]
cstop = self.in_parts[col + 1]
local = self.linears[row][col]
global_weight = other.weight[rstart:rstop, cstart:cstop]
with deepspeed.zero.GatheredParameters(local.weight, modifier_rank=0):
local.weight.copy_(global_weight)
if local.bias is not None:
with deepspeed.zero.GatheredParameters(local.bias, modifier_rank=0):
local.bias.data.copy_(other.bias[rstart:rstop].data)
class TiledLinearReturnBias(TiledLinear):
"""Wrapper for a Linear class that returns its own bias parameter, such as
used by Megatron-LM.
"""
def _reduce_local_output(self, in_id, out_id, current_out, new_out):
"""Reduces output tensors, but not the returned bias. """
if current_out is not None:
old_tensor, old_bias = current_out
else:
old_tensor, old_bias = None, None
assert isinstance(new_out, tuple)
assert len(new_out) == 2
tensor, bias = new_out
assert tensor is not None
tensor = super()._reduce_local_output(in_id=in_id, out_id=out_id, current_out=old_tensor, new_out=tensor)
if bias is None:
bias = old_bias
return tensor, bias
def _combine_output_splits(self, outputs):
# stack output tensors
tensors = [o[0] for o in outputs]
tensor = super()._combine_output_splits(tensors)
# stack biases if applicable
biases = [o[1] for o in outputs if o[1] is not None]
if len(biases) > 0:
bias = super()._combine_output_splits(biases)
else:
bias = None
return tensor, bias
|