Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import torch | |
# 4 block | |
import triton | |
import triton.language as tl | |
from triton.language.extra.cuda import libdevice | |
from ._division import _stochastic_rounding | |
from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block | |
"""Division and Transpose Operator""" | |
"""Input uses full-precision/BF16""" | |
"""Output uses per tensor quantization""" | |
"""Output_t uses per tensor quantization and is transposed, but is flattened to 2D""" | |
"""The input can be 2D or 3D, but the calculation is performed in 2D""" | |
def _fp8_division_transpose_kernel( | |
output_ptr, | |
output_t_ptr, # output | |
input_ptr, | |
input_scale_ptr, # input | |
noise_ptr, # noise for stochastic | |
M, | |
N, | |
SN, | |
QB: tl.constexpr, | |
fp8_max, | |
e_bit, | |
m_bit, # shape | |
input_stride_0, | |
input_stride_1, # input stride | |
output_stride_0, | |
output_stride_1, # output stride | |
output_t_stride_0, | |
output_t_stride_1, # output stride | |
SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor | |
STOCHASTIC: tl.constexpr, | |
ONLY_TRANSPOSED: tl.constexpr, | |
BLOCK_M: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
BLOCK_SN: tl.constexpr, | |
): # CUDA block size | |
# Block PID | |
pid = tl.program_id(0) | |
NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) | |
pid_dim0 = pid // NUM_BLOCK_N | |
pid_dim1 = pid % NUM_BLOCK_N | |
# pointers | |
input_block_ptr = tl.make_block_ptr( | |
base=input_ptr, | |
shape=(M, N), | |
strides=(input_stride_0, input_stride_1), | |
offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), | |
block_shape=(BLOCK_M, BLOCK_N), | |
order=(1, 0), | |
) | |
input = tl.load(input_block_ptr) | |
input = input.to(tl.float32) | |
scale_output = tl.load(input_scale_ptr) | |
scale_output = scale_output.to(tl.float32) | |
output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) | |
# Quantize Scale calculation | |
# Quantize | |
output = tl.fdiv(output, scale_output) | |
output = tl.reshape(output, (BLOCK_M, BLOCK_N)) | |
if STOCHASTIC: | |
# noise_block_ptr = tl.make_block_ptr( | |
# base=noise_ptr, | |
# shape=(M, N), | |
# strides=(input_stride_0, input_stride_1), | |
# offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), | |
# block_shape=(BLOCK_M, BLOCK_N), | |
# order=(1, 0) | |
# ) | |
# noise = tl.load(noise_block_ptr) | |
offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N) | |
noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1 | |
noise = tl.rand(0, noise_offset) | |
output = _stochastic_rounding(output, noise, e_bit, m_bit) | |
output = output.to(output_ptr.type.element_ty) | |
# tl.device_print("3: ", output) | |
output_t = tl.trans(output) | |
# pointers | |
output_block_ptr = tl.make_block_ptr( | |
base=output_ptr, | |
shape=(M, N), | |
strides=(output_stride_0, output_stride_1), | |
offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), | |
block_shape=(BLOCK_M, BLOCK_N), | |
order=(1, 0), | |
) | |
output_t_block_ptr = tl.make_block_ptr( | |
base=output_t_ptr, | |
shape=(N, M), | |
strides=(output_t_stride_0, output_t_stride_1), | |
offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M), | |
block_shape=(BLOCK_N, BLOCK_M), | |
order=(1, 0), | |
) | |
if not ONLY_TRANSPOSED: | |
tl.store(output_block_ptr, output, boundary_check=(0, 1)) | |
tl.store(output_t_block_ptr, output_t, boundary_check=(0, 1)) | |
def fp8_division_transpose(x, QB, fp8type, s_y=None, stochastic=False, only_transposed=False): | |
# Change batched 3D input to 2D | |
batched = False | |
if len(x.shape) == 3: | |
batched = True | |
BS = x.shape[0] | |
x = x.reshape(-1, x.shape[-1]) | |
if stochastic: | |
# noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5) | |
noise = None | |
else: | |
noise = None | |
# defining the input and output tensor | |
M, N = x.shape | |
SN = N // QB | |
if isinstance(fp8type, str): | |
fp8type = convert_str_to_fp8[fp8type] | |
y = torch.empty_like(x, dtype=fp8type) | |
y_t = torch.empty((N, M), dtype=fp8type, device=x.device) | |
fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value | |
e_bit, m_bit = convert_fp8_to_embit[fp8type] | |
if s_y is None: | |
# print("Warning: do not specify s_y in fp8_division_transpose") | |
s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue | |
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) | |
_fp8_division_transpose_kernel[grid]( | |
y, | |
y_t, | |
x, | |
s_y, | |
noise, | |
M, | |
N, | |
SN, | |
QB, | |
fp8MaxValue, | |
e_bit, | |
m_bit, | |
x.stride(0), | |
x.stride(1), | |
y.stride(0), | |
y.stride(1), | |
y_t.stride(0), | |
y_t.stride(1), | |
SCALE_MIN_THRES=SCALE_MIN_THRES, | |
STOCHASTIC=stochastic, | |
ONLY_TRANSPOSED=only_transposed, | |
) | |
if not only_transposed: | |
# Recover 2D to 3D | |
if batched: | |
y = y.reshape(BS, -1, y.shape[-1]) | |
return y, s_y, y_t # y_t is expected to be 2D tensor | |
else: | |
return y_t, s_y | |