SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
# 4 block
import triton
import triton.language as tl
from triton.language.extra.cuda import libdevice
from ._division import _stochastic_rounding
from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
"""Division and Transpose Operator"""
"""Input uses full-precision/BF16"""
"""Output uses per tensor quantization"""
"""Output_t uses per tensor quantization and is transposed, but is flattened to 2D"""
"""The input can be 2D or 3D, but the calculation is performed in 2D"""
@triton.autotune(
configs=[] + get_configs_io_block(), # triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)
# configs=[triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)], #
key=[
"N",
],
)
@triton.heuristics(
{
"BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
}
)
@triton.jit
def _fp8_division_transpose_kernel(
output_ptr,
output_t_ptr, # output
input_ptr,
input_scale_ptr, # input
noise_ptr, # noise for stochastic
M,
N,
SN,
QB: tl.constexpr,
fp8_max,
e_bit,
m_bit, # shape
input_stride_0,
input_stride_1, # input stride
output_stride_0,
output_stride_1, # output stride
output_t_stride_0,
output_t_stride_1, # output stride
SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
STOCHASTIC: tl.constexpr,
ONLY_TRANSPOSED: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_SN: tl.constexpr,
): # CUDA block size
# Block PID
pid = tl.program_id(0)
NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
pid_dim0 = pid // NUM_BLOCK_N
pid_dim1 = pid % NUM_BLOCK_N
# pointers
input_block_ptr = tl.make_block_ptr(
base=input_ptr,
shape=(M, N),
strides=(input_stride_0, input_stride_1),
offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
input = tl.load(input_block_ptr)
input = input.to(tl.float32)
scale_output = tl.load(input_scale_ptr)
scale_output = scale_output.to(tl.float32)
output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
# Quantize Scale calculation
# Quantize
output = tl.fdiv(output, scale_output)
output = tl.reshape(output, (BLOCK_M, BLOCK_N))
if STOCHASTIC:
# noise_block_ptr = tl.make_block_ptr(
# base=noise_ptr,
# shape=(M, N),
# strides=(input_stride_0, input_stride_1),
# offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
# block_shape=(BLOCK_M, BLOCK_N),
# order=(1, 0)
# )
# noise = tl.load(noise_block_ptr)
offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1
noise = tl.rand(0, noise_offset)
output = _stochastic_rounding(output, noise, e_bit, m_bit)
output = output.to(output_ptr.type.element_ty)
# tl.device_print("3: ", output)
output_t = tl.trans(output)
# pointers
output_block_ptr = tl.make_block_ptr(
base=output_ptr,
shape=(M, N),
strides=(output_stride_0, output_stride_1),
offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
output_t_block_ptr = tl.make_block_ptr(
base=output_t_ptr,
shape=(N, M),
strides=(output_t_stride_0, output_t_stride_1),
offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M),
block_shape=(BLOCK_N, BLOCK_M),
order=(1, 0),
)
if not ONLY_TRANSPOSED:
tl.store(output_block_ptr, output, boundary_check=(0, 1))
tl.store(output_t_block_ptr, output_t, boundary_check=(0, 1))
def fp8_division_transpose(x, QB, fp8type, s_y=None, stochastic=False, only_transposed=False):
# Change batched 3D input to 2D
batched = False
if len(x.shape) == 3:
batched = True
BS = x.shape[0]
x = x.reshape(-1, x.shape[-1])
if stochastic:
# noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
noise = None
else:
noise = None
# defining the input and output tensor
M, N = x.shape
SN = N // QB
if isinstance(fp8type, str):
fp8type = convert_str_to_fp8[fp8type]
y = torch.empty_like(x, dtype=fp8type)
y_t = torch.empty((N, M), dtype=fp8type, device=x.device)
fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
e_bit, m_bit = convert_fp8_to_embit[fp8type]
if s_y is None:
# print("Warning: do not specify s_y in fp8_division_transpose")
s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
_fp8_division_transpose_kernel[grid](
y,
y_t,
x,
s_y,
noise,
M,
N,
SN,
QB,
fp8MaxValue,
e_bit,
m_bit,
x.stride(0),
x.stride(1),
y.stride(0),
y.stride(1),
y_t.stride(0),
y_t.stride(1),
SCALE_MIN_THRES=SCALE_MIN_THRES,
STOCHASTIC=stochastic,
ONLY_TRANSPOSED=only_transposed,
)
if not only_transposed:
# Recover 2D to 3D
if batched:
y = y.reshape(BS, -1, y.shape[-1])
return y, s_y, y_t # y_t is expected to be 2D tensor
else:
return y_t, s_y