File size: 10,377 Bytes
174ae06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""
Layer Normalization
====================
In this tutorial, you will write a high-performance layer normalization
kernel that runs faster than the PyTorch implementation.

In doing so, you will learn about:

* Implementing backward pass in Triton.

* Implementing parallel reduction in Triton.

"""

import torch
import triton
import triton.language as tl

from ._division_transpose import fp8_division_transpose
from .common import FP8_MAX_VALUE, SCALE_MIN_THRES

"""RMSNorm Forward + Backward"""
"""Forward: Input uses 1 * 16 group quantization"""
"""Forward: Output uses per-tensor quantization"""
"""Backward: Input uses full-precision/BF16."""
"""Backward: Output uses full-precision/BF16."""

"""The input can be 2D or 3D, but the calculation is performed in 2D"""


@triton.heuristics(
    {
        "BLOCK_SN": lambda args: args["N"] // args["QB"],
        "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
    }
)
@triton.jit
def _rms_norm_fwd_fused(
    X,  # pointer to the input
    SX,  # pointer to the scale of input
    Y,  # pointer to the output
    SY,  # pointer to the scale of output
    W,  # Weight
    Rstd,  # pointer to the 1/std
    stride,  # how much to increase the pointer when moving by 1 row
    scale_stride,  # how much to increase the pointer when moving by 1 row
    N: tl.constexpr,  # number of columns in X,
    N2: tl.constexpr,  # number of columns in X,
    SN: tl.constexpr,
    SN2: tl.constexpr,
    QB: tl.constexpr,
    eps,  # epsilon to avoid division by zero
    fp8_max,
    SCALE_MIN_THRES: tl.constexpr,
    BLOCK_SN: tl.constexpr,
    BLOCK_SN2: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    Y += row * stride
    X += row * stride
    SX += row * scale_stride

    cols = tl.arange(0, N2)
    scale_cols = tl.arange(0, SN2)
    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
    scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)

    # Dequantize and swish calculation
    scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
    x = tl.reshape(x, (BLOCK_SN2, QB))
    x = x * scale_x
    x = tl.reshape(x, N2)

    # Compute variance
    _var = x * x
    var = tl.sum(_var, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)

    # Write mean / rstd
    tl.store(Rstd + row, rstd)

    # Normalize and apply linear transformation
    w = tl.load(W + cols, mask=cols < N, other=0.0)
    x_hat = x * rstd
    y = x_hat * w

    # Scale calculation
    abs_y = tl.abs(y)
    max_val = tl.max(abs_y) + SCALE_MIN_THRES
    scale_output = max_val / fp8_max
    tl.store(SY + row, scale_output)

    # Write output
    tl.store(Y + cols, y, mask=cols < N)


@triton.heuristics(
    {
        "BLOCK_SN": lambda args: args["N"] // args["QB"],
        "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
    }
)
@triton.jit
def _rms_norm_bwd_dx_fused(
    DX,  # pointer to the input gradient
    DY,  # pointer to the output gradient
    DW,
    X,  # pointer to the input
    SX,  # pointer to the input
    W,  # weight
    Rstd,  # pointer to the 1/std
    Lock,
    stride,  # how much to increase the pointer when moving by 1 row
    scale_stride,  # how much to increase the pointer when moving by 1 row
    N: tl.constexpr,  # number of columns in X,
    N2: tl.constexpr,  # number of columns in X,
    SN: tl.constexpr,
    SN2: tl.constexpr,
    QB: tl.constexpr,
    SCALE_MIN_THRES: tl.constexpr,
    BLOCK_SN: tl.constexpr,
    BLOCK_SN2: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    # Map the program id to the elements of X, DX, and DY it should compute.
    row = tl.program_id(0)
    cols = tl.arange(0, N2)
    scale_cols = tl.arange(0, SN2)
    mask = cols < N
    X += row * stride
    DY += row * stride
    DX += row * stride
    SX += row * scale_stride

    # Offset locks and weights/biases gradient pointer for parallel reduction
    lock_id = row % GROUP_SIZE_M
    Lock += lock_id
    Count = Lock + GROUP_SIZE_M
    DW = DW + lock_id * N + cols

    # Load data to SRAM
    x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
    scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
    scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
    x = tl.reshape(x, (BLOCK_SN2, QB))
    x = x * scale_x
    x = tl.reshape(x, N2)
    dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)

    # Load weight
    w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32)
    rstd = tl.load(Rstd + row).to(tl.float32)

    # Compute dx
    xhat = x * rstd
    wdy = w * dy
    xhat = tl.where(mask, xhat, 0.0)
    wdy = tl.where(mask, wdy, 0.0)
    c1 = tl.sum(xhat * wdy, axis=0) / N
    dx = (wdy - (xhat * c1)) * rstd  # layer norm have c2 term, rmsnorm do not

    dx = dx.to(DX.type.element_ty)

    # Write dx
    tl.store(DX + cols, dx, mask=mask)

    # Accumulate partial sums for dw/db
    partial_dw = (dy * xhat).to(w.dtype)
    while tl.atomic_cas(Lock, 0, 1) == 1:
        pass
    count = tl.load(Count)
    # First store doesn't accumulate
    if count == 0:
        tl.atomic_xchg(Count, 1)
    else:
        partial_dw += tl.load(DW, mask=mask)
    tl.store(DW, partial_dw, mask=mask)
    # Release the lock
    tl.atomic_xchg(Lock, 0)


@triton.jit
def _rms_norm_bwd_dwdb(
    DW,  # pointer to the partial sum of weights gradient
    FINAL_DW,  # pointer to the weights gradient
    M,  # GROUP_SIZE_M
    N,  # number of columns
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
):
    # Map the program id to the elements of DW and DB it should compute.
    pid = tl.program_id(0)
    cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    # Iterate through the rows of DW and DB to sum the partial sums.
    for i in range(0, M, BLOCK_SIZE_M):
        rows = i + tl.arange(0, BLOCK_SIZE_M)
        mask = (rows[:, None] < M) & (cols[None, :] < N)
        offs = rows[:, None] * N + cols[None, :]
        dw += tl.load(DW + offs, mask=mask, other=0.0)
    # Write the final sum to the output.
    sum_dw = tl.sum(dw, axis=0)
    tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)


def fp8_rmsnorm_forward(x, s_x, w, QB, eps, transpose_output_2d=False):
    # Change batched 3D input to 2D
    batched = False
    if len(x.shape) == 3:
        assert len(s_x.shape) == 3
        batched = True
        BS = x.shape[0]
        x = x.reshape(-1, x.shape[-1])
        s_x = s_x.reshape(-1, s_x.shape[-1])

    # allocate output
    M, N = x.shape
    _, SN = s_x.shape
    y = torch.empty_like(x, dtype=torch.bfloat16)
    s_y = torch.empty((M,), dtype=torch.bfloat16, device=x.device)
    rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
    # heuristics for number of warps
    num_warps = 8
    fp8MaxValue = FP8_MAX_VALUE[x.dtype]

    N2 = triton.next_power_of_2(N)
    SN2 = N2 // QB

    # import os
    # if int(os.environ.get("LOCAL_RANK")) == 7:
    #     print(x.device, x.shape, x.dtype, s_x.shape, s_x.dtype, y.shape, y.dtype, s_y.shape, s_y.dtype, w.shape, w.dtype, rstd.shape, rstd.dtype, x.stride(0), s_x.stride(0), N, N2, SN, SN2, QB, "\n")
    # enqueue kernel
    _rms_norm_fwd_fused[(M,)](  #
        x,
        s_x,
        y,
        s_y,
        w,
        rstd,  #
        x.stride(0),
        s_x.stride(0),
        N,
        N2,
        SN,
        SN2,
        QB,
        eps,
        fp8MaxValue,
        SCALE_MIN_THRES=SCALE_MIN_THRES,  #
        num_warps=num_warps,
        num_ctas=1,
    )

    # reduction
    s_y_max = s_y.max()
    qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max)

    # Recover 2D to 3D
    if batched:
        y = y.reshape(BS, -1, y.shape[-1])
        qy = qy.reshape(BS, -1, y.shape[-1])
        if not transpose_output_2d:
            qy_t = qy_t.reshape(BS, -1, y.shape[-1])

    return qy, s_y_max, qy_t, (w.clone(), rstd, num_warps)


def fp8_rmsnorm_backward(x, s_x, g, w, v, QB, num_warps):
    # Change batched 3D input to 2D
    batched = False
    if len(x.shape) == 3:
        assert len(s_x.shape) == 3
        batched = True
        BS = x.shape[0]
        x = x.reshape(-1, x.shape[-1])
        s_x = s_x.reshape(-1, s_x.shape[-1])
        g = g.reshape(-1, g.shape[-1])

    # enqueue kernel using forward pass heuristics
    # also compute partial sums for DW and DB
    M, N = g.shape
    _, SN = s_x.shape

    GROUP_SIZE_M = 128
    # heuristics for amount of parallel reduction stream for DW/DB
    locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device)
    _dw = torch.zeros((GROUP_SIZE_M, N), dtype=w.dtype, device=w.device)
    dw = torch.empty((N,), dtype=w.dtype, device=w.device)

    dx = torch.empty_like(g, dtype=torch.bfloat16)

    N2 = triton.next_power_of_2(N)
    SN2 = triton.next_power_of_2(SN)
    _rms_norm_bwd_dx_fused[(M,)](  #
        dx,
        g,
        _dw,
        x,
        s_x,
        w,
        v,
        locks,  #
        x.stride(0),
        s_x.stride(0),
        N,
        N2,
        SN,
        SN2,
        QB,
        SCALE_MIN_THRES=SCALE_MIN_THRES,
        num_warps=num_warps,
        GROUP_SIZE_M=GROUP_SIZE_M,
    )

    if batched:
        dx = dx.reshape(BS, -1, dx.shape[-1])

    grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
    # accumulate partial sums in separate kernel
    _rms_norm_bwd_dwdb[grid](_dw, dw, min(GROUP_SIZE_M, M), N, BLOCK_SIZE_M=32, BLOCK_SIZE_N=128, num_ctas=1)  #  #

    return dx, dw