Spaces:
Build error
Build error
File size: 9,636 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
import torch.distributed as dist
from megatron.core import mpu, parallel_state
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Function
from torch.distributed import broadcast, get_process_group_ranks
from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE
from transformer_engine.pytorch.module.rmsnorm import _RMSNorm
from cosmos_predict1.utils import log
def get_batch_on_this_cp_rank(inputs):
"""Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
seq_dim = 1 # if key != 'attention_mask' else 2
inputs = inputs.view(
*inputs.shape[0:seq_dim],
2 * cp_size,
inputs.shape[seq_dim] // (2 * cp_size),
*inputs.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
inputs = inputs.index_select(seq_dim, index)
inputs = inputs.view(*inputs.shape[0:seq_dim], -1, *inputs.shape[(seq_dim + 2) :])
return inputs
def gather_batch_from_cp_ranks(outputs):
"""
Gather and reconstruct the full batch from chunks distributed across GPUs in a context parallel group.
"""
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
if cp_size > 1:
seq_dim = 1 # Assuming sequence dimension is 1
try:
# Reshape output to separate the two chunks
chunk_size = outputs.shape[seq_dim] // 2
outputs = outputs.view(*outputs.shape[:seq_dim], 2, chunk_size, *outputs.shape[seq_dim + 1 :])
# Prepare a list to gather all chunks from all ranks
gathered_chunks = [torch.zeros_like(outputs) for _ in range(cp_size)]
# Gather all chunks
dist.barrier()
dist.all_gather(gathered_chunks, outputs, group=parallel_state.get_context_parallel_group())
dist.barrier()
# Reorder chunks
reordered_chunks = [None] * (2 * cp_size)
for i in range(cp_size):
reordered_chunks[i] = gathered_chunks[i].select(seq_dim, 0)
reordered_chunks[2 * cp_size - 1 - i] = gathered_chunks[i].select(seq_dim, 1)
# Concatenate all chunks
outputs = torch.cat(reordered_chunks, dim=seq_dim)
except Exception as e:
log.info(f"[Rank {cp_rank}] Error in gather_batch_from_cp_ranks: {str(e)}")
raise
return outputs
def broadcast_data_batch_in_tp_cp_group(data_batch):
"""
Broadcast data batch across tensor model parallel and context parallel groups.
"""
keys = sorted(data_batch.keys())
tp_size = parallel_state.get_tensor_model_parallel_world_size()
cp_size = parallel_state.get_context_parallel_world_size()
tp_group = parallel_state.get_tensor_model_parallel_group() if tp_size > 1 else None
cp_group = parallel_state.get_context_parallel_group() if cp_size > 1 else None
tp_ranks = get_process_group_ranks(tp_group) if tp_size > 1 else None
cp_ranks = get_process_group_ranks(cp_group) if cp_size > 1 else None
if tp_size > 1 or cp_size > 1:
for key in keys:
tensor = data_batch[key]
if isinstance(tensor, torch.Tensor):
tensor = tensor.contiguous()
if tp_size > 1:
broadcast(tensor, min(tp_ranks), group=tp_group)
if cp_size > 1:
broadcast(tensor, min(cp_ranks), group=cp_group)
def allreduce_layernorm_grads(model: List[torch.nn.Module], tensor_model_parallel_size: int, sequence_parallel: bool):
"""
All-reduce layernorm grads (for sequence parallelism).
Note:
- We skip QK Normalization layers and the last normalization layer of Transformer,
since we use AllReduceBWDRMSNormTE for these layers, which already applies all-reduce in the backward pass.
- TransformerEngine's LayernormLinear and LayernormMLP modules have `*.layer_norm_weight` parameters that
we must all-reduce in the backward pass as well. So we implement this function to cover these parameters.
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if tensor_model_parallel_size > 1 and sequence_parallel:
grads = []
for model_chunk in model:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
if name.endswith(".layer_norm_weight"): # TP # Q-layernorm # K-layernorm
grad = param.grad
if grad is not None:
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
def sync_1d_parameters(model: torch.nn.Module, process_group=None):
"""
Synchronize layernorm parameters (1D) across ranks by performing all-reduce with mean operation.
LayerNorm parameters are identified by having ndim==1.
Note: If parameters other than LayerNorm are 1D, they will also be synchronized.
Args:
model (torch.nn.Module): The model containing layernorm parameters
process_group (optional): The process group to perform all-reduce.
If None, uses the default process group.
"""
if not torch.distributed.is_initialized():
return
# Synchronize each 1D parameter (layernorm parameters)
for name, param in model.named_parameters():
if param.ndim == 1 and param.requires_grad: # LayerNorm weights/biases are 1D
torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.AVG, group=process_group)
class AllReduceBWD(Function):
"""
Custom autograd Function that performs an all-reduce operation during the backward pass.
Args:
tensor (Tensor): The input tensor.
process_group: The process group to perform the all-reduce operation.
Returns:
Tensor: The input tensor in the forward pass, and the all-reduced gradient in the backward pass.
"""
@staticmethod
def forward(ctx, tensor, process_group):
ctx.process_group = process_group
return tensor
@staticmethod
def backward(ctx, grad_output):
dist.all_reduce(grad_output, group=ctx.process_group)
return grad_output, None
class AllReduceBWDRMSNormTE(RMSNormTE):
"""
A custom RMSNorm layer that applies all-reduce operation during backward pass.
Used in tensor parallel training with Transformer Engine.
Args:
hidden_size (int): The size of the hidden dimension.
process_group: Megatron Core's process group.
**kwargs: Additional arguments to be passed to RMSNormTE.
"""
def __init__(self, hidden_size, process_group, **kwargs):
super().__init__(hidden_size, **kwargs)
self.process_group = process_group
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD"""
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply
args = []
else:
fwd_fn = _RMSNorm.forward
args = [None]
args += (
inp,
AllReduceBWD.apply(self.weight, self.process_group),
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.inf_rmsnorm_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
)
return fwd_fn(*args)
|