# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team # The file has been adapted from the following Megatron-LM file: # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py # Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796 # We retain the following copyright from the original files: # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import deepspeed from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank, bwc_tensor_model_parallel_group) def _gather_tokens(input_, dim=0): """Gather tensors and concatenate them along a dimension""" mpu = deepspeed.utils.groups.mpu input_ = input_.contiguous() world_size = bwc_tensor_model_parallel_world_size(mpu) if world_size == 1: return input_ gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device) deepspeed.comm.all_gather_into_tensor(gather_buffer, input_, group=bwc_tensor_model_parallel_group(mpu)) if dim == 0: shape = list(input_.size()) shape[0] = shape[0] * world_size output = gather_buffer.view(shape) else: tensor_list = [ gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size) ] # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=dim).contiguous() return output def _drop_tokens(input_, dim=0): """Divide a tensor among the tensor parallel ranks""" mpu = deepspeed.utils.groups.mpu total_chunks = bwc_tensor_model_parallel_world_size(mpu) if total_chunks == 1: return input_ this_chunk = bwc_tensor_model_parallel_rank(mpu) assert input_.shape[ dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})" chunk_size = input_.shape[dim] // total_chunks return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) class _GatherTokens(torch.autograd.Function): """All gather tokens among the tensor parallel ranks""" @staticmethod def symbolic(graph, input_, dim): return _gather_tokens(input_, dim) @staticmethod def forward(ctx, input_, dim): ctx.dim = dim return _gather_tokens(input_, dim) @staticmethod def backward(ctx, grad_output): return _drop_tokens(grad_output, ctx.dim), None class _DropTokens(torch.autograd.Function): "Divide tokens equally among the tensor parallel ranks" @staticmethod def symbolic(graph, input_, dim): return _drop_tokens(input_, dim) @staticmethod def forward(ctx, input_, dim): ctx.dim = dim return _drop_tokens(input_, dim) @staticmethod def backward(ctx, input_): return _gather_tokens(input_, ctx.dim), None def gather_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1: # no tensor parallelism for non-experts return input_ return _GatherTokens.apply(input_, dim) def drop_tokens(input_, dim=0): mpu = deepspeed.utils.groups.mpu if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1: # no tensor parallelism for non-experts return input_ return _DropTokens.apply(input_, dim)