Spaces:
Running
on
A100
Running
on
A100
File size: 4,916 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
# 4 block
import triton
import triton.language as tl
from triton.language.extra.cuda import libdevice
from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
"""Quantize Operator"""
"""Input uses 1 * 16 group quantization"""
"""Output uses 1 * 16 group quantization"""
"""The input can be 2D or 3D, but the calculation is performed in 2D"""
@triton.autotune(
configs=[] + get_configs_io_block(),
key=[
"N",
],
)
@triton.heuristics(
{
"BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
}
)
@triton.jit
def _fp8_quantize_kernel(
output_ptr,
output_scale_ptr, # output
input_ptr, # input
M,
N,
SN,
QB: tl.constexpr,
fp8_max, # shape
input_stride_0,
input_stride_1, # input stride
output_stride_0,
output_stride_1, # output stride
s_output_stride_0,
s_output_stride_1, # scale of output stride
SCALE_MIN_THRES: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_SN: tl.constexpr,
): # CUDA block size
# Block PID
pid = tl.program_id(0)
NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
pid_dim0 = pid // NUM_BLOCK_N
pid_dim1 = pid % NUM_BLOCK_N
# pointers
input_block_ptr = tl.make_block_ptr(
base=input_ptr,
shape=(M, N),
strides=(input_stride_0, input_stride_1),
offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
input = tl.load(input_block_ptr)
input = input.to(tl.float32)
output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
# Quantize Scale calculation
abs_output = tl.abs(output)
max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
scale_output = max_val / fp8_max
scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
# Quantize
output = tl.fdiv(output, scale_output)
output = output.to(output_ptr.type.element_ty)
scale_output = scale_output.to(output_scale_ptr.type.element_ty)
scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
output = tl.reshape(output, (BLOCK_M, BLOCK_N))
# debug
# gelu_output = input
# scale_output = scale_input
# pointers
output_block_ptr = tl.make_block_ptr(
base=output_ptr,
shape=(M, N),
strides=(output_stride_0, output_stride_1),
offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
scale_output_ptr = tl.make_block_ptr(
base=output_scale_ptr,
shape=(M, SN),
strides=(s_output_stride_0, s_output_stride_1),
offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
block_shape=(BLOCK_M, BLOCK_SN),
order=(1, 0),
)
tl.store(output_block_ptr, output, boundary_check=(0, 1))
tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
def fp8_quantize(x, QB, fp8type):
# Change batched 3D input to 2D
batched = False
if len(x.shape) == 3:
batched = True
BS = x.shape[0]
x = x.reshape(-1, x.shape[-1])
# defining the input and output tensor
M, N = x.shape
SN = N // QB
if isinstance(fp8type, str):
fp8type = convert_str_to_fp8[fp8type]
y = torch.empty_like(x, dtype=fp8type)
s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
_fp8_quantize_kernel[grid](
y,
s_y,
x,
M,
N,
SN,
QB,
fp8MaxValue,
x.stride(0),
x.stride(1),
y.stride(0),
y.stride(1),
s_y.stride(0),
s_y.stride(1),
SCALE_MIN_THRES=SCALE_MIN_THRES,
)
# Recover 2D to 3D
if batched:
y = y.reshape(BS, -1, y.shape[-1])
s_y = s_y.reshape(BS, -1, s_y.shape[-1])
return y, s_y
|