Spaces:
Build error
Build error
File size: 8,381 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
import torch
from megatron.core import ModelParallelConfig, parallel_state
from megatron.core.tensor_parallel import ColumnParallelLinear as McoreColumnParallelLinear
from megatron.core.tensor_parallel import RowParallelLinear as McoreRowParallelLinear
from megatron.core.tensor_parallel import VocabParallelEmbedding as McoreVocabParallelEmbedding
from megatron.core.tensor_parallel.mappings import (
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.tensor_parallel.utils import VocabUtility
from torch.distributed import _functional_collectives as funcol
from torch.distributed._functional_collectives import all_reduce
class VocabParallelEmbedding(torch.nn.Module):
"""
Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Args:
num_embeddings (int): vocabulary size.
embedding_dim (int): size of hidden state.
precision (str): precision of the embedding.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
precision: str = "bfloat16",
):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
(self.vocab_start_index, self.vocab_end_index) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
parallel_state.get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.weight = torch.nn.Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=getattr(torch, precision),
)
)
def forward(self, input_):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output = self.weight[masked_input]
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output[input_mask, :] = 0.0
output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group())
return output
class ColumnParallelLinear(McoreColumnParallelLinear):
"""
A modified version of Mcore's ColumnParallelLinear that only returns the output tensor.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input_: torch.Tensor):
"""
Performs the forward pass of the column parallel linear layer.
Args:
input_ (torch.Tensor): The input tensor.
weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight.
Returns:
torch.Tensor: The output tensor after the linear transformation.
"""
output, _ = super().forward(input_)
return output
class RowParallelLinear(McoreRowParallelLinear):
"""
A modified version of Mcore's RowParallelLinear that only returns the output tensor.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input_: torch.Tensor):
"""
Performs the forward pass of the Row Parallel linear layer.
Args:
input_ (torch.Tensor): The input tensor.
weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight.
Returns:
torch.Tensor: The output tensor after the linear transformation.
"""
output, _ = super().forward(input_)
return output
class TrainingVocabParallelEmbedding(McoreVocabParallelEmbedding):
"""
Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Args:
num_embeddings (int): vocabulary size.
embedding_dim (int): size of hidden state.
Keyword Args:
sequence_parallel (bool): Decides whether to perform ReduceScatter after embedding lookup
batch_first (bool): If True, then output tensor shape is [batch, seq, feature]. If False, then shape becomes
[seq, batch, feature]. Note: We assume the input tensor is always in the shape of [seq, batch].
config: A megatron.core.ModelParallelConfig object
use_inference_allreduce (bool): If True, then Megatron's allreduce in the forward pass is disabled, and the pytorch's
allreduce is used instead (inference mode only).
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
*,
init_method: Callable,
sequence_parallel: bool = False,
batch_first: bool = False,
config: ModelParallelConfig,
use_inference_allreduce: bool = False,
):
super(TrainingVocabParallelEmbedding, self).__init__(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
init_method=init_method,
config=config,
)
self.sequence_parallel = sequence_parallel
if sequence_parallel:
# If sequence parallel, then the output tensor should be in the shape of [seq, batch, feature]
batch_first = False
self.batch_first = batch_first
self.use_inference_allreduce = use_inference_allreduce
def forward(self, input_):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output = self.weight[masked_input]
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output[input_mask, :] = 0.0
if self.sequence_parallel:
assert not self.batch_first
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output = output.transpose(0, 1).contiguous()
if not self.use_inference_allreduce:
output = reduce_scatter_to_sequence_parallel_region(output)
else:
# Reduce across all the model parallel GPUs.
if not self.use_inference_allreduce:
output = reduce_from_tensor_model_parallel_region(output)
if not self.batch_first:
# Shape: [b, s, h] --> [s, b, h]
output = output.transpose(0, 1).contiguous()
if self.use_inference_allreduce:
output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group())
return output
|