File size: 9,636 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import torch
import torch.distributed as dist
from megatron.core import mpu, parallel_state
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Function
from torch.distributed import broadcast, get_process_group_ranks
from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE
from transformer_engine.pytorch.module.rmsnorm import _RMSNorm

from cosmos_predict1.utils import log


def get_batch_on_this_cp_rank(inputs):
    """Slice batch input along sequence dimension into multiple chunks,
    which are parallelized across GPUs in a context parallel group.
    """

    # With causal masking, each token only attends to its prior tokens. Simply split
    # sequence into CP chunks can result in severe load imbalance. That's to say, chunks
    # at the end of sequence have bigger workload than others. To address this issue,
    # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
    # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
    # that we can get balanced workload among GPUs in a context parallel group.
    cp_size = parallel_state.get_context_parallel_world_size()

    if cp_size > 1:
        cp_rank = mpu.get_context_parallel_rank()
        seq_dim = 1  # if key != 'attention_mask' else 2
        inputs = inputs.view(
            *inputs.shape[0:seq_dim],
            2 * cp_size,
            inputs.shape[seq_dim] // (2 * cp_size),
            *inputs.shape[(seq_dim + 1) :],
        )
        index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
            non_blocking=True
        )
        inputs = inputs.index_select(seq_dim, index)
        inputs = inputs.view(*inputs.shape[0:seq_dim], -1, *inputs.shape[(seq_dim + 2) :])

    return inputs


def gather_batch_from_cp_ranks(outputs):
    """
    Gather and reconstruct the full batch from chunks distributed across GPUs in a context parallel group.
    """
    cp_size = parallel_state.get_context_parallel_world_size()
    cp_rank = mpu.get_context_parallel_rank()

    if cp_size > 1:
        seq_dim = 1  # Assuming sequence dimension is 1

        try:
            # Reshape output to separate the two chunks
            chunk_size = outputs.shape[seq_dim] // 2
            outputs = outputs.view(*outputs.shape[:seq_dim], 2, chunk_size, *outputs.shape[seq_dim + 1 :])

            # Prepare a list to gather all chunks from all ranks
            gathered_chunks = [torch.zeros_like(outputs) for _ in range(cp_size)]

            # Gather all chunks
            dist.barrier()
            dist.all_gather(gathered_chunks, outputs, group=parallel_state.get_context_parallel_group())
            dist.barrier()

            # Reorder chunks
            reordered_chunks = [None] * (2 * cp_size)
            for i in range(cp_size):
                reordered_chunks[i] = gathered_chunks[i].select(seq_dim, 0)
                reordered_chunks[2 * cp_size - 1 - i] = gathered_chunks[i].select(seq_dim, 1)

            # Concatenate all chunks
            outputs = torch.cat(reordered_chunks, dim=seq_dim)
        except Exception as e:
            log.info(f"[Rank {cp_rank}] Error in gather_batch_from_cp_ranks: {str(e)}")
            raise

    return outputs


def broadcast_data_batch_in_tp_cp_group(data_batch):
    """
    Broadcast data batch across tensor model parallel and context parallel groups.
    """
    keys = sorted(data_batch.keys())
    tp_size = parallel_state.get_tensor_model_parallel_world_size()
    cp_size = parallel_state.get_context_parallel_world_size()
    tp_group = parallel_state.get_tensor_model_parallel_group() if tp_size > 1 else None
    cp_group = parallel_state.get_context_parallel_group() if cp_size > 1 else None
    tp_ranks = get_process_group_ranks(tp_group) if tp_size > 1 else None
    cp_ranks = get_process_group_ranks(cp_group) if cp_size > 1 else None
    if tp_size > 1 or cp_size > 1:
        for key in keys:
            tensor = data_batch[key]
            if isinstance(tensor, torch.Tensor):
                tensor = tensor.contiguous()
                if tp_size > 1:
                    broadcast(tensor, min(tp_ranks), group=tp_group)
                if cp_size > 1:
                    broadcast(tensor, min(cp_ranks), group=cp_group)


