roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
import torch.distributed as dist
from megatron.core import mpu, parallel_state
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Function
from torch.distributed import broadcast, get_process_group_ranks
from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE
from transformer_engine.pytorch.module.rmsnorm import _RMSNorm
from cosmos_predict1.utils import log
def get_batch_on_this_cp_rank(inputs):
"""Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
seq_dim = 1 # if key != 'attention_mask' else 2
inputs = inputs.view(
*inputs.shape[0:seq_dim],
2 * cp_size,
inputs.shape[seq_dim] // (2 * cp_size),
*inputs.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
inputs = inputs.index_select(seq_dim, index)
inputs = inputs.view(*inputs.shape[0:seq_dim], -1, *inputs.shape[(seq_dim + 2) :])
return inputs
def gather_batch_from_cp_ranks(outputs):
"""
Gather and reconstruct the full batch from chunks distributed across GPUs in a context parallel group.
"""
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
if cp_size > 1:
seq_dim = 1 # Assuming sequence dimension is 1
try:
# Reshape output to separate the two chunks
chunk_size = outputs.shape[seq_dim] // 2
outputs = outputs.view(*outputs.shape[:seq_dim], 2, chunk_size, *outputs.shape[seq_dim + 1 :])
# Prepare a list to gather all chunks from all ranks
gathered_chunks = [torch.zeros_like(outputs) for _ in range(cp_size)]
# Gather all chunks
dist.barrier()
dist.all_gather(gathered_chunks, outputs, group=parallel_state.get_context_parallel_group())
dist.barrier()
# Reorder chunks
reordered_chunks = [None] * (2 * cp_size)
for i in range(cp_size):
reordered_chunks[i] = gathered_chunks[i].select(seq_dim, 0)
reordered_chunks[2 * cp_size - 1 - i] = gathered_chunks[i].select(seq_dim, 1)
# Concatenate all chunks
outputs = torch.cat(reordered_chunks, dim=seq_dim)
except Exception as e:
log.info(f"[Rank {cp_rank}] Error in gather_batch_from_cp_ranks: {str(e)}")
raise
return outputs
def broadcast_data_batch_in_tp_cp_group(data_batch):
"""
Broadcast data batch across tensor model parallel and context parallel groups.
"""
keys = sorted(data_batch.keys())
tp_size = parallel_state.get_tensor_model_parallel_world_size()
cp_size = parallel_state.get_context_parallel_world_size()
tp_group = parallel_state.get_tensor_model_parallel_group() if tp_size > 1 else None
cp_group = parallel_state.get_context_parallel_group() if cp_size > 1 else None
tp_ranks = get_process_group_ranks(tp_group) if tp_size > 1 else None
cp_ranks = get_process_group_ranks(cp_group) if cp_size > 1 else None
if tp_size > 1 or cp_size > 1:
for key in keys:
tensor = data_batch[key]
if isinstance(tensor, torch.Tensor):
tensor = tensor.contiguous()
if tp_size > 1:
broadcast(tensor, min(tp_ranks), group=tp_group)
if cp_size > 1:
broadcast(tensor, min(cp_ranks), group=cp_group)
def allreduce_layernorm_grads(model: List[torch.nn.Module], tensor_model_parallel_size: int, sequence_parallel: bool):
"""
All-reduce layernorm grads (for sequence parallelism).
Note:
- We skip QK Normalization layers and the last normalization layer of Transformer,
since we use AllReduceBWDRMSNormTE for these layers, which already applies all-reduce in the backward pass.
- TransformerEngine's LayernormLinear and LayernormMLP modules have `*.layer_norm_weight` parameters that
we must all-reduce in the backward pass as well. So we implement this function to cover these parameters.
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if tensor_model_parallel_size > 1 and sequence_parallel:
grads = []
for model_chunk in model:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
if name.endswith(".layer_norm_weight"): # TP # Q-layernorm # K-layernorm
grad = param.grad
if grad is not None:
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
def sync_1d_parameters(model: torch.nn.Module, process_group=None):
"""
Synchronize layernorm parameters (1D) across ranks by performing all-reduce with mean operation.
LayerNorm parameters are identified by having ndim==1.
Note: If parameters other than LayerNorm are 1D, they will also be synchronized.
Args:
model (torch.nn.Module): The model containing layernorm parameters
process_group (optional): The process group to perform all-reduce.
If None, uses the default process group.
"""
if not torch.distributed.is_initialized():
return
# Synchronize each 1D parameter (layernorm parameters)
for name, param in model.named_parameters():
if param.ndim == 1 and param.requires_grad: # LayerNorm weights/biases are 1D
torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.AVG, group=process_group)
class AllReduceBWD(Function):
"""
Custom autograd Function that performs an all-reduce operation during the backward pass.
Args:
tensor (Tensor): The input tensor.
process_group: The process group to perform the all-reduce operation.
Returns:
Tensor: The input tensor in the forward pass, and the all-reduced gradient in the backward pass.
"""
@staticmethod
def forward(ctx, tensor, process_group):
ctx.process_group = process_group
return tensor
@staticmethod
def backward(ctx, grad_output):
dist.all_reduce(grad_output, group=ctx.process_group)
return grad_output, None
class AllReduceBWDRMSNormTE(RMSNormTE):
"""
A custom RMSNorm layer that applies all-reduce operation during backward pass.
Used in tensor parallel training with Transformer Engine.
Args:
hidden_size (int): The size of the hidden dimension.
process_group: Megatron Core's process group.
**kwargs: Additional arguments to be passed to RMSNormTE.
"""
def __init__(self, hidden_size, process_group, **kwargs):
super().__init__(hidden_size, **kwargs)
self.process_group = process_group
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD"""
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply
args = []
else:
fwd_fn = _RMSNorm.forward
args = [None]
args += (
inp,
AllReduceBWD.apply(self.weight, self.process_group),
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.inf_rmsnorm_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
)
return fwd_fn(*args)