def allreduce_layernorm_grads(model: List[torch.nn.Module], tensor_model_parallel_size: int, sequence_parallel: bool):
    """
    All-reduce layernorm grads (for sequence parallelism).
    Note:
    - We skip QK Normalization layers and the last normalization layer of Transformer,
      since we use AllReduceBWDRMSNormTE for these layers, which already applies all-reduce in the backward pass.
    - TransformerEngine's LayernormLinear and LayernormMLP modules have `*.layer_norm_weight` parameters that
      we must all-reduce in the backward pass as well. So we implement this function to cover these parameters.
    """
    # All-reduce layernorm parameters across model parallel nodes
    # when sequence parallelism is used
    if tensor_model_parallel_size > 1 and sequence_parallel:
        grads = []
        for model_chunk in model:
            for name, param in model_chunk.named_parameters():
                if not param.requires_grad:
                    continue
                if name.endswith(".layer_norm_weight"):  # TP  # Q-layernorm  # K-layernorm
                    grad = param.grad
                    if grad is not None:
                        grads.append(grad.data)

        if grads:
            coalesced = _flatten_dense_tensors(grads)
            torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group())
            for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
                buf.copy_(synced)


def sync_1d_parameters(model: torch.nn.Module, process_group=None):
    """
    Synchronize layernorm parameters (1D) across ranks by performing all-reduce with mean operation.
    LayerNorm parameters are identified by having ndim==1.
    Note: If parameters other than LayerNorm are 1D, they will also be synchronized.

    Args:
        model (torch.nn.Module): The model containing layernorm parameters
        process_group (optional): The process group to perform all-reduce.
                                If None, uses the default process group.
    """
    if not torch.distributed.is_initialized():
        return
    # Synchronize each 1D parameter (layernorm parameters)
    for name, param in model.named_parameters():
        if param.ndim == 1 and param.requires_grad:  # LayerNorm weights/biases are 1D
            torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.AVG, group=process_group)


class AllReduceBWD(Function):
    """
    Custom autograd Function that performs an all-reduce operation during the backward pass.

    Args:
        tensor (Tensor): The input tensor.
        process_group: The process group to perform the all-reduce operation.

    Returns:
        Tensor: The input tensor in the forward pass, and the all-reduced gradient in the backward pass.
    """

    @staticmethod
    def forward(ctx, tensor, process_group):
        ctx.process_group = process_group
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
        dist.all_reduce(grad_output, group=ctx.process_group)
        return grad_output, None


class AllReduceBWDRMSNormTE(RMSNormTE):
    """
    A custom RMSNorm layer that applies all-reduce operation during backward pass.
    Used in tensor parallel training with Transformer Engine.

    Args:
        hidden_size (int): The size of the hidden dimension.
        process_group: Megatron Core's process group.
        **kwargs: Additional arguments to be passed to RMSNormTE.
    """

    def __init__(self, hidden_size, process_group, **kwargs):
        super().__init__(hidden_size, **kwargs)
        self.process_group = process_group

    @no_torch_dynamo()
    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        """RMSNorm FWD"""

        # Set the activation type for AMP.
        TransformerEngineBaseModule.set_activation_dtype(self, inp)

        if torch.is_grad_enabled():
            fwd_fn = _RMSNorm.apply
            args = []
        else:
            fwd_fn = _RMSNorm.forward
            args = [None]

        args += (
            inp,
            AllReduceBWD.apply(self.weight, self.process_group),
            self.eps,
            self.fwd_rmsnorm_sm_margin,
            self.bwd_rmsnorm_sm_margin,
            self.inf_rmsnorm_sm_margin,
            self.zero_centered_gamma,
            torch.is_grad_enabled(),
            self.activation_dtype,
        )

        return fwd_fn(*args)