Add GPTQ-Marlin
Browse files- build.toml +14 -0
- ext-torch/torch_binding.cpp +26 -0
- ext-torch/torch_binding.h +28 -0
- gptq_marlin/awq_marlin_repack.cu +258 -0
- gptq_marlin/gptq_marlin.cu +2423 -0
- gptq_marlin/gptq_marlin_repack.cu +333 -0
build.toml
CHANGED
|
@@ -5,6 +5,7 @@ version = "0.0.1"
|
|
| 5 |
name = "quantization"
|
| 6 |
src = [
|
| 7 |
"core/registration.h",
|
|
|
|
| 8 |
"ext-torch/torch_binding.cpp",
|
| 9 |
"ext-torch/torch_binding.h"
|
| 10 |
]
|
|
@@ -69,3 +70,16 @@ src = [
|
|
| 69 |
]
|
| 70 |
include = [ "." ]
|
| 71 |
depends = [ "torch" ]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
name = "quantization"
|
| 6 |
src = [
|
| 7 |
"core/registration.h",
|
| 8 |
+
"core/scalar_type.hpp",
|
| 9 |
"ext-torch/torch_binding.cpp",
|
| 10 |
"ext-torch/torch_binding.h"
|
| 11 |
]
|
|
|
|
| 70 |
]
|
| 71 |
include = [ "." ]
|
| 72 |
depends = [ "torch" ]
|
| 73 |
+
|
| 74 |
+
[kernel.gptq_marlin]
|
| 75 |
+
capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
| 76 |
+
src = [
|
| 77 |
+
"core/scalar_type.hpp",
|
| 78 |
+
"gptq_marlin/awq_marlin_repack.cu",
|
| 79 |
+
"gptq_marlin/gptq_marlin.cu",
|
| 80 |
+
"gptq_marlin/gptq_marlin_repack.cu",
|
| 81 |
+
"gptq_marlin/marlin.cuh",
|
| 82 |
+
"gptq_marlin/marlin_dtypes.cuh"
|
| 83 |
+
]
|
| 84 |
+
include = [ "." ]
|
| 85 |
+
depends = [ "torch" ]
|
ext-torch/torch_binding.cpp
CHANGED
|
@@ -66,6 +66,32 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 66 |
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
| 67 |
"SymInt size_k) -> Tensor");
|
| 68 |
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
| 66 |
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
|
| 67 |
"SymInt size_k) -> Tensor");
|
| 68 |
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
| 69 |
+
|
| 70 |
+
// awq_marlin repack from AWQ.
|
| 71 |
+
ops.def(
|
| 72 |
+
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
| 73 |
+
"SymInt size_n, int num_bits) -> Tensor");
|
| 74 |
+
ops.impl("awq_marlin_repack", &awq_marlin_repack);
|
| 75 |
+
|
| 76 |
+
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
| 77 |
+
ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
| 78 |
+
ops.def(
|
| 79 |
+
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
| 80 |
+
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
| 81 |
+
"int b_q_type, "
|
| 82 |
+
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
| 83 |
+
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
| 84 |
+
|
| 85 |
+
// gptq_marlin repack from GPTQ.
|
| 86 |
+
ops.def(
|
| 87 |
+
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
| 88 |
+
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
| 89 |
+
ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
|
| 93 |
+
ops.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
| 94 |
+
ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
| 95 |
}
|
| 96 |
|
| 97 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
ext-torch/torch_binding.h
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
|
| 3 |
#include <torch/torch.h>
|
| 4 |
|
|
|
|
|
|
|
| 5 |
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
| 6 |
|
| 7 |
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
|
@@ -46,3 +48,29 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|
| 46 |
torch::Tensor& b_scales, torch::Tensor& workspace,
|
| 47 |
int64_t num_bits, int64_t size_m, int64_t size_n,
|
| 48 |
int64_t size_k);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
#include <torch/torch.h>
|
| 4 |
|
| 5 |
+
#include <core/scalar_type.hpp>
|
| 6 |
+
|
| 7 |
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
| 8 |
|
| 9 |
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
|
|
|
| 48 |
torch::Tensor& b_scales, torch::Tensor& workspace,
|
| 49 |
int64_t num_bits, int64_t size_m, int64_t size_n,
|
| 50 |
int64_t size_k);
|
| 51 |
+
|
| 52 |
+
// GPTQ-Marlin
|
| 53 |
+
|
| 54 |
+
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
| 55 |
+
int64_t size_n, int64_t num_bits);
|
| 56 |
+
|
| 57 |
+
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
| 58 |
+
c10::SymInt size_k, c10::SymInt size_n,
|
| 59 |
+
int64_t num_bits);
|
| 60 |
+
|
| 61 |
+
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
| 62 |
+
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
| 63 |
+
torch::Tensor& g_idx, torch::Tensor& perm,
|
| 64 |
+
torch::Tensor& workspace,
|
| 65 |
+
vllm::ScalarTypeId const& b_q_type_id,
|
| 66 |
+
int64_t size_m, int64_t size_n, int64_t size_k,
|
| 67 |
+
bool is_k_full, bool has_zp,
|
| 68 |
+
bool use_fp32_reduce, bool is_zp_float);
|
| 69 |
+
|
| 70 |
+
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
| 71 |
+
int64_t size_k, int64_t size_n,
|
| 72 |
+
int64_t num_bits);
|
| 73 |
+
|
| 74 |
+
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
| 75 |
+
torch::Tensor& perm, c10::SymInt size_k,
|
| 76 |
+
c10::SymInt size_n, int64_t num_bits);
|
gptq_marlin/awq_marlin_repack.cu
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "marlin.cuh"
|
| 2 |
+
|
| 3 |
+
namespace marlin {
|
| 4 |
+
|
| 5 |
+
template <int const num_threads, int const num_bits>
|
| 6 |
+
__global__ void awq_marlin_repack_kernel(
|
| 7 |
+
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
| 8 |
+
int size_k, int size_n) {
|
| 9 |
+
constexpr int pack_factor = 32 / num_bits;
|
| 10 |
+
|
| 11 |
+
int k_tiles = size_k / tile_k_size;
|
| 12 |
+
int n_tiles = size_n / tile_n_size;
|
| 13 |
+
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
| 14 |
+
|
| 15 |
+
int start_k_tile = blockIdx.x * block_k_tiles;
|
| 16 |
+
if (start_k_tile >= k_tiles) {
|
| 17 |
+
return;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
| 21 |
+
|
| 22 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
| 23 |
+
auto wait_for_stage = [&]() {
|
| 24 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
| 25 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
| 26 |
+
// shared memory load is fully complete (as it may otherwise be
|
| 27 |
+
// overwritten).
|
| 28 |
+
cp_async_wait<repack_stages - 2>();
|
| 29 |
+
__syncthreads();
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
extern __shared__ int4 sh[];
|
| 33 |
+
|
| 34 |
+
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
| 35 |
+
|
| 36 |
+
constexpr int stage_n_threads = tile_n_ints / 4;
|
| 37 |
+
constexpr int stage_k_threads = tile_k_size;
|
| 38 |
+
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
| 39 |
+
|
| 40 |
+
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
| 41 |
+
if (n_tile_id >= n_tiles) {
|
| 42 |
+
cp_async_fence();
|
| 43 |
+
return;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
int first_n = n_tile_id * tile_n_size;
|
| 47 |
+
int first_n_packed = first_n / pack_factor;
|
| 48 |
+
|
| 49 |
+
int4* sh_ptr = sh + stage_size * pipe;
|
| 50 |
+
|
| 51 |
+
if (threadIdx.x < stage_size) {
|
| 52 |
+
int k_id = threadIdx.x / stage_n_threads;
|
| 53 |
+
int n_id = threadIdx.x % stage_n_threads;
|
| 54 |
+
|
| 55 |
+
int first_k = k_tile_id * tile_k_size;
|
| 56 |
+
|
| 57 |
+
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
| 58 |
+
reinterpret_cast<int4 const*>(
|
| 59 |
+
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
|
| 60 |
+
first_n_packed + (n_id * 4)])));
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
cp_async_fence();
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
| 67 |
+
if (n_tile_id >= n_tiles) {
|
| 68 |
+
return;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
int warp_id = threadIdx.x / 32;
|
| 72 |
+
int th_id = threadIdx.x % 32;
|
| 73 |
+
|
| 74 |
+
if (warp_id >= 4) {
|
| 75 |
+
return;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
int tc_col = th_id / 4;
|
| 79 |
+
int tc_row = (th_id % 4) * 2;
|
| 80 |
+
|
| 81 |
+
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
| 82 |
+
|
| 83 |
+
int cur_n = warp_id * 16 + tc_col;
|
| 84 |
+
int cur_n_packed = cur_n / pack_factor;
|
| 85 |
+
int cur_n_pos = cur_n % pack_factor;
|
| 86 |
+
|
| 87 |
+
constexpr int sh_stride = tile_n_ints;
|
| 88 |
+
constexpr uint32_t mask = (1 << num_bits) - 1;
|
| 89 |
+
|
| 90 |
+
int4* sh_stage_ptr = sh + stage_size * pipe;
|
| 91 |
+
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
| 92 |
+
|
| 93 |
+
// Undo interleaving
|
| 94 |
+
int cur_n_pos_unpacked;
|
| 95 |
+
if constexpr (num_bits == 4) {
|
| 96 |
+
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
| 97 |
+
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
| 98 |
+
} else {
|
| 99 |
+
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
| 100 |
+
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
uint32_t vals[8];
|
| 104 |
+
#pragma unroll
|
| 105 |
+
for (int i = 0; i < 4; i++) {
|
| 106 |
+
int cur_elem = tc_row + tc_offsets[i];
|
| 107 |
+
|
| 108 |
+
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
| 109 |
+
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
| 110 |
+
sh_stride * cur_elem];
|
| 111 |
+
|
| 112 |
+
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
| 113 |
+
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
| 117 |
+
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
| 118 |
+
|
| 119 |
+
// Result of:
|
| 120 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
| 121 |
+
if constexpr (num_bits == 4) {
|
| 122 |
+
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
| 123 |
+
|
| 124 |
+
uint32_t res = 0;
|
| 125 |
+
#pragma unroll
|
| 126 |
+
for (int i = 0; i < 8; i++) {
|
| 127 |
+
res |= vals[pack_idx[i]] << (i * 4);
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
| 131 |
+
|
| 132 |
+
} else {
|
| 133 |
+
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
| 134 |
+
|
| 135 |
+
uint32_t res1 = 0;
|
| 136 |
+
uint32_t res2 = 0;
|
| 137 |
+
#pragma unroll
|
| 138 |
+
for (int i = 0; i < 4; i++) {
|
| 139 |
+
res1 |= vals[pack_idx[i]] << (i * 8);
|
| 140 |
+
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
| 144 |
+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
| 145 |
+
}
|
| 146 |
+
};
|
| 147 |
+
|
| 148 |
+
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
| 149 |
+
#pragma unroll
|
| 150 |
+
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
| 151 |
+
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
wait_for_stage();
|
| 155 |
+
};
|
| 156 |
+
#pragma unroll
|
| 157 |
+
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
| 158 |
+
int n_tile_id = 0;
|
| 159 |
+
|
| 160 |
+
start_pipes(k_tile_id, n_tile_id);
|
| 161 |
+
|
| 162 |
+
while (n_tile_id < n_tiles) {
|
| 163 |
+
#pragma unroll
|
| 164 |
+
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
| 165 |
+
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
| 166 |
+
n_tile_id + pipe + repack_stages - 1);
|
| 167 |
+
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
| 168 |
+
wait_for_stage();
|
| 169 |
+
}
|
| 170 |
+
n_tile_id += repack_stages;
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
} // namespace marlin
|
| 176 |
+
|
| 177 |
+
#define CALL_IF(NUM_BITS) \
|
| 178 |
+
else if (num_bits == NUM_BITS) { \
|
| 179 |
+
cudaFuncSetAttribute( \
|
| 180 |
+
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
| 181 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
| 182 |
+
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
| 183 |
+
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
| 184 |
+
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
| 188 |
+
int64_t size_n, int64_t num_bits) {
|
| 189 |
+
// Verify compatibility with marlin tile of 16x64
|
| 190 |
+
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
| 191 |
+
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
| 192 |
+
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
| 193 |
+
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
| 194 |
+
|
| 195 |
+
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
| 196 |
+
"num_bits must be 4 or 8. Got = ", num_bits);
|
| 197 |
+
int const pack_factor = 32 / num_bits;
|
| 198 |
+
|
| 199 |
+
// Verify B
|
| 200 |
+
TORCH_CHECK(b_q_weight.size(0) == size_k,
|
| 201 |
+
"b_q_weight.size(0) = ", b_q_weight.size(0),
|
| 202 |
+
" is not size_k = ", size_k);
|
| 203 |
+
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
|
| 204 |
+
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
| 205 |
+
", size_n = ", size_n, ", pack_factor = ", pack_factor);
|
| 206 |
+
|
| 207 |
+
// Verify device and strides
|
| 208 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
| 209 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
| 210 |
+
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
| 211 |
+
|
| 212 |
+
// Alloc buffers
|
| 213 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
| 214 |
+
auto options = torch::TensorOptions()
|
| 215 |
+
.dtype(b_q_weight.dtype())
|
| 216 |
+
.device(b_q_weight.device());
|
| 217 |
+
torch::Tensor out = torch::empty(
|
| 218 |
+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
| 219 |
+
options);
|
| 220 |
+
|
| 221 |
+
// Get ptrs
|
| 222 |
+
uint32_t const* b_q_weight_ptr =
|
| 223 |
+
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
| 224 |
+
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
| 225 |
+
|
| 226 |
+
// Get dev info
|
| 227 |
+
int dev = b_q_weight.get_device();
|
| 228 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
| 229 |
+
int blocks;
|
| 230 |
+
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
| 231 |
+
|
| 232 |
+
int max_shared_mem = 0;
|
| 233 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
| 234 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
| 235 |
+
TORCH_CHECK(max_shared_mem > 0);
|
| 236 |
+
|
| 237 |
+
if (false) {
|
| 238 |
+
}
|
| 239 |
+
CALL_IF(4)
|
| 240 |
+
CALL_IF(8)
|
| 241 |
+
else {
|
| 242 |
+
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
return out;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
| 249 |
+
c10::SymInt size_k, c10::SymInt size_n,
|
| 250 |
+
int64_t num_bits) {
|
| 251 |
+
int const pack_factor = 32 / num_bits;
|
| 252 |
+
auto options = torch::TensorOptions()
|
| 253 |
+
.dtype(b_q_weight.dtype())
|
| 254 |
+
.device(b_q_weight.device());
|
| 255 |
+
return torch::empty_symint(
|
| 256 |
+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
| 257 |
+
options);
|
| 258 |
+
}
|
gptq_marlin/gptq_marlin.cu
ADDED
|
@@ -0,0 +1,2423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Modified by Neural Magic
|
| 3 |
+
* Copyright (C) Marlin.2024 Elias Frantar
|
| 4 |
+
*
|
| 5 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
* you may not use this file except in compliance with the License.
|
| 7 |
+
* You may obtain a copy of the License at
|
| 8 |
+
*
|
| 9 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
*
|
| 11 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
* See the License for the specific language governing permissions and
|
| 15 |
+
* limitations under the License.
|
| 16 |
+
*/
|
| 17 |
+
|
| 18 |
+
/*
|
| 19 |
+
* Adapted from https://github.com/IST-DASLab/marlin
|
| 20 |
+
*/
|
| 21 |
+
|
| 22 |
+
#include "marlin.cuh"
|
| 23 |
+
#include "marlin_dtypes.cuh"
|
| 24 |
+
#include "core/scalar_type.hpp"
|
| 25 |
+
|
| 26 |
+
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
| 27 |
+
static_assert(std::is_same<scalar_t, half>::value || \
|
| 28 |
+
std::is_same<scalar_t, nv_bfloat16>::value, \
|
| 29 |
+
"only float16 and bfloat16 is supported");
|
| 30 |
+
|
| 31 |
+
template <typename T>
|
| 32 |
+
inline std::string str(T x) {
|
| 33 |
+
return std::to_string(x);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
namespace marlin {
|
| 37 |
+
|
| 38 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
| 39 |
+
|
| 40 |
+
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
| 41 |
+
int const* __restrict__ perm_int_ptr,
|
| 42 |
+
int4* __restrict__ out_int4_ptr, int size_m,
|
| 43 |
+
int size_k, int block_rows) {}
|
| 44 |
+
|
| 45 |
+
template <typename scalar_t, // compute dtype, half or nv_float16
|
| 46 |
+
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
| 47 |
+
const int threads, // number of threads in a threadblock
|
| 48 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
| 49 |
+
// dimension (batchsize) of the
|
| 50 |
+
// threadblock
|
| 51 |
+
const int thread_n_blocks, // same for n dimension (output)
|
| 52 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
| 53 |
+
const int stages, // number of stages for the async global->shared
|
| 54 |
+
// fetch pipeline
|
| 55 |
+
const bool has_act_order, // whether act_order is enabled
|
| 56 |
+
const int group_blocks = -1, // number of consecutive 16x16 blocks
|
| 57 |
+
// with a separate quantization scale
|
| 58 |
+
const bool is_zp_float // is zero point of float16 type?
|
| 59 |
+
>
|
| 60 |
+
__global__ void Marlin(
|
| 61 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
| 62 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
| 63 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
| 64 |
+
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
| 65 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
| 66 |
+
// (k/groupsize)xn
|
| 67 |
+
const int* __restrict__ g_idx, // int32 group indices of shape k
|
| 68 |
+
int num_groups, // number of scale groups per output channel
|
| 69 |
+
int prob_m, // batch dimension m
|
| 70 |
+
int prob_n, // output dimension n
|
| 71 |
+
int prob_k, // reduction dimension k
|
| 72 |
+
int* locks, // extra global storage for barrier synchronization
|
| 73 |
+
bool use_fp32_reduce // whether to use fp32 global reduce
|
| 74 |
+
) {}
|
| 75 |
+
|
| 76 |
+
} // namespace marlin
|
| 77 |
+
|
| 78 |
+
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
| 79 |
+
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
| 80 |
+
torch::Tensor& g_idx, torch::Tensor& perm,
|
| 81 |
+
torch::Tensor& workspace,
|
| 82 |
+
vllm::ScalarTypeId const b_q_type_id,
|
| 83 |
+
int64_t size_m, int64_t size_n, int64_t size_k,
|
| 84 |
+
bool is_k_full, bool has_zp, bool is_zp_float) {
|
| 85 |
+
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
| 86 |
+
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
| 87 |
+
return torch::empty({1, 1});
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
#else
|
| 91 |
+
|
| 92 |
+
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
|
| 93 |
+
// output/accumulation.
|
| 94 |
+
template <typename scalar_t>
|
| 95 |
+
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
|
| 96 |
+
const typename ScalarType<scalar_t>::FragB& frag_b,
|
| 97 |
+
typename ScalarType<scalar_t>::FragC& frag_c) {
|
| 98 |
+
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
|
| 99 |
+
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
| 100 |
+
float* c = reinterpret_cast<float*>(&frag_c);
|
| 101 |
+
if constexpr (std::is_same<scalar_t, half>::value) {
|
| 102 |
+
asm volatile(
|
| 103 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
| 104 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 105 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
| 106 |
+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
| 107 |
+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
| 108 |
+
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
| 109 |
+
asm volatile(
|
| 110 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
| 111 |
+
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
|
| 112 |
+
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
| 113 |
+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
|
| 114 |
+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
|
| 115 |
+
} else {
|
| 116 |
+
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
| 121 |
+
// memory, directly in tensor core layout.
|
| 122 |
+
template <typename scalar_t>
|
| 123 |
+
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
|
| 124 |
+
const void* smem_ptr) {
|
| 125 |
+
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
| 126 |
+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
| 127 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
| 128 |
+
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
| 129 |
+
: "r"(smem));
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// Lookup-table based 3-input logical operation; explicitly used for
|
| 133 |
+
// dequantization as the compiler does not seem to automatically recognize it in
|
| 134 |
+
// all cases.
|
| 135 |
+
template <int lut>
|
| 136 |
+
__device__ inline int lop3(int a, int b, int c) {
|
| 137 |
+
int res;
|
| 138 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
| 139 |
+
: "=r"(res)
|
| 140 |
+
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
| 141 |
+
return res;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// Constructs destination register by taking bytes from 2 sources (based on
|
| 145 |
+
// mask)
|
| 146 |
+
template <int start_byte, int mask>
|
| 147 |
+
__device__ inline uint32_t prmt(uint32_t a) {
|
| 148 |
+
uint32_t res;
|
| 149 |
+
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
| 150 |
+
: "=r"(res)
|
| 151 |
+
: "r"(a), "n"(start_byte), "n"(mask));
|
| 152 |
+
return res;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template <typename scalar_t, vllm::ScalarTypeId w_type_id>
|
| 156 |
+
__device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
|
| 157 |
+
|
| 158 |
+
//
|
| 159 |
+
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
| 160 |
+
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
| 161 |
+
// with some small changes:
|
| 162 |
+
// - FP16:
|
| 163 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
| 164 |
+
// - BF16:
|
| 165 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
| 166 |
+
//
|
| 167 |
+
template <>
|
| 168 |
+
__device__ inline typename ScalarType<half>::FragB
|
| 169 |
+
dequant<half, vllm::kU4B8.id()>(int q) {
|
| 170 |
+
const int LO = 0x000f000f;
|
| 171 |
+
const int HI = 0x00f000f0;
|
| 172 |
+
const int EX = 0x64006400;
|
| 173 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
| 174 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
| 175 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
| 176 |
+
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
| 177 |
+
// directly into `SUB` and `ADD`.
|
| 178 |
+
const int SUB = 0x64086408;
|
| 179 |
+
const int MUL = 0x2c002c00;
|
| 180 |
+
const int ADD = 0xd480d480;
|
| 181 |
+
typename ScalarType<half>::FragB frag_b;
|
| 182 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
| 183 |
+
*reinterpret_cast<const half2*>(&SUB));
|
| 184 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
| 185 |
+
*reinterpret_cast<const half2*>(&MUL),
|
| 186 |
+
*reinterpret_cast<const half2*>(&ADD));
|
| 187 |
+
return frag_b;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template <>
|
| 191 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
| 192 |
+
dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
|
| 193 |
+
static constexpr uint32_t MASK = 0x000f000f;
|
| 194 |
+
static constexpr uint32_t EX = 0x43004300;
|
| 195 |
+
|
| 196 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
| 197 |
+
|
| 198 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
| 199 |
+
q >>= 4;
|
| 200 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
| 201 |
+
|
| 202 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
| 203 |
+
static constexpr uint32_t MUL = 0x3F803F80;
|
| 204 |
+
static constexpr uint32_t ADD = 0xC308C308;
|
| 205 |
+
|
| 206 |
+
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
| 207 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
| 208 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
| 209 |
+
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
| 210 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
| 211 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
| 212 |
+
return frag_b;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
template <>
|
| 216 |
+
__device__ inline typename ScalarType<half>::FragB
|
| 217 |
+
dequant<half, vllm::kU4.id()>(int q) {
|
| 218 |
+
const int LO = 0x000f000f;
|
| 219 |
+
const int HI = 0x00f000f0;
|
| 220 |
+
const int EX = 0x64006400;
|
| 221 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
| 222 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
| 223 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
| 224 |
+
|
| 225 |
+
const int SUB = 0x64006400;
|
| 226 |
+
const int MUL = 0x2c002c00;
|
| 227 |
+
const int ADD = 0xd400d400;
|
| 228 |
+
typename ScalarType<half>::FragB frag_b;
|
| 229 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
| 230 |
+
*reinterpret_cast<const half2*>(&SUB));
|
| 231 |
+
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
| 232 |
+
*reinterpret_cast<const half2*>(&MUL),
|
| 233 |
+
*reinterpret_cast<const half2*>(&ADD));
|
| 234 |
+
return frag_b;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
template <>
|
| 238 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
| 239 |
+
dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
|
| 240 |
+
static constexpr uint32_t MASK = 0x000f000f;
|
| 241 |
+
static constexpr uint32_t EX = 0x43004300;
|
| 242 |
+
|
| 243 |
+
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
| 244 |
+
|
| 245 |
+
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
| 246 |
+
q >>= 4;
|
| 247 |
+
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
| 248 |
+
|
| 249 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
| 250 |
+
static constexpr uint32_t MUL = 0x3F803F80;
|
| 251 |
+
static constexpr uint32_t ADD = 0xC300C300;
|
| 252 |
+
|
| 253 |
+
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
| 254 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
| 255 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
| 256 |
+
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
| 257 |
+
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
| 258 |
+
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
| 259 |
+
return frag_b;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
//
|
| 263 |
+
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
| 264 |
+
// bf16 Reference:
|
| 265 |
+
// - FP16:
|
| 266 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
| 267 |
+
// - BF16:
|
| 268 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
| 269 |
+
//
|
| 270 |
+
template <>
|
| 271 |
+
__device__ inline typename ScalarType<half>::FragB
|
| 272 |
+
dequant<half, vllm::kU8B128.id()>(int q) {
|
| 273 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
| 274 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
| 275 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
| 276 |
+
|
| 277 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
| 278 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
| 279 |
+
|
| 280 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
| 281 |
+
|
| 282 |
+
typename ScalarType<half>::FragB frag_b;
|
| 283 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
| 284 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
| 285 |
+
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
| 286 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
| 287 |
+
return frag_b;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
template <>
|
| 291 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
| 292 |
+
dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) {
|
| 293 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
| 294 |
+
|
| 295 |
+
float fp32_intermediates[4];
|
| 296 |
+
uint32_t* fp32_intermediates_casted =
|
| 297 |
+
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
| 298 |
+
|
| 299 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
| 300 |
+
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
| 301 |
+
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
| 302 |
+
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
| 303 |
+
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
| 304 |
+
|
| 305 |
+
fp32_intermediates[0] -= 8388736.f;
|
| 306 |
+
fp32_intermediates[1] -= 8388736.f;
|
| 307 |
+
fp32_intermediates[2] -= 8388736.f;
|
| 308 |
+
fp32_intermediates[3] -= 8388736.f;
|
| 309 |
+
|
| 310 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
| 311 |
+
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
| 312 |
+
fp32_intermediates_casted[1], 0x7632);
|
| 313 |
+
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
| 314 |
+
fp32_intermediates_casted[3], 0x7632);
|
| 315 |
+
|
| 316 |
+
return frag_b;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
template <>
|
| 320 |
+
__device__ inline typename ScalarType<half>::FragB
|
| 321 |
+
dequant<half, vllm::kU8.id()>(int q) {
|
| 322 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
| 323 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
| 324 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
| 325 |
+
|
| 326 |
+
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
| 327 |
+
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
| 328 |
+
|
| 329 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
| 330 |
+
|
| 331 |
+
typename ScalarType<half>::FragB frag_b;
|
| 332 |
+
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
| 333 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
| 334 |
+
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
| 335 |
+
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
| 336 |
+
return frag_b;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
template <>
|
| 340 |
+
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
| 341 |
+
dequant<nv_bfloat16, vllm::kU8.id()>(int q) {
|
| 342 |
+
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
| 343 |
+
|
| 344 |
+
float fp32_intermediates[4];
|
| 345 |
+
uint32_t* fp32_intermediates_casted =
|
| 346 |
+
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
| 347 |
+
|
| 348 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
| 349 |
+
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
| 350 |
+
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
| 351 |
+
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
| 352 |
+
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
| 353 |
+
|
| 354 |
+
fp32_intermediates[0] -= 8388608.f;
|
| 355 |
+
fp32_intermediates[1] -= 8388608.f;
|
| 356 |
+
fp32_intermediates[2] -= 8388608.f;
|
| 357 |
+
fp32_intermediates[3] -= 8388608.f;
|
| 358 |
+
|
| 359 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
| 360 |
+
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
| 361 |
+
fp32_intermediates_casted[1], 0x7632);
|
| 362 |
+
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
| 363 |
+
fp32_intermediates_casted[3], 0x7632);
|
| 364 |
+
|
| 365 |
+
return frag_b;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
// Multiply dequantized values by the corresponding quantization scale; used
|
| 369 |
+
// only for grouped quantization.
|
| 370 |
+
template <typename scalar_t>
|
| 371 |
+
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
| 372 |
+
typename ScalarType<scalar_t>::FragS& frag_s,
|
| 373 |
+
int i) {
|
| 374 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
| 375 |
+
scalar_t2 s =
|
| 376 |
+
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
|
| 377 |
+
frag_b[0] = __hmul2(frag_b[0], s);
|
| 378 |
+
frag_b[1] = __hmul2(frag_b[1], s);
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
template <typename scalar_t>
|
| 382 |
+
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
|
| 383 |
+
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
|
| 384 |
+
int i) {
|
| 385 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
| 386 |
+
scalar_t2 zp =
|
| 387 |
+
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
|
| 388 |
+
frag_b[0] = __hsub2(frag_b[0], zp);
|
| 389 |
+
frag_b[1] = __hsub2(frag_b[1], zp);
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
// Same as above, but for act_order (each K is multiplied individually)
|
| 393 |
+
template <typename scalar_t>
|
| 394 |
+
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
|
| 395 |
+
typename ScalarType<scalar_t>::FragS& frag_s_1,
|
| 396 |
+
typename ScalarType<scalar_t>::FragS& frag_s_2,
|
| 397 |
+
typename ScalarType<scalar_t>::FragS& frag_s_3,
|
| 398 |
+
typename ScalarType<scalar_t>::FragS& frag_s_4,
|
| 399 |
+
int i) {
|
| 400 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
| 401 |
+
scalar_t2 s_val_1_2;
|
| 402 |
+
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
|
| 403 |
+
s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
|
| 404 |
+
|
| 405 |
+
scalar_t2 s_val_3_4;
|
| 406 |
+
s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];
|
| 407 |
+
s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];
|
| 408 |
+
|
| 409 |
+
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
|
| 410 |
+
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
// Given 2 floats multiply by 2 scales (halves)
|
| 414 |
+
template <typename scalar_t>
|
| 415 |
+
__device__ inline void scale_float(float* c,
|
| 416 |
+
typename ScalarType<scalar_t>::FragS& s) {
|
| 417 |
+
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
|
| 418 |
+
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
|
| 419 |
+
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
// Wait until barrier reaches `count`, then lock for current threadblock.
|
| 423 |
+
__device__ inline void barrier_acquire(int* lock, int count) {
|
| 424 |
+
if (threadIdx.x == 0) {
|
| 425 |
+
int state = -1;
|
| 426 |
+
do
|
| 427 |
+
// Guarantee that subsequent writes by this threadblock will be visible
|
| 428 |
+
// globally.
|
| 429 |
+
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
| 430 |
+
: "=r"(state)
|
| 431 |
+
: "l"(lock));
|
| 432 |
+
while (state != count);
|
| 433 |
+
}
|
| 434 |
+
__syncthreads();
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
// Release barrier and increment visitation count.
|
| 438 |
+
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
| 439 |
+
__syncthreads();
|
| 440 |
+
if (threadIdx.x == 0) {
|
| 441 |
+
if (reset) {
|
| 442 |
+
lock[0] = 0;
|
| 443 |
+
return;
|
| 444 |
+
}
|
| 445 |
+
int val = 1;
|
| 446 |
+
// Make sure that all writes since acquiring this barrier are visible
|
| 447 |
+
// globally, while releasing the barrier.
|
| 448 |
+
asm volatile("fence.acq_rel.gpu;\n");
|
| 449 |
+
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
| 450 |
+
:
|
| 451 |
+
: "l"(lock), "r"(val));
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
| 456 |
+
// on the given "perm" indices.
|
| 457 |
+
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
| 458 |
+
int const* __restrict__ perm_int_ptr,
|
| 459 |
+
int4* __restrict__ out_int4_ptr, int size_m,
|
| 460 |
+
int size_k, int block_rows) {
|
| 461 |
+
int start_row = block_rows * blockIdx.x;
|
| 462 |
+
int finish_row = start_row + block_rows;
|
| 463 |
+
if (finish_row > size_m) {
|
| 464 |
+
finish_row = size_m;
|
| 465 |
+
}
|
| 466 |
+
int cur_block_rows = finish_row - start_row;
|
| 467 |
+
|
| 468 |
+
int row_stride = size_k * sizeof(half) / 16;
|
| 469 |
+
|
| 470 |
+
auto permute_row = [&](int row) {
|
| 471 |
+
int iters = size_k / default_threads;
|
| 472 |
+
int rest = size_k % default_threads;
|
| 473 |
+
|
| 474 |
+
int offset = row * row_stride;
|
| 475 |
+
|
| 476 |
+
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
| 477 |
+
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
| 478 |
+
|
| 479 |
+
int base_k = 0;
|
| 480 |
+
|
| 481 |
+
for (int i = 0; i < iters; i++) {
|
| 482 |
+
int cur_k = base_k + threadIdx.x;
|
| 483 |
+
int src_pos = perm_int_ptr[cur_k];
|
| 484 |
+
|
| 485 |
+
out_half[cur_k] = a_row_half[src_pos];
|
| 486 |
+
|
| 487 |
+
base_k += default_threads;
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
if (rest) {
|
| 491 |
+
if (threadIdx.x < rest) {
|
| 492 |
+
int cur_k = base_k + threadIdx.x;
|
| 493 |
+
int src_pos = perm_int_ptr[cur_k];
|
| 494 |
+
|
| 495 |
+
out_half[cur_k] = a_row_half[src_pos];
|
| 496 |
+
}
|
| 497 |
+
}
|
| 498 |
+
};
|
| 499 |
+
|
| 500 |
+
for (int i = 0; i < cur_block_rows; i++) {
|
| 501 |
+
int cur_row = start_row + i;
|
| 502 |
+
if (cur_row < size_m) {
|
| 503 |
+
permute_row(cur_row);
|
| 504 |
+
}
|
| 505 |
+
}
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
template <typename scalar_t, // compute dtype, half or nv_float16
|
| 509 |
+
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
| 510 |
+
const int threads, // number of threads in a threadblock
|
| 511 |
+
const int thread_m_blocks, // number of 16x16 blocks in the m
|
| 512 |
+
// dimension (batchsize) of the
|
| 513 |
+
// threadblock
|
| 514 |
+
const int thread_n_blocks, // same for n dimension (output)
|
| 515 |
+
const int thread_k_blocks, // same for k dimension (reduction)
|
| 516 |
+
const int stages, // number of stages for the async global->shared
|
| 517 |
+
// fetch pipeline
|
| 518 |
+
const bool has_act_order, // whether act_order is enabled
|
| 519 |
+
const bool has_zp, // whether zero-points are enabled
|
| 520 |
+
const int group_blocks = -1, // number of consecutive 16x16 blocks
|
| 521 |
+
// with a separate quantization scale
|
| 522 |
+
const bool is_zp_float // is zero point of float16 type?
|
| 523 |
+
>
|
| 524 |
+
__global__ void Marlin(
|
| 525 |
+
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
| 526 |
+
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
| 527 |
+
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
| 528 |
+
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
| 529 |
+
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
| 530 |
+
// (k/groupsize)xn
|
| 531 |
+
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
| 532 |
+
// (k/groupsize)x(n/pack_factor)
|
| 533 |
+
const int* __restrict__ g_idx, // int32 group indices of shape k
|
| 534 |
+
int num_groups, // number of scale groups per output channel
|
| 535 |
+
int prob_m, // batch dimension m
|
| 536 |
+
int prob_n, // output dimension n
|
| 537 |
+
int prob_k, // reduction dimension k
|
| 538 |
+
int* locks, // extra global storage for barrier synchronization
|
| 539 |
+
bool use_fp32_reduce // whether to use fp32 global reduce
|
| 540 |
+
) {
|
| 541 |
+
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
| 542 |
+
// same size, which might involve multiple column "slices" (of width 16 *
|
| 543 |
+
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
|
| 544 |
+
// example:
|
| 545 |
+
// 0 1 3
|
| 546 |
+
// 0 2 3
|
| 547 |
+
// 1 2 4
|
| 548 |
+
// While this kind of partitioning makes things somewhat more complicated, it
|
| 549 |
+
// ensures good utilization of all SMs for many kinds of shape and GPU
|
| 550 |
+
// configurations, while requiring as few slow global cross-threadblock
|
| 551 |
+
// reductions as possible.
|
| 552 |
+
using Dtype = ScalarType<scalar_t>;
|
| 553 |
+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
| 554 |
+
using FragA = typename ScalarType<scalar_t>::FragA;
|
| 555 |
+
using FragB = typename ScalarType<scalar_t>::FragB;
|
| 556 |
+
using FragC = typename ScalarType<scalar_t>::FragC;
|
| 557 |
+
using FragS = typename ScalarType<scalar_t>::FragS;
|
| 558 |
+
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
| 559 |
+
|
| 560 |
+
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
| 561 |
+
|
| 562 |
+
constexpr int pack_factor = 32 / w_type.size_bits();
|
| 563 |
+
|
| 564 |
+
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
| 565 |
+
// better partitioning with less reductions
|
| 566 |
+
int parallel = 1;
|
| 567 |
+
if (prob_m > 16 * thread_m_blocks) {
|
| 568 |
+
parallel = prob_m / (16 * thread_m_blocks);
|
| 569 |
+
prob_m = 16 * thread_m_blocks;
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
int k_tiles = prob_k / 16 / thread_k_blocks;
|
| 573 |
+
int n_tiles = prob_n / 16 / thread_n_blocks;
|
| 574 |
+
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
|
| 575 |
+
|
| 576 |
+
if constexpr (!has_act_order && group_blocks != -1) {
|
| 577 |
+
if (group_blocks >= thread_k_blocks) {
|
| 578 |
+
// Ensure that the number of tiles in each stripe is a multiple of the
|
| 579 |
+
// groupsize; this avoids an annoying special case where a stripe starts
|
| 580 |
+
// in the middle of group.
|
| 581 |
+
iters = (group_blocks / thread_k_blocks) *
|
| 582 |
+
div_ceil(iters, (group_blocks / thread_k_blocks));
|
| 583 |
+
}
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
int slice_row = (iters * blockIdx.x) % k_tiles;
|
| 587 |
+
int slice_col_par = (iters * blockIdx.x) / k_tiles;
|
| 588 |
+
int slice_col = slice_col_par;
|
| 589 |
+
int slice_iters; // number of threadblock tiles in the current slice
|
| 590 |
+
int slice_count =
|
| 591 |
+
0; // total number of active threadblocks in the current slice
|
| 592 |
+
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
| 593 |
+
// top
|
| 594 |
+
|
| 595 |
+
int par_id = 0;
|
| 596 |
+
|
| 597 |
+
// We can easily implement parallel problem execution by just remapping
|
| 598 |
+
// indices and advancing global pointers
|
| 599 |
+
if (slice_col_par >= n_tiles) {
|
| 600 |
+
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
|
| 601 |
+
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
| 602 |
+
locks += (slice_col_par / n_tiles) * n_tiles;
|
| 603 |
+
slice_col = slice_col_par % n_tiles;
|
| 604 |
+
par_id = slice_col_par / n_tiles;
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
// Compute all information about the current slice which is required for
|
| 608 |
+
// synchronization.
|
| 609 |
+
auto init_slice = [&]() {
|
| 610 |
+
slice_iters =
|
| 611 |
+
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
|
| 612 |
+
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
|
| 613 |
+
if (slice_iters == 0) return;
|
| 614 |
+
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
|
| 615 |
+
slice_count = 1;
|
| 616 |
+
slice_idx = 0;
|
| 617 |
+
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
|
| 618 |
+
if (col_first <= k_tiles * (slice_col_par + 1)) {
|
| 619 |
+
int col_off = col_first - k_tiles * slice_col_par;
|
| 620 |
+
slice_count = div_ceil(k_tiles - col_off, iters);
|
| 621 |
+
if (col_off > 0) slice_count++;
|
| 622 |
+
int delta_first = iters * blockIdx.x - col_first;
|
| 623 |
+
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
|
| 624 |
+
slice_idx = slice_count - 1;
|
| 625 |
+
else {
|
| 626 |
+
slice_idx = slice_count - 1 - delta_first / iters;
|
| 627 |
+
if (col_off > 0) slice_idx--;
|
| 628 |
+
}
|
| 629 |
+
}
|
| 630 |
+
if (slice_col == n_tiles) {
|
| 631 |
+
A += 16 * thread_m_blocks * prob_k / 8;
|
| 632 |
+
C += 16 * thread_m_blocks * prob_n / 8;
|
| 633 |
+
locks += n_tiles;
|
| 634 |
+
slice_col = 0;
|
| 635 |
+
par_id++;
|
| 636 |
+
}
|
| 637 |
+
};
|
| 638 |
+
init_slice();
|
| 639 |
+
|
| 640 |
+
// A sizes/strides
|
| 641 |
+
|
| 642 |
+
// stride of the A matrix in global memory
|
| 643 |
+
int a_gl_stride = prob_k / 8;
|
| 644 |
+
// stride of an A matrix tile in shared memory
|
| 645 |
+
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
|
| 646 |
+
// delta between subsequent A tiles in global memory
|
| 647 |
+
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
|
| 648 |
+
// between subsequent accesses within a tile
|
| 649 |
+
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
|
| 650 |
+
// between shared memory writes
|
| 651 |
+
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
|
| 652 |
+
// between shared memory tile reads
|
| 653 |
+
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
|
| 654 |
+
// within a shared memory tile
|
| 655 |
+
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
|
| 656 |
+
// overall size of a tile
|
| 657 |
+
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
|
| 658 |
+
// number of shared write iterations for a tile
|
| 659 |
+
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
|
| 660 |
+
|
| 661 |
+
// B sizes/strides
|
| 662 |
+
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
| 663 |
+
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
| 664 |
+
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
|
| 665 |
+
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
| 666 |
+
|
| 667 |
+
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
| 668 |
+
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
|
| 669 |
+
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
|
| 670 |
+
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
|
| 671 |
+
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
|
| 672 |
+
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
| 673 |
+
|
| 674 |
+
// Scale sizes/strides without act_order
|
| 675 |
+
int s_gl_stride = prob_n / 8;
|
| 676 |
+
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
| 677 |
+
constexpr int s_tb_groups =
|
| 678 |
+
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
| 679 |
+
? thread_k_blocks / group_blocks
|
| 680 |
+
: 1;
|
| 681 |
+
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
| 682 |
+
int s_gl_rd_delta = s_gl_stride;
|
| 683 |
+
|
| 684 |
+
// Scale size/strides with act_order
|
| 685 |
+
constexpr int tb_k = 16 * thread_k_blocks;
|
| 686 |
+
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
|
| 687 |
+
// constexpr int act_s_row_stride = 1;
|
| 688 |
+
// int act_s_col_stride = act_s_row_stride * num_groups;
|
| 689 |
+
int act_s_col_stride = 1;
|
| 690 |
+
int act_s_col_warp_stride = act_s_col_stride * 8;
|
| 691 |
+
int tb_n_warps = thread_n_blocks / 4;
|
| 692 |
+
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
| 693 |
+
|
| 694 |
+
// Zero-points sizes/strides
|
| 695 |
+
int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
|
| 696 |
+
constexpr int zp_sh_stride = is_zp_float
|
| 697 |
+
? 16 * thread_n_blocks / 8
|
| 698 |
+
: ((16 * thread_n_blocks) / pack_factor) / 4;
|
| 699 |
+
constexpr int zp_tb_groups = s_tb_groups;
|
| 700 |
+
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
| 701 |
+
int zp_gl_rd_delta = zp_gl_stride;
|
| 702 |
+
|
| 703 |
+
// Global A read index of current thread.
|
| 704 |
+
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
| 705 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
| 706 |
+
a_gl_rd += a_gl_rd_delta_o * slice_row;
|
| 707 |
+
// Shared write index of current thread.
|
| 708 |
+
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
| 709 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
| 710 |
+
// Shared read index.
|
| 711 |
+
int a_sh_rd =
|
| 712 |
+
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
|
| 713 |
+
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
|
| 714 |
+
|
| 715 |
+
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
|
| 716 |
+
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
| 717 |
+
b_gl_rd += b_sh_stride * slice_col;
|
| 718 |
+
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
| 719 |
+
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
| 720 |
+
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
| 721 |
+
|
| 722 |
+
// For act_order
|
| 723 |
+
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
| 724 |
+
int slice_k_start = tb_k * slice_row;
|
| 725 |
+
int slice_k_finish = slice_k_start + tb_k * slice_iters;
|
| 726 |
+
int slice_k_start_shared_fetch = slice_k_start;
|
| 727 |
+
int slice_n_offset = act_s_col_tb_stride * slice_col;
|
| 728 |
+
|
| 729 |
+
// No act_order
|
| 730 |
+
int s_gl_rd;
|
| 731 |
+
if constexpr (!has_act_order) {
|
| 732 |
+
if constexpr (group_blocks == -1) {
|
| 733 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
| 734 |
+
} else {
|
| 735 |
+
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
| 736 |
+
s_sh_stride * slice_col + threadIdx.x;
|
| 737 |
+
}
|
| 738 |
+
}
|
| 739 |
+
int s_sh_wr = threadIdx.x;
|
| 740 |
+
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
| 741 |
+
|
| 742 |
+
// Zero-points
|
| 743 |
+
int zp_gl_rd;
|
| 744 |
+
if constexpr (has_zp) {
|
| 745 |
+
if constexpr (group_blocks == -1) {
|
| 746 |
+
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
| 747 |
+
} else {
|
| 748 |
+
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
| 749 |
+
zp_sh_stride * slice_col + threadIdx.x;
|
| 750 |
+
}
|
| 751 |
+
}
|
| 752 |
+
int zp_sh_wr = threadIdx.x;
|
| 753 |
+
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
| 754 |
+
|
| 755 |
+
// We use a different scale layout for grouped and column-wise quantization as
|
| 756 |
+
// we scale a `half2` tile in column-major layout in the former and in
|
| 757 |
+
// row-major in the latter case.
|
| 758 |
+
int s_sh_rd;
|
| 759 |
+
if constexpr (group_blocks != -1)
|
| 760 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
| 761 |
+
(threadIdx.x % 32) / 4;
|
| 762 |
+
else
|
| 763 |
+
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
| 764 |
+
(threadIdx.x % 32) % 4;
|
| 765 |
+
|
| 766 |
+
// Zero-points have the same read layout as the scales
|
| 767 |
+
// (without column-wise case)
|
| 768 |
+
constexpr int num_col_threads = 8;
|
| 769 |
+
constexpr int num_row_threads = 4;
|
| 770 |
+
constexpr int num_ints_per_thread = 8 / pack_factor;
|
| 771 |
+
int zp_sh_rd;
|
| 772 |
+
if constexpr (has_zp) {
|
| 773 |
+
if constexpr (is_zp_float) {
|
| 774 |
+
if constexpr (group_blocks != -1) {
|
| 775 |
+
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
| 776 |
+
(threadIdx.x % 32) / 4;
|
| 777 |
+
}
|
| 778 |
+
} else {
|
| 779 |
+
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
| 780 |
+
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
| 781 |
+
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
| 782 |
+
}
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
// Precompute which thread should not read memory in which iterations; this is
|
| 786 |
+
// needed if there are more threads than required for a certain tilesize or
|
| 787 |
+
// when the batchsize is not a multiple of 16.
|
| 788 |
+
bool a_sh_wr_pred[a_sh_wr_iters];
|
| 789 |
+
#pragma unroll
|
| 790 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
| 791 |
+
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
|
| 792 |
+
|
| 793 |
+
// To ensure that writing and reading A tiles to/from shared memory, the
|
| 794 |
+
// latter in fragment format, is fully bank conflict free, we need to use a
|
| 795 |
+
// rather fancy XOR-based layout. The key here is that neither reads nor
|
| 796 |
+
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
|
| 797 |
+
// same shared memory banks. Further, it seems (based on NSight-Compute) that
|
| 798 |
+
// each warp must also write a consecutive memory segment?
|
| 799 |
+
auto transform_a = [&](int i) {
|
| 800 |
+
int row = i / a_gl_rd_delta_o;
|
| 801 |
+
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
|
| 802 |
+
};
|
| 803 |
+
// Since the computation of this remapping is non-trivial and, due to our main
|
| 804 |
+
// loop unrolls, all shared memory accesses are static, we simply precompute
|
| 805 |
+
// both transformed reads and writes.
|
| 806 |
+
int a_sh_wr_trans[a_sh_wr_iters];
|
| 807 |
+
#pragma unroll
|
| 808 |
+
for (int i = 0; i < a_sh_wr_iters; i++)
|
| 809 |
+
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
|
| 810 |
+
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
|
| 811 |
+
#pragma unroll
|
| 812 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
| 813 |
+
#pragma unroll
|
| 814 |
+
for (int j = 0; j < thread_m_blocks; j++)
|
| 815 |
+
a_sh_rd_trans[i][j] =
|
| 816 |
+
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
// Since B-accesses have non-constant stride they have to be computed at
|
| 820 |
+
// runtime; we break dependencies between subsequent accesses with a tile by
|
| 821 |
+
// maintining multiple pointers (we have enough registers), a tiny
|
| 822 |
+
// optimization.
|
| 823 |
+
const int4* B_ptr[b_sh_wr_iters];
|
| 824 |
+
#pragma unroll
|
| 825 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
| 826 |
+
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
|
| 827 |
+
|
| 828 |
+
extern __shared__ int4 sh[];
|
| 829 |
+
// Shared memory storage for global fetch pipelines.
|
| 830 |
+
int4* sh_a = sh;
|
| 831 |
+
int4* sh_b = sh_a + (stages * a_sh_stage);
|
| 832 |
+
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
| 833 |
+
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
| 834 |
+
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
| 835 |
+
|
| 836 |
+
// Register storage for double buffer of shared memory reads.
|
| 837 |
+
FragA frag_a[2][thread_m_blocks];
|
| 838 |
+
I4 frag_b_quant[2][b_thread_vecs];
|
| 839 |
+
FragC frag_c[thread_m_blocks][4][2];
|
| 840 |
+
FragS frag_s[2][4]; // No act-order
|
| 841 |
+
FragS act_frag_s[2][4][4]; // For act-order
|
| 842 |
+
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
| 843 |
+
FragZP frag_zp; // Zero-points in fp16
|
| 844 |
+
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
|
| 845 |
+
|
| 846 |
+
// Zero accumulators.
|
| 847 |
+
auto zero_accums = [&]() {
|
| 848 |
+
#pragma unroll
|
| 849 |
+
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
|
| 850 |
+
reinterpret_cast<float*>(frag_c)[i] = 0;
|
| 851 |
+
};
|
| 852 |
+
|
| 853 |
+
int sh_first_group_id = -1;
|
| 854 |
+
int sh_num_groups = -1;
|
| 855 |
+
constexpr int sh_max_num_groups = 32;
|
| 856 |
+
|
| 857 |
+
auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
|
| 858 |
+
int last_group_id) {
|
| 859 |
+
sh_first_group_id = first_group_id;
|
| 860 |
+
sh_num_groups = last_group_id - first_group_id + 1;
|
| 861 |
+
|
| 862 |
+
if (sh_num_groups < sh_max_num_groups) {
|
| 863 |
+
sh_num_groups = sh_max_num_groups;
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
if (sh_first_group_id + sh_num_groups > num_groups) {
|
| 867 |
+
sh_num_groups = num_groups - sh_first_group_id;
|
| 868 |
+
}
|
| 869 |
+
|
| 870 |
+
int row_offset = first_group_id * s_gl_stride;
|
| 871 |
+
|
| 872 |
+
if (is_async) {
|
| 873 |
+
for (int i = 0; i < sh_num_groups; i++) {
|
| 874 |
+
if (threadIdx.x < s_sh_stride) {
|
| 875 |
+
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
|
| 876 |
+
&scales_ptr[row_offset + (i * s_gl_stride) +
|
| 877 |
+
slice_n_offset + threadIdx.x]);
|
| 878 |
+
}
|
| 879 |
+
}
|
| 880 |
+
} else {
|
| 881 |
+
for (int i = 0; i < sh_num_groups; i++) {
|
| 882 |
+
if (threadIdx.x < s_sh_stride) {
|
| 883 |
+
sh_s[(i * s_sh_stride) + threadIdx.x] =
|
| 884 |
+
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
|
| 885 |
+
threadIdx.x];
|
| 886 |
+
}
|
| 887 |
+
}
|
| 888 |
+
}
|
| 889 |
+
};
|
| 890 |
+
// Asynchronously fetch the next A, B and s tile from global to the next
|
| 891 |
+
// shared memory pipeline location.
|
| 892 |
+
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
|
| 893 |
+
if (pred) {
|
| 894 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
| 895 |
+
#pragma unroll
|
| 896 |
+
for (int i = 0; i < a_sh_wr_iters; i++) {
|
| 897 |
+
cp_async4_pred(
|
| 898 |
+
&sh_a_stage[a_sh_wr_trans[i]],
|
| 899 |
+
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
|
| 900 |
+
a_sh_wr_pred[i]);
|
| 901 |
+
}
|
| 902 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
| 903 |
+
#pragma unroll
|
| 904 |
+
for (int i = 0; i < b_sh_wr_iters; i++) {
|
| 905 |
+
#pragma unroll
|
| 906 |
+
for (int j = 0; j < b_thread_vecs; j++) {
|
| 907 |
+
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
B_ptr[i] += b_gl_rd_delta_o;
|
| 911 |
+
}
|
| 912 |
+
|
| 913 |
+
if constexpr (has_act_order) {
|
| 914 |
+
// Fetch g_idx thread-block portion
|
| 915 |
+
int full_pipe = a_off;
|
| 916 |
+
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
|
| 917 |
+
if (cur_k < prob_k && cur_k < slice_k_finish) {
|
| 918 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
| 919 |
+
|
| 920 |
+
int4 const* cur_g_idx_stage_ptr =
|
| 921 |
+
reinterpret_cast<int4 const*>(&g_idx[cur_k]);
|
| 922 |
+
|
| 923 |
+
if (threadIdx.x < g_idx_stage) {
|
| 924 |
+
cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
|
| 925 |
+
&cur_g_idx_stage_ptr[threadIdx.x]);
|
| 926 |
+
}
|
| 927 |
+
}
|
| 928 |
+
} else {
|
| 929 |
+
if constexpr (group_blocks != -1) {
|
| 930 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
| 931 |
+
|
| 932 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
| 933 |
+
// Only fetch scales if this tile starts a new group
|
| 934 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
| 935 |
+
if (s_sh_wr_pred) {
|
| 936 |
+
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
| 937 |
+
}
|
| 938 |
+
s_gl_rd += s_gl_rd_delta;
|
| 939 |
+
}
|
| 940 |
+
} else {
|
| 941 |
+
for (int i = 0; i < s_tb_groups; i++) {
|
| 942 |
+
if (s_sh_wr_pred) {
|
| 943 |
+
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
|
| 944 |
+
&scales_ptr[s_gl_rd]);
|
| 945 |
+
}
|
| 946 |
+
s_gl_rd += s_gl_rd_delta;
|
| 947 |
+
}
|
| 948 |
+
}
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
if constexpr (has_zp && group_blocks != -1) {
|
| 952 |
+
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
| 953 |
+
|
| 954 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
| 955 |
+
// Only fetch zero-points if this tile starts a new group
|
| 956 |
+
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
| 957 |
+
if (zp_sh_wr_pred) {
|
| 958 |
+
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
| 959 |
+
}
|
| 960 |
+
zp_gl_rd += zp_gl_rd_delta;
|
| 961 |
+
}
|
| 962 |
+
} else {
|
| 963 |
+
for (int i = 0; i < zp_tb_groups; i++) {
|
| 964 |
+
if (zp_sh_wr_pred) {
|
| 965 |
+
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
| 966 |
+
&zp_ptr[zp_gl_rd]);
|
| 967 |
+
}
|
| 968 |
+
zp_gl_rd += zp_gl_rd_delta;
|
| 969 |
+
}
|
| 970 |
+
}
|
| 971 |
+
}
|
| 972 |
+
}
|
| 973 |
+
}
|
| 974 |
+
// Insert a fence even when we are winding down the pipeline to ensure that
|
| 975 |
+
// waiting is also correct at this point.
|
| 976 |
+
cp_async_fence();
|
| 977 |
+
};
|
| 978 |
+
|
| 979 |
+
auto fetch_zp_to_shared = [&]() {
|
| 980 |
+
if (zp_sh_wr_pred) {
|
| 981 |
+
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
| 982 |
+
}
|
| 983 |
+
};
|
| 984 |
+
|
| 985 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
| 986 |
+
auto wait_for_stage = [&]() {
|
| 987 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
| 988 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
| 989 |
+
// shared memory load is fully complete (as it may otherwise be
|
| 990 |
+
// overwritten).
|
| 991 |
+
cp_async_wait<stages - 2>();
|
| 992 |
+
__syncthreads();
|
| 993 |
+
};
|
| 994 |
+
|
| 995 |
+
// Load the next sub-tile from the current location in the shared memory pipe
|
| 996 |
+
// into the current register buffer.
|
| 997 |
+
auto fetch_to_registers = [&](int k, int pipe) {
|
| 998 |
+
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
|
| 999 |
+
#pragma unroll
|
| 1000 |
+
for (int i = 0; i < thread_m_blocks; i++)
|
| 1001 |
+
ldsm4<scalar_t>(frag_a[k % 2][i],
|
| 1002 |
+
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
|
| 1003 |
+
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
|
| 1004 |
+
|
| 1005 |
+
#pragma unroll
|
| 1006 |
+
for (int i = 0; i < b_thread_vecs; i++) {
|
| 1007 |
+
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
|
| 1008 |
+
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
|
| 1009 |
+
}
|
| 1010 |
+
};
|
| 1011 |
+
|
| 1012 |
+
bool is_same_group[stages];
|
| 1013 |
+
int same_group_id[stages];
|
| 1014 |
+
|
| 1015 |
+
auto init_same_group = [&](int pipe) {
|
| 1016 |
+
if constexpr (!has_act_order) {
|
| 1017 |
+
is_same_group[pipe] = false;
|
| 1018 |
+
same_group_id[pipe] = 0;
|
| 1019 |
+
return;
|
| 1020 |
+
}
|
| 1021 |
+
|
| 1022 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
| 1023 |
+
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
| 1024 |
+
|
| 1025 |
+
int group_id_1 = sh_g_idx_int_ptr[0];
|
| 1026 |
+
int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
|
| 1027 |
+
|
| 1028 |
+
is_same_group[pipe] = group_id_1 == group_id_2;
|
| 1029 |
+
same_group_id[pipe] = group_id_1;
|
| 1030 |
+
};
|
| 1031 |
+
|
| 1032 |
+
auto fetch_scales_to_registers = [&](int k, int full_pipe) {
|
| 1033 |
+
int pipe = full_pipe % stages;
|
| 1034 |
+
|
| 1035 |
+
if constexpr (!has_act_order) {
|
| 1036 |
+
// No act-order case
|
| 1037 |
+
if constexpr (group_blocks != -1) {
|
| 1038 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
| 1039 |
+
int4* sh_s_stage =
|
| 1040 |
+
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
| 1041 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
| 1042 |
+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
| 1043 |
+
} else {
|
| 1044 |
+
int warp_id = threadIdx.x / 32;
|
| 1045 |
+
int n_warps = thread_n_blocks / 4;
|
| 1046 |
+
|
| 1047 |
+
int warp_row = warp_id / n_warps;
|
| 1048 |
+
|
| 1049 |
+
int cur_k = warp_row * 16;
|
| 1050 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
| 1051 |
+
|
| 1052 |
+
int k_blocks = cur_k / 16;
|
| 1053 |
+
int cur_group_id = k_blocks / group_blocks;
|
| 1054 |
+
|
| 1055 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
| 1056 |
+
|
| 1057 |
+
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
| 1058 |
+
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
| 1059 |
+
}
|
| 1060 |
+
}
|
| 1061 |
+
|
| 1062 |
+
return;
|
| 1063 |
+
}
|
| 1064 |
+
|
| 1065 |
+
// Act-order case
|
| 1066 |
+
|
| 1067 |
+
// Determine K of the "current" thread-block
|
| 1068 |
+
int cur_k = slice_k_start + tb_k * full_pipe;
|
| 1069 |
+
if (cur_k >= prob_k || cur_k >= slice_k_finish) {
|
| 1070 |
+
return;
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
+
// Reset (to current thread-block) since we read g_idx portion from the
|
| 1074 |
+
// shared memory
|
| 1075 |
+
cur_k = 0;
|
| 1076 |
+
|
| 1077 |
+
// Progress to current iteration
|
| 1078 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
| 1079 |
+
|
| 1080 |
+
// Determine "position" inside the thread-block (based on warp and
|
| 1081 |
+
// thread-id)
|
| 1082 |
+
int warp_id = threadIdx.x / 32;
|
| 1083 |
+
int n_warps =
|
| 1084 |
+
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
| 1085 |
+
|
| 1086 |
+
int warp_row = warp_id / n_warps;
|
| 1087 |
+
int warp_col = warp_id % n_warps;
|
| 1088 |
+
|
| 1089 |
+
cur_k += warp_row * 16;
|
| 1090 |
+
|
| 1091 |
+
int th_id = threadIdx.x % 32;
|
| 1092 |
+
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
| 1093 |
+
|
| 1094 |
+
int s_col_shift =
|
| 1095 |
+
/*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
|
| 1096 |
+
(th_id / 4) * act_s_col_stride;
|
| 1097 |
+
|
| 1098 |
+
if (is_same_group[pipe]) {
|
| 1099 |
+
if (k % 2 == 0) {
|
| 1100 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
|
| 1101 |
+
sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
|
| 1102 |
+
s_col_shift];
|
| 1103 |
+
} else {
|
| 1104 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
|
| 1105 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
|
| 1106 |
+
}
|
| 1107 |
+
|
| 1108 |
+
for (int i = 1; i < 4; i++) {
|
| 1109 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
|
| 1110 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
|
| 1111 |
+
}
|
| 1112 |
+
return;
|
| 1113 |
+
}
|
| 1114 |
+
|
| 1115 |
+
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
|
| 1116 |
+
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
|
| 1117 |
+
|
| 1118 |
+
constexpr int k_frag_offsets[4] = {0, 1, 8,
|
| 1119 |
+
9}; // Tensor core offsets per thread
|
| 1120 |
+
|
| 1121 |
+
#pragma unroll
|
| 1122 |
+
for (int i = 0; i < 4; i++) {
|
| 1123 |
+
int actual_k = cur_k + k_frag_offsets[i];
|
| 1124 |
+
|
| 1125 |
+
int group_id = sh_g_idx_int_ptr[actual_k];
|
| 1126 |
+
int rel_group_id = group_id - sh_first_group_id;
|
| 1127 |
+
|
| 1128 |
+
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =
|
| 1129 |
+
sh_s[rel_group_id * s_sh_stride + s_col_shift];
|
| 1130 |
+
}
|
| 1131 |
+
};
|
| 1132 |
+
|
| 1133 |
+
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
| 1134 |
+
// This code does not handle group_blocks == 0,
|
| 1135 |
+
// which signifies act_order.
|
| 1136 |
+
// has_zp implies AWQ, which doesn't have act_order,
|
| 1137 |
+
static_assert(!has_zp || group_blocks != 0);
|
| 1138 |
+
|
| 1139 |
+
if constexpr (has_zp && !is_zp_float) {
|
| 1140 |
+
int pipe = full_pipe % stages;
|
| 1141 |
+
|
| 1142 |
+
if constexpr (group_blocks == -1) {
|
| 1143 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
| 1144 |
+
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
+
} else if constexpr (group_blocks >= thread_k_blocks) {
|
| 1148 |
+
int4* sh_zp_stage =
|
| 1149 |
+
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
| 1150 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
| 1151 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
| 1152 |
+
frag_qzp[k % 2][i] =
|
| 1153 |
+
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
| 1154 |
+
}
|
| 1155 |
+
} else {
|
| 1156 |
+
int warp_id = threadIdx.x / 32;
|
| 1157 |
+
int n_warps = thread_n_blocks / 4;
|
| 1158 |
+
|
| 1159 |
+
int warp_row = warp_id / n_warps;
|
| 1160 |
+
|
| 1161 |
+
int cur_k = warp_row * 16;
|
| 1162 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
| 1163 |
+
|
| 1164 |
+
int k_blocks = cur_k / 16;
|
| 1165 |
+
int cur_group_id = 0;
|
| 1166 |
+
|
| 1167 |
+
// Suppress bogus and persistent divide-by-zero warning
|
| 1168 |
+
#pragma nv_diagnostic push
|
| 1169 |
+
#pragma nv_diag_suppress divide_by_zero
|
| 1170 |
+
cur_group_id = k_blocks / group_blocks;
|
| 1171 |
+
#pragma nv_diagnostic pop
|
| 1172 |
+
|
| 1173 |
+
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
| 1174 |
+
|
| 1175 |
+
sh_zp_stage += cur_group_id * zp_sh_stride;
|
| 1176 |
+
|
| 1177 |
+
for (int i = 0; i < num_ints_per_thread; i++) {
|
| 1178 |
+
frag_qzp[k % 2][i] =
|
| 1179 |
+
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
| 1180 |
+
}
|
| 1181 |
+
}
|
| 1182 |
+
}
|
| 1183 |
+
|
| 1184 |
+
else if constexpr (has_zp && is_zp_float) {
|
| 1185 |
+
int pipe = full_pipe % stages;
|
| 1186 |
+
|
| 1187 |
+
if constexpr (group_blocks != -1) {
|
| 1188 |
+
if constexpr (group_blocks >= thread_k_blocks) {
|
| 1189 |
+
int4* sh_zp_stage =
|
| 1190 |
+
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
| 1191 |
+
(pipe / (group_blocks / thread_k_blocks)));
|
| 1192 |
+
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
| 1193 |
+
} else {
|
| 1194 |
+
int warp_id = threadIdx.x / 32;
|
| 1195 |
+
int n_warps = thread_n_blocks / 4;
|
| 1196 |
+
|
| 1197 |
+
int warp_row = warp_id / n_warps;
|
| 1198 |
+
|
| 1199 |
+
int cur_k = warp_row * 16;
|
| 1200 |
+
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
| 1201 |
+
|
| 1202 |
+
int k_blocks = cur_k / 16;
|
| 1203 |
+
// Suppress bogus and persistent divide-by-zero warning
|
| 1204 |
+
#pragma nv_diagnostic push
|
| 1205 |
+
#pragma nv_diag_suppress divide_by_zero
|
| 1206 |
+
int cur_group_id = k_blocks / group_blocks;
|
| 1207 |
+
#pragma nv_diagnostic pop
|
| 1208 |
+
|
| 1209 |
+
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
| 1210 |
+
|
| 1211 |
+
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
|
| 1212 |
+
sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
|
| 1213 |
+
}
|
| 1214 |
+
}
|
| 1215 |
+
}
|
| 1216 |
+
};
|
| 1217 |
+
|
| 1218 |
+
// Execute the actual tensor core matmul of a sub-tile.
|
| 1219 |
+
auto matmul = [&](int k) {
|
| 1220 |
+
if constexpr (has_zp && !is_zp_float) {
|
| 1221 |
+
FragB frag_zp_0;
|
| 1222 |
+
FragB frag_zp_1;
|
| 1223 |
+
int zp_quant_0, zp_quant_1;
|
| 1224 |
+
|
| 1225 |
+
if constexpr (w_type.size_bits() == 4) {
|
| 1226 |
+
zp_quant_0 = frag_qzp[k % 2][0];
|
| 1227 |
+
zp_quant_1 = zp_quant_0 >> 8;
|
| 1228 |
+
} else {
|
| 1229 |
+
static_assert(w_type.size_bits() == 8);
|
| 1230 |
+
zp_quant_0 = frag_qzp[k % 2][0];
|
| 1231 |
+
zp_quant_1 = frag_qzp[k % 2][1];
|
| 1232 |
+
}
|
| 1233 |
+
|
| 1234 |
+
frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
|
| 1235 |
+
frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
|
| 1236 |
+
|
| 1237 |
+
frag_zp[0] = frag_zp_0[0];
|
| 1238 |
+
frag_zp[1] = frag_zp_0[1];
|
| 1239 |
+
frag_zp[2] = frag_zp_1[0];
|
| 1240 |
+
frag_zp[3] = frag_zp_1[1];
|
| 1241 |
+
}
|
| 1242 |
+
|
| 1243 |
+
// We have the m dimension as the inner loop in order to encourage overlapping
|
| 1244 |
+
// dequantization and matmul operations.
|
| 1245 |
+
#pragma unroll
|
| 1246 |
+
for (int j = 0; j < 4; j++) {
|
| 1247 |
+
FragB frag_b0;
|
| 1248 |
+
FragB frag_b1;
|
| 1249 |
+
int b_quant_0, b_quant_1;
|
| 1250 |
+
|
| 1251 |
+
if constexpr (w_type.size_bits() == 4) {
|
| 1252 |
+
b_quant_0 = frag_b_quant[k % 2][0][j];
|
| 1253 |
+
b_quant_1 = b_quant_0 >> 8;
|
| 1254 |
+
} else {
|
| 1255 |
+
static_assert(w_type.size_bits() == 8);
|
| 1256 |
+
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
| 1257 |
+
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
| 1258 |
+
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
| 1259 |
+
}
|
| 1260 |
+
|
| 1261 |
+
frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
|
| 1262 |
+
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
|
| 1263 |
+
|
| 1264 |
+
// Apply zero-point to frag_b0
|
| 1265 |
+
if constexpr (has_zp && !is_zp_float) {
|
| 1266 |
+
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
| 1270 |
+
sub_zp<scalar_t>(frag_b0, frag_zpf[k % 2][j], 0);
|
| 1271 |
+
}
|
| 1272 |
+
|
| 1273 |
+
// Apply scale to frag_b0
|
| 1274 |
+
if constexpr (has_act_order) {
|
| 1275 |
+
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
|
| 1276 |
+
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
|
| 1277 |
+
act_frag_s[k % 2][3][j], 0);
|
| 1278 |
+
} else {
|
| 1279 |
+
if constexpr (group_blocks != -1) {
|
| 1280 |
+
scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);
|
| 1281 |
+
}
|
| 1282 |
+
}
|
| 1283 |
+
|
| 1284 |
+
// Apply zero-point to frag_b1
|
| 1285 |
+
if constexpr (has_zp && !is_zp_float) {
|
| 1286 |
+
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
| 1287 |
+
}
|
| 1288 |
+
|
| 1289 |
+
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
| 1290 |
+
sub_zp<scalar_t>(frag_b1, frag_zpf[k % 2][j], 1);
|
| 1291 |
+
}
|
| 1292 |
+
|
| 1293 |
+
// Apply scale to frag_b1
|
| 1294 |
+
if constexpr (has_act_order) {
|
| 1295 |
+
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
| 1296 |
+
act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],
|
| 1297 |
+
act_frag_s[k % 2][3][j], 1);
|
| 1298 |
+
|
| 1299 |
+
} else {
|
| 1300 |
+
if constexpr (group_blocks != -1) {
|
| 1301 |
+
scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);
|
| 1302 |
+
}
|
| 1303 |
+
}
|
| 1304 |
+
|
| 1305 |
+
#pragma unroll
|
| 1306 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
| 1307 |
+
mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
|
| 1308 |
+
mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
|
| 1309 |
+
}
|
| 1310 |
+
}
|
| 1311 |
+
};
|
| 1312 |
+
|
| 1313 |
+
// Since we slice across the k dimension of a tile in order to increase the
|
| 1314 |
+
// number of warps while keeping the n dimension of a tile reasonable, we have
|
| 1315 |
+
// multiple warps that accumulate their partial sums of the same output
|
| 1316 |
+
// location; which we have to reduce over in the end. We do in shared memory.
|
| 1317 |
+
auto thread_block_reduce = [&]() {
|
| 1318 |
+
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
| 1319 |
+
if (red_off >= 1) {
|
| 1320 |
+
int red_idx = threadIdx.x / b_sh_stride_threads;
|
| 1321 |
+
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
| 1322 |
+
constexpr int red_sh_delta = b_sh_stride_threads;
|
| 1323 |
+
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
| 1324 |
+
(threadIdx.x % b_sh_stride_threads);
|
| 1325 |
+
|
| 1326 |
+
// Parallel logarithmic shared memory reduction. We make sure to avoid any
|
| 1327 |
+
// unnecessary read or write iterations, e.g., for two warps we write only
|
| 1328 |
+
// once by warp 1 and read only once by warp 0.
|
| 1329 |
+
|
| 1330 |
+
#pragma unroll
|
| 1331 |
+
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
|
| 1332 |
+
#pragma unroll
|
| 1333 |
+
for (int i = red_off; i > 0; i /= 2) {
|
| 1334 |
+
if (i <= red_idx && red_idx < 2 * i) {
|
| 1335 |
+
#pragma unroll
|
| 1336 |
+
for (int j = 0; j < 4 * 2; j++) {
|
| 1337 |
+
int red_sh_wr =
|
| 1338 |
+
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
| 1339 |
+
if (i < red_off) {
|
| 1340 |
+
float* c_rd =
|
| 1341 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
| 1342 |
+
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
| 1343 |
+
#pragma unroll
|
| 1344 |
+
for (int k = 0; k < 4; k++)
|
| 1345 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
| 1346 |
+
c_rd[k] + c_wr[k];
|
| 1347 |
+
}
|
| 1348 |
+
sh[red_sh_wr] =
|
| 1349 |
+
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
| 1350 |
+
}
|
| 1351 |
+
}
|
| 1352 |
+
__syncthreads();
|
| 1353 |
+
}
|
| 1354 |
+
if (red_idx == 0) {
|
| 1355 |
+
#pragma unroll
|
| 1356 |
+
for (int i = 0; i < 4 * 2; i++) {
|
| 1357 |
+
float* c_rd =
|
| 1358 |
+
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
| 1359 |
+
#pragma unroll
|
| 1360 |
+
for (int j = 0; j < 4; j++)
|
| 1361 |
+
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
| 1362 |
+
c_rd[j];
|
| 1363 |
+
}
|
| 1364 |
+
}
|
| 1365 |
+
__syncthreads();
|
| 1366 |
+
}
|
| 1367 |
+
}
|
| 1368 |
+
};
|
| 1369 |
+
|
| 1370 |
+
// Since multiple threadblocks may process parts of the same column slice, we
|
| 1371 |
+
// finally have to globally reduce over the results. As the striped
|
| 1372 |
+
// partitioning minimizes the number of such reductions and our outputs are
|
| 1373 |
+
// usually rather small, we perform this reduction serially in L2 cache.
|
| 1374 |
+
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
|
| 1375 |
+
// We are very careful here to reduce directly in the output buffer to
|
| 1376 |
+
// maximize L2 cache utilization in this step. To do this, we write out
|
| 1377 |
+
// results in FP16 (but still reduce with FP32 compute).
|
| 1378 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
| 1379 |
+
if (threadIdx.x < active_threads) {
|
| 1380 |
+
int c_gl_stride = prob_n / 8;
|
| 1381 |
+
int c_gl_wr_delta_o = 8 * c_gl_stride;
|
| 1382 |
+
int c_gl_wr_delta_i = 4 * (active_threads / 32);
|
| 1383 |
+
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
|
| 1384 |
+
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
| 1385 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
| 1386 |
+
constexpr int c_sh_wr_delta = active_threads;
|
| 1387 |
+
int c_sh_wr = threadIdx.x;
|
| 1388 |
+
|
| 1389 |
+
int row = (threadIdx.x % 32) / 4;
|
| 1390 |
+
|
| 1391 |
+
if (!first) {
|
| 1392 |
+
// Interestingly, doing direct global accesses here really seems to mess up
|
| 1393 |
+
// the compiler and lead to slowdowns, hence we also use async-copies even
|
| 1394 |
+
// though these fetches are not actually asynchronous.
|
| 1395 |
+
#pragma unroll
|
| 1396 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
| 1397 |
+
cp_async4_pred(
|
| 1398 |
+
&sh[c_sh_wr + c_sh_wr_delta * i],
|
| 1399 |
+
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
| 1400 |
+
c_gl_wr_delta_i * (i % 2)],
|
| 1401 |
+
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
| 1402 |
+
}
|
| 1403 |
+
cp_async_fence();
|
| 1404 |
+
cp_async_wait<0>();
|
| 1405 |
+
}
|
| 1406 |
+
|
| 1407 |
+
#pragma unroll
|
| 1408 |
+
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
| 1409 |
+
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
| 1410 |
+
if (!first) {
|
| 1411 |
+
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
| 1412 |
+
#pragma unroll
|
| 1413 |
+
for (int j = 0; j < 2 * 4; j++) {
|
| 1414 |
+
reinterpret_cast<float*>(
|
| 1415 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
|
| 1416 |
+
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
|
| 1417 |
+
}
|
| 1418 |
+
}
|
| 1419 |
+
if (!last) {
|
| 1420 |
+
int4 c;
|
| 1421 |
+
#pragma unroll
|
| 1422 |
+
for (int j = 0; j < 2 * 4; j++) {
|
| 1423 |
+
reinterpret_cast<scalar_t*>(&c)[j] =
|
| 1424 |
+
Dtype::float2num(reinterpret_cast<float*>(
|
| 1425 |
+
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
|
| 1426 |
+
}
|
| 1427 |
+
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
|
| 1428 |
+
c;
|
| 1429 |
+
}
|
| 1430 |
+
}
|
| 1431 |
+
}
|
| 1432 |
+
}
|
| 1433 |
+
};
|
| 1434 |
+
|
| 1435 |
+
// Globally reduce over threadblocks that compute the same column block.
|
| 1436 |
+
// We use a tmp C buffer to reduce in full fp32 precision.
|
| 1437 |
+
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
|
| 1438 |
+
constexpr int tb_m = thread_m_blocks * 16;
|
| 1439 |
+
constexpr int tb_n = thread_n_blocks * 16;
|
| 1440 |
+
|
| 1441 |
+
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
|
| 1442 |
+
|
| 1443 |
+
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
| 1444 |
+
bool is_th_active = threadIdx.x < active_threads;
|
| 1445 |
+
|
| 1446 |
+
int par_offset = c_size * n_tiles * par_id;
|
| 1447 |
+
int slice_offset = c_size * slice_col;
|
| 1448 |
+
|
| 1449 |
+
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
|
| 1450 |
+
constexpr int th_size = num_floats * sizeof(float) / 16;
|
| 1451 |
+
|
| 1452 |
+
int c_cur_offset = par_offset + slice_offset;
|
| 1453 |
+
|
| 1454 |
+
if (!is_th_active) {
|
| 1455 |
+
return;
|
| 1456 |
+
}
|
| 1457 |
+
|
| 1458 |
+
if (!first) {
|
| 1459 |
+
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
| 1460 |
+
#pragma unroll
|
| 1461 |
+
for (int k = 0; k < th_size; k++) {
|
| 1462 |
+
sh[threadIdx.x] =
|
| 1463 |
+
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
| 1464 |
+
|
| 1465 |
+
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
|
| 1466 |
+
#pragma unroll
|
| 1467 |
+
for (int f = 0; f < 4; f++) {
|
| 1468 |
+
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
| 1469 |
+
}
|
| 1470 |
+
}
|
| 1471 |
+
}
|
| 1472 |
+
|
| 1473 |
+
if (!last) {
|
| 1474 |
+
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
|
| 1475 |
+
#pragma unroll
|
| 1476 |
+
for (int k = 0; k < th_size; k++) {
|
| 1477 |
+
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
|
| 1478 |
+
}
|
| 1479 |
+
}
|
| 1480 |
+
};
|
| 1481 |
+
|
| 1482 |
+
// Write out the reduce final result in the correct layout. We only actually
|
| 1483 |
+
// reshuffle matrix fragments in this step, the reduction above is performed
|
| 1484 |
+
// in fragment layout.
|
| 1485 |
+
auto write_result = [&]() {
|
| 1486 |
+
int c_gl_stride = prob_n / 8;
|
| 1487 |
+
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
| 1488 |
+
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
| 1489 |
+
constexpr int c_sh_rd_delta =
|
| 1490 |
+
c_sh_stride * (threads / (2 * thread_n_blocks));
|
| 1491 |
+
|
| 1492 |
+
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
| 1493 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
| 1494 |
+
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
| 1495 |
+
int c_sh_wr =
|
| 1496 |
+
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
|
| 1497 |
+
c_sh_wr += 32 * (threadIdx.x / 32);
|
| 1498 |
+
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
|
| 1499 |
+
(threadIdx.x % (2 * thread_n_blocks));
|
| 1500 |
+
|
| 1501 |
+
int c_gl_wr_end = c_gl_stride * prob_m;
|
| 1502 |
+
|
| 1503 |
+
// We first reorder in shared memory to guarantee the most efficient final
|
| 1504 |
+
// global write patterns
|
| 1505 |
+
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
| 1506 |
+
scalar_t2 res =
|
| 1507 |
+
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
| 1508 |
+
|
| 1509 |
+
// For per-column quantization we finally apply the scale here (only for
|
| 1510 |
+
// 4-bit)
|
| 1511 |
+
if constexpr (!has_act_order && group_blocks == -1 &&
|
| 1512 |
+
w_type.size_bits() == 4) {
|
| 1513 |
+
res = __hmul2(res, s[0]);
|
| 1514 |
+
}
|
| 1515 |
+
|
| 1516 |
+
((scalar_t2*)sh)[idx] = res;
|
| 1517 |
+
};
|
| 1518 |
+
|
| 1519 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
| 1520 |
+
#pragma unroll
|
| 1521 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
| 1522 |
+
#pragma unroll
|
| 1523 |
+
for (int j = 0; j < 4; j++) {
|
| 1524 |
+
int wr = c_sh_wr + 8 * j;
|
| 1525 |
+
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
| 1526 |
+
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
| 1527 |
+
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
| 1528 |
+
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
| 1529 |
+
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
| 1530 |
+
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
| 1531 |
+
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
| 1532 |
+
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
| 1533 |
+
}
|
| 1534 |
+
c_sh_wr += 16 * (4 * c_sh_stride);
|
| 1535 |
+
}
|
| 1536 |
+
}
|
| 1537 |
+
__syncthreads();
|
| 1538 |
+
|
| 1539 |
+
#pragma unroll
|
| 1540 |
+
for (int i = 0;
|
| 1541 |
+
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
| 1542 |
+
i++) {
|
| 1543 |
+
if (c_gl_wr < c_gl_wr_end) {
|
| 1544 |
+
C[c_gl_wr] = sh[c_sh_rd];
|
| 1545 |
+
c_gl_wr += c_gl_wr_delta;
|
| 1546 |
+
c_sh_rd += c_sh_rd_delta;
|
| 1547 |
+
}
|
| 1548 |
+
}
|
| 1549 |
+
};
|
| 1550 |
+
|
| 1551 |
+
// Start global fetch and register load pipelines.
|
| 1552 |
+
auto start_pipes = [&]() {
|
| 1553 |
+
|
| 1554 |
+
#pragma unroll
|
| 1555 |
+
for (int i = 0; i < stages - 1; i++) {
|
| 1556 |
+
if (has_act_order && i == 0) {
|
| 1557 |
+
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
| 1558 |
+
if (last_g_idx >= prob_k) {
|
| 1559 |
+
last_g_idx = prob_k - 1;
|
| 1560 |
+
}
|
| 1561 |
+
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
| 1562 |
+
}
|
| 1563 |
+
|
| 1564 |
+
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
| 1565 |
+
if (i == 0) {
|
| 1566 |
+
fetch_zp_to_shared();
|
| 1567 |
+
}
|
| 1568 |
+
}
|
| 1569 |
+
fetch_to_shared(i, i, i < slice_iters);
|
| 1570 |
+
}
|
| 1571 |
+
|
| 1572 |
+
zero_accums();
|
| 1573 |
+
wait_for_stage();
|
| 1574 |
+
init_same_group(0);
|
| 1575 |
+
fetch_to_registers(0, 0);
|
| 1576 |
+
fetch_scales_to_registers(0, 0);
|
| 1577 |
+
fetch_zp_to_registers(0, 0);
|
| 1578 |
+
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
| 1579 |
+
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
| 1580 |
+
};
|
| 1581 |
+
if (slice_iters) {
|
| 1582 |
+
start_pipes();
|
| 1583 |
+
}
|
| 1584 |
+
|
| 1585 |
+
// Main loop.
|
| 1586 |
+
while (slice_iters) {
|
| 1587 |
+
// We unroll over both the global fetch and the register load pipeline to
|
| 1588 |
+
// ensure all shared memory accesses are static. Note that both pipelines
|
| 1589 |
+
// have even length meaning that the next iteration will always start at
|
| 1590 |
+
// index 0.
|
| 1591 |
+
|
| 1592 |
+
#pragma unroll
|
| 1593 |
+
for (int pipe = 0; pipe < stages;) {
|
| 1594 |
+
#pragma unroll
|
| 1595 |
+
for (int k = 0; k < b_sh_wr_iters; k++) {
|
| 1596 |
+
fetch_to_registers(k + 1, pipe % stages);
|
| 1597 |
+
fetch_scales_to_registers(k + 1, pipe);
|
| 1598 |
+
fetch_zp_to_registers(k + 1, pipe);
|
| 1599 |
+
if (k == b_sh_wr_iters - 2) {
|
| 1600 |
+
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
| 1601 |
+
slice_iters >= stages);
|
| 1602 |
+
pipe++;
|
| 1603 |
+
wait_for_stage();
|
| 1604 |
+
init_same_group(pipe % stages);
|
| 1605 |
+
}
|
| 1606 |
+
matmul(k);
|
| 1607 |
+
}
|
| 1608 |
+
slice_iters--;
|
| 1609 |
+
if (slice_iters == 0) {
|
| 1610 |
+
break;
|
| 1611 |
+
}
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
a_gl_rd += a_gl_rd_delta_o * stages;
|
| 1615 |
+
slice_k_start += tb_k * stages;
|
| 1616 |
+
slice_k_start_shared_fetch += tb_k * stages;
|
| 1617 |
+
|
| 1618 |
+
if constexpr (has_act_order) {
|
| 1619 |
+
int first_group_id = g_idx[slice_k_start];
|
| 1620 |
+
int last_g_idx = slice_k_start + stages * tb_k * 2;
|
| 1621 |
+
if (last_g_idx >= prob_k) {
|
| 1622 |
+
last_g_idx = prob_k - 1;
|
| 1623 |
+
}
|
| 1624 |
+
int last_group_id = g_idx[last_g_idx];
|
| 1625 |
+
if (last_group_id >= sh_first_group_id + sh_num_groups) {
|
| 1626 |
+
fetch_scales_to_shared(false, first_group_id, last_group_id);
|
| 1627 |
+
__syncthreads();
|
| 1628 |
+
}
|
| 1629 |
+
}
|
| 1630 |
+
|
| 1631 |
+
// Process results and, if necessary, proceed to the next column slice.
|
| 1632 |
+
// While this pattern may not be the most readable, other ways of writing
|
| 1633 |
+
// the loop seemed to noticeably worse performance after compilation.
|
| 1634 |
+
if (slice_iters == 0) {
|
| 1635 |
+
cp_async_wait<0>();
|
| 1636 |
+
bool last = slice_idx == slice_count - 1;
|
| 1637 |
+
// For per-column scales, we only fetch them here in the final step before
|
| 1638 |
+
// write-out
|
| 1639 |
+
if constexpr (!has_act_order && group_blocks == -1) {
|
| 1640 |
+
if constexpr (w_type.size_bits() == 8) {
|
| 1641 |
+
if (s_sh_wr_pred) {
|
| 1642 |
+
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
| 1643 |
+
}
|
| 1644 |
+
cp_async_fence();
|
| 1645 |
+
} else {
|
| 1646 |
+
if (last) {
|
| 1647 |
+
if (s_sh_wr_pred) {
|
| 1648 |
+
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
| 1649 |
+
}
|
| 1650 |
+
cp_async_fence();
|
| 1651 |
+
}
|
| 1652 |
+
}
|
| 1653 |
+
}
|
| 1654 |
+
|
| 1655 |
+
thread_block_reduce();
|
| 1656 |
+
if constexpr (!has_act_order && group_blocks == -1) {
|
| 1657 |
+
if constexpr (w_type.size_bits() == 8) {
|
| 1658 |
+
cp_async_wait<0>();
|
| 1659 |
+
__syncthreads();
|
| 1660 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
| 1661 |
+
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
| 1662 |
+
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
| 1663 |
+
}
|
| 1664 |
+
|
| 1665 |
+
} else {
|
| 1666 |
+
if (last) {
|
| 1667 |
+
cp_async_wait<0>();
|
| 1668 |
+
__syncthreads();
|
| 1669 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
| 1670 |
+
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
|
| 1671 |
+
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
|
| 1672 |
+
}
|
| 1673 |
+
}
|
| 1674 |
+
}
|
| 1675 |
+
}
|
| 1676 |
+
|
| 1677 |
+
// For 8-bit channelwise, we apply the scale before the global reduction
|
| 1678 |
+
// that converts the fp32 results to fp16 (so that we avoid possible
|
| 1679 |
+
// overflow in fp16)
|
| 1680 |
+
if constexpr (!has_act_order && group_blocks == -1 &&
|
| 1681 |
+
w_type.size_bits() == 8) {
|
| 1682 |
+
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
| 1683 |
+
#pragma unroll
|
| 1684 |
+
for (int i = 0; i < thread_m_blocks; i++) {
|
| 1685 |
+
#pragma unroll
|
| 1686 |
+
for (int j = 0; j < 4; j++) {
|
| 1687 |
+
scale_float<scalar_t>(
|
| 1688 |
+
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
|
| 1689 |
+
frag_s[j / 2][2 * (j % 2) + 0]);
|
| 1690 |
+
scale_float<scalar_t>(
|
| 1691 |
+
reinterpret_cast<float*>(&frag_c[i][j][0][2]),
|
| 1692 |
+
frag_s[j / 2][2 * (j % 2) + 0]);
|
| 1693 |
+
|
| 1694 |
+
scale_float<scalar_t>(
|
| 1695 |
+
reinterpret_cast<float*>(&frag_c[i][j][1][0]),
|
| 1696 |
+
frag_s[j / 2][2 * (j % 2) + 1]);
|
| 1697 |
+
scale_float<scalar_t>(
|
| 1698 |
+
reinterpret_cast<float*>(&frag_c[i][j][1][2]),
|
| 1699 |
+
frag_s[j / 2][2 * (j % 2) + 1]);
|
| 1700 |
+
}
|
| 1701 |
+
}
|
| 1702 |
+
}
|
| 1703 |
+
}
|
| 1704 |
+
|
| 1705 |
+
if (slice_count > 1) { // only globally reduce if there is more than one
|
| 1706 |
+
// block in a slice
|
| 1707 |
+
barrier_acquire(&locks[slice_col], slice_idx);
|
| 1708 |
+
if (use_fp32_reduce) {
|
| 1709 |
+
global_reduce_fp32(slice_idx == 0, last);
|
| 1710 |
+
} else {
|
| 1711 |
+
global_reduce_fp16(slice_idx == 0, last);
|
| 1712 |
+
}
|
| 1713 |
+
barrier_release(&locks[slice_col], last);
|
| 1714 |
+
}
|
| 1715 |
+
if (last) // only the last block in a slice actually writes the result
|
| 1716 |
+
write_result();
|
| 1717 |
+
slice_row = 0;
|
| 1718 |
+
slice_col_par++;
|
| 1719 |
+
slice_col++;
|
| 1720 |
+
init_slice();
|
| 1721 |
+
if (slice_iters) {
|
| 1722 |
+
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
| 1723 |
+
(threadIdx.x % a_gl_rd_delta_o);
|
| 1724 |
+
#pragma unroll
|
| 1725 |
+
for (int i = 0; i < b_sh_wr_iters; i++)
|
| 1726 |
+
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
|
| 1727 |
+
if (slice_col == 0) {
|
| 1728 |
+
#pragma unroll
|
| 1729 |
+
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
| 1730 |
+
}
|
| 1731 |
+
|
| 1732 |
+
// Update slice k/n for scales loading
|
| 1733 |
+
if constexpr (has_act_order) {
|
| 1734 |
+
slice_k_start = tb_k * slice_row;
|
| 1735 |
+
slice_k_finish = slice_k_start + tb_k * slice_iters;
|
| 1736 |
+
slice_k_start_shared_fetch = slice_k_start;
|
| 1737 |
+
slice_n_offset = act_s_col_tb_stride * slice_col;
|
| 1738 |
+
|
| 1739 |
+
} else {
|
| 1740 |
+
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
| 1741 |
+
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
| 1742 |
+
}
|
| 1743 |
+
|
| 1744 |
+
start_pipes();
|
| 1745 |
+
}
|
| 1746 |
+
}
|
| 1747 |
+
}
|
| 1748 |
+
}
|
| 1749 |
+
|
| 1750 |
+
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
| 1751 |
+
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
|
| 1752 |
+
IS_ZP_FLOAT) \
|
| 1753 |
+
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
| 1754 |
+
thread_n_blocks == THREAD_N_BLOCKS && \
|
| 1755 |
+
thread_k_blocks == THREAD_K_BLOCKS && \
|
| 1756 |
+
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
| 1757 |
+
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
| 1758 |
+
is_zp_float == IS_ZP_FLOAT) { \
|
| 1759 |
+
if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
|
| 1760 |
+
cudaFuncSetAttribute( \
|
| 1761 |
+
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
| 1762 |
+
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
|
| 1763 |
+
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
|
| 1764 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
| 1765 |
+
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
| 1766 |
+
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
| 1767 |
+
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
| 1768 |
+
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
| 1769 |
+
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
| 1770 |
+
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
| 1771 |
+
} \
|
| 1772 |
+
}
|
| 1773 |
+
|
| 1774 |
+
typedef struct {
|
| 1775 |
+
int thread_k;
|
| 1776 |
+
int thread_n;
|
| 1777 |
+
int num_threads;
|
| 1778 |
+
} thread_config_t;
|
| 1779 |
+
|
| 1780 |
+
typedef struct {
|
| 1781 |
+
int max_m_blocks;
|
| 1782 |
+
thread_config_t tb_cfg;
|
| 1783 |
+
} exec_config_t;
|
| 1784 |
+
|
| 1785 |
+
thread_config_t small_batch_thread_configs[] = {
|
| 1786 |
+
// Ordered by priority
|
| 1787 |
+
|
| 1788 |
+
// thread_k, thread_n, num_threads
|
| 1789 |
+
{128, 128, 256},
|
| 1790 |
+
{64, 128, 128},
|
| 1791 |
+
{128, 64, 128},
|
| 1792 |
+
};
|
| 1793 |
+
|
| 1794 |
+
thread_config_t large_batch_thread_configs[] = {
|
| 1795 |
+
// Ordered by priority
|
| 1796 |
+
|
| 1797 |
+
// thread_k, thread_n, num_threads
|
| 1798 |
+
{64, 256, 256},
|
| 1799 |
+
{64, 128, 128},
|
| 1800 |
+
{128, 64, 128},
|
| 1801 |
+
|
| 1802 |
+
};
|
| 1803 |
+
|
| 1804 |
+
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
| 1805 |
+
int prob_n, int prob_k, int num_bits, int group_size,
|
| 1806 |
+
bool has_act_order, bool is_k_full) {
|
| 1807 |
+
bool cache_scales_chunk = has_act_order && !is_k_full;
|
| 1808 |
+
|
| 1809 |
+
int tb_n = th_config.thread_n;
|
| 1810 |
+
int tb_k = th_config.thread_k;
|
| 1811 |
+
|
| 1812 |
+
// Get max scale groups per thread-block
|
| 1813 |
+
int tb_groups;
|
| 1814 |
+
if (group_size == -1) {
|
| 1815 |
+
tb_groups = 1;
|
| 1816 |
+
} else if (group_size == 0) {
|
| 1817 |
+
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
| 1818 |
+
} else {
|
| 1819 |
+
tb_groups = div_ceil(tb_k, group_size);
|
| 1820 |
+
}
|
| 1821 |
+
|
| 1822 |
+
if (cache_scales_chunk) {
|
| 1823 |
+
int load_groups =
|
| 1824 |
+
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
| 1825 |
+
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
| 1826 |
+
return load_groups * tb_n * 2;
|
| 1827 |
+
|
| 1828 |
+
} else {
|
| 1829 |
+
int tb_scales = tb_groups * tb_n * 2;
|
| 1830 |
+
|
| 1831 |
+
return tb_scales * pipe_stages;
|
| 1832 |
+
}
|
| 1833 |
+
}
|
| 1834 |
+
|
| 1835 |
+
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
| 1836 |
+
int prob_m, int prob_n, int prob_k, int num_bits,
|
| 1837 |
+
int scales_cache_size, int max_shared_mem) {
|
| 1838 |
+
int pack_factor = 32 / num_bits;
|
| 1839 |
+
|
| 1840 |
+
// Get B size
|
| 1841 |
+
int tb_k = th_config.thread_k;
|
| 1842 |
+
int tb_n = th_config.thread_n;
|
| 1843 |
+
|
| 1844 |
+
int b_size = (tb_k * tb_n / pack_factor) * 4;
|
| 1845 |
+
|
| 1846 |
+
// Get A size
|
| 1847 |
+
int m_blocks = div_ceil(prob_m, 16);
|
| 1848 |
+
int tb_max_m = 16;
|
| 1849 |
+
|
| 1850 |
+
while (true) {
|
| 1851 |
+
if (m_blocks >= max_m_blocks) {
|
| 1852 |
+
tb_max_m *= max_m_blocks;
|
| 1853 |
+
break;
|
| 1854 |
+
}
|
| 1855 |
+
|
| 1856 |
+
max_m_blocks--;
|
| 1857 |
+
if (max_m_blocks == 0) {
|
| 1858 |
+
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
|
| 1859 |
+
}
|
| 1860 |
+
}
|
| 1861 |
+
|
| 1862 |
+
int a_size = (tb_max_m * tb_k) * 2;
|
| 1863 |
+
|
| 1864 |
+
float pipe_size = (a_size + b_size) * pipe_stages;
|
| 1865 |
+
|
| 1866 |
+
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
| 1867 |
+
|
| 1868 |
+
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
| 1869 |
+
}
|
| 1870 |
+
|
| 1871 |
+
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
| 1872 |
+
int prob_m, int prob_n, int prob_k, int num_bits,
|
| 1873 |
+
int group_size, bool has_act_order, bool is_k_full,
|
| 1874 |
+
int max_shared_mem) {
|
| 1875 |
+
// Sanity
|
| 1876 |
+
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
| 1877 |
+
th_config.num_threads == -1) {
|
| 1878 |
+
return false;
|
| 1879 |
+
}
|
| 1880 |
+
|
| 1881 |
+
// Verify K/N are divisible by thread K/N
|
| 1882 |
+
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
| 1883 |
+
return false;
|
| 1884 |
+
}
|
| 1885 |
+
|
| 1886 |
+
// Verify min for thread K/N
|
| 1887 |
+
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
| 1888 |
+
return false;
|
| 1889 |
+
}
|
| 1890 |
+
|
| 1891 |
+
// num_threads must be at least 128 (= 4 warps)
|
| 1892 |
+
if (th_config.num_threads < 128) {
|
| 1893 |
+
return false;
|
| 1894 |
+
}
|
| 1895 |
+
|
| 1896 |
+
// Determine cache for scales
|
| 1897 |
+
int scales_cache_size =
|
| 1898 |
+
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
| 1899 |
+
group_size, has_act_order, is_k_full);
|
| 1900 |
+
|
| 1901 |
+
// Check that pipeline fits into cache
|
| 1902 |
+
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
| 1903 |
+
num_bits, scales_cache_size, max_shared_mem)) {
|
| 1904 |
+
return false;
|
| 1905 |
+
}
|
| 1906 |
+
|
| 1907 |
+
return true;
|
| 1908 |
+
}
|
| 1909 |
+
|
| 1910 |
+
int determine_reduce_max_m(int prob_m, int max_par) {
|
| 1911 |
+
constexpr int tile_m_size = 16;
|
| 1912 |
+
|
| 1913 |
+
if (prob_m <= tile_m_size) {
|
| 1914 |
+
return tile_m_size;
|
| 1915 |
+
|
| 1916 |
+
} else if (prob_m <= tile_m_size * 2) {
|
| 1917 |
+
return tile_m_size * 2;
|
| 1918 |
+
|
| 1919 |
+
} else if (prob_m <= tile_m_size * 3) {
|
| 1920 |
+
return tile_m_size * 3;
|
| 1921 |
+
|
| 1922 |
+
} else if (prob_m <= tile_m_size * 4) {
|
| 1923 |
+
return tile_m_size * 4;
|
| 1924 |
+
|
| 1925 |
+
} else {
|
| 1926 |
+
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
|
| 1927 |
+
return tile_m_size * 4 * cur_par;
|
| 1928 |
+
}
|
| 1929 |
+
}
|
| 1930 |
+
|
| 1931 |
+
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
| 1932 |
+
int num_bits, int group_size,
|
| 1933 |
+
bool has_act_order, bool is_k_full,
|
| 1934 |
+
int max_shared_mem) {
|
| 1935 |
+
int max_m_blocks = 4;
|
| 1936 |
+
while (max_m_blocks > 0) {
|
| 1937 |
+
if (prob_m <= 16) {
|
| 1938 |
+
for (auto th_config : small_batch_thread_configs) {
|
| 1939 |
+
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
| 1940 |
+
num_bits, group_size, has_act_order, is_k_full,
|
| 1941 |
+
max_shared_mem)) {
|
| 1942 |
+
return exec_config_t{max_m_blocks, th_config};
|
| 1943 |
+
}
|
| 1944 |
+
}
|
| 1945 |
+
} else {
|
| 1946 |
+
for (auto th_config : large_batch_thread_configs) {
|
| 1947 |
+
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
|
| 1948 |
+
num_bits, group_size, has_act_order, is_k_full,
|
| 1949 |
+
max_shared_mem)) {
|
| 1950 |
+
return exec_config_t{max_m_blocks, th_config};
|
| 1951 |
+
}
|
| 1952 |
+
}
|
| 1953 |
+
}
|
| 1954 |
+
|
| 1955 |
+
max_m_blocks--; // Process less M blocks per invocation to reduce cache
|
| 1956 |
+
// usage
|
| 1957 |
+
}
|
| 1958 |
+
|
| 1959 |
+
return exec_config_t{0, {-1, -1, -1}};
|
| 1960 |
+
}
|
| 1961 |
+
|
| 1962 |
+
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
| 1963 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
| 1964 |
+
false) \
|
| 1965 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
| 1966 |
+
false) \
|
| 1967 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
| 1968 |
+
false) \
|
| 1969 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
| 1970 |
+
false) \
|
| 1971 |
+
\
|
| 1972 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
| 1973 |
+
false) \
|
| 1974 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
| 1975 |
+
false) \
|
| 1976 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
| 1977 |
+
false) \
|
| 1978 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
| 1979 |
+
false) \
|
| 1980 |
+
\
|
| 1981 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
| 1982 |
+
false) \
|
| 1983 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
| 1984 |
+
false) \
|
| 1985 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
| 1986 |
+
false) \
|
| 1987 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
| 1988 |
+
false) \
|
| 1989 |
+
\
|
| 1990 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
| 1991 |
+
false) \
|
| 1992 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
| 1993 |
+
false) \
|
| 1994 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
| 1995 |
+
false) \
|
| 1996 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
| 1997 |
+
false) \
|
| 1998 |
+
\
|
| 1999 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
| 2000 |
+
false) \
|
| 2001 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
| 2002 |
+
false) \
|
| 2003 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
| 2004 |
+
false) \
|
| 2005 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
| 2006 |
+
false)
|
| 2007 |
+
|
| 2008 |
+
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
| 2009 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
| 2010 |
+
false) \
|
| 2011 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
| 2012 |
+
false) \
|
| 2013 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
| 2014 |
+
false) \
|
| 2015 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
| 2016 |
+
false) \
|
| 2017 |
+
\
|
| 2018 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
| 2019 |
+
false) \
|
| 2020 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
| 2021 |
+
false) \
|
| 2022 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
| 2023 |
+
false) \
|
| 2024 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
| 2025 |
+
false) \
|
| 2026 |
+
\
|
| 2027 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
| 2028 |
+
false) \
|
| 2029 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
| 2030 |
+
false) \
|
| 2031 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
| 2032 |
+
false) \
|
| 2033 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
| 2034 |
+
false) \
|
| 2035 |
+
\
|
| 2036 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
| 2037 |
+
false) \
|
| 2038 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
| 2039 |
+
false) \
|
| 2040 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
| 2041 |
+
false) \
|
| 2042 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
|
| 2043 |
+
|
| 2044 |
+
// We currently have 4-bit models only with group_blocks == 4
|
| 2045 |
+
#define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
| 2046 |
+
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
| 2047 |
+
true) \
|
| 2048 |
+
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
| 2049 |
+
true) \
|
| 2050 |
+
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
| 2051 |
+
true) \
|
| 2052 |
+
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
|
| 2053 |
+
|
| 2054 |
+
template <typename scalar_t>
|
| 2055 |
+
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
| 2056 |
+
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
| 2057 |
+
int prob_n, int prob_k, void* workspace,
|
| 2058 |
+
vllm::ScalarType const& q_type, bool has_act_order,
|
| 2059 |
+
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
| 2060 |
+
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
| 2061 |
+
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
|
| 2062 |
+
if (has_zp) {
|
| 2063 |
+
TORCH_CHECK(
|
| 2064 |
+
q_type == vllm::kU4 || q_type == vllm::kU8,
|
| 2065 |
+
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
| 2066 |
+
} else {
|
| 2067 |
+
TORCH_CHECK(
|
| 2068 |
+
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
| 2069 |
+
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
| 2070 |
+
q_type.str());
|
| 2071 |
+
}
|
| 2072 |
+
|
| 2073 |
+
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
| 2074 |
+
", ", prob_n, ", ", prob_k, "]");
|
| 2075 |
+
|
| 2076 |
+
// TODO: remove alias when we start supporting other 8bit types
|
| 2077 |
+
int num_bits = q_type.size_bits();
|
| 2078 |
+
int tot_m = prob_m;
|
| 2079 |
+
int tot_m_blocks = div_ceil(tot_m, 16);
|
| 2080 |
+
int pad = 16 * tot_m_blocks - tot_m;
|
| 2081 |
+
|
| 2082 |
+
if (sms == -1) {
|
| 2083 |
+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
|
| 2084 |
+
}
|
| 2085 |
+
|
| 2086 |
+
int max_shared_mem = 0;
|
| 2087 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
| 2088 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
| 2089 |
+
TORCH_CHECK(max_shared_mem > 0);
|
| 2090 |
+
|
| 2091 |
+
// Set thread config
|
| 2092 |
+
exec_config_t exec_cfg;
|
| 2093 |
+
if (thread_k != -1 && thread_n != -1) {
|
| 2094 |
+
// User-defined config
|
| 2095 |
+
exec_cfg =
|
| 2096 |
+
exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
|
| 2097 |
+
} else {
|
| 2098 |
+
// Auto config
|
| 2099 |
+
exec_cfg =
|
| 2100 |
+
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
|
| 2101 |
+
has_act_order, is_k_full, max_shared_mem);
|
| 2102 |
+
}
|
| 2103 |
+
|
| 2104 |
+
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
|
| 2105 |
+
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
|
| 2106 |
+
prob_m, prob_n, prob_k, num_bits, group_size,
|
| 2107 |
+
has_act_order, is_k_full, max_shared_mem),
|
| 2108 |
+
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
|
| 2109 |
+
", thread_k = ", exec_cfg.tb_cfg.thread_k,
|
| 2110 |
+
", thread_n = ", exec_cfg.tb_cfg.thread_n,
|
| 2111 |
+
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
|
| 2112 |
+
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
| 2113 |
+
", group_size = ", group_size,
|
| 2114 |
+
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
| 2115 |
+
", max_shared_mem = ", max_shared_mem);
|
| 2116 |
+
|
| 2117 |
+
int num_threads = exec_cfg.tb_cfg.num_threads;
|
| 2118 |
+
thread_k = exec_cfg.tb_cfg.thread_k;
|
| 2119 |
+
thread_n = exec_cfg.tb_cfg.thread_n;
|
| 2120 |
+
|
| 2121 |
+
int thread_k_blocks = thread_k / 16;
|
| 2122 |
+
int thread_n_blocks = thread_n / 16;
|
| 2123 |
+
|
| 2124 |
+
int blocks = sms;
|
| 2125 |
+
|
| 2126 |
+
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
| 2127 |
+
" is not divisible by thread_n = ", thread_n);
|
| 2128 |
+
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
| 2129 |
+
" is not divisible by thread_k = ", thread_k);
|
| 2130 |
+
|
| 2131 |
+
int group_blocks = 0;
|
| 2132 |
+
if (has_act_order) {
|
| 2133 |
+
if (is_k_full) {
|
| 2134 |
+
TORCH_CHECK(group_size != -1);
|
| 2135 |
+
group_blocks = group_size / 16;
|
| 2136 |
+
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
| 2137 |
+
" is not divisible by group_blocks = ", group_blocks);
|
| 2138 |
+
} else {
|
| 2139 |
+
TORCH_CHECK(group_size == 0);
|
| 2140 |
+
group_blocks = 0;
|
| 2141 |
+
}
|
| 2142 |
+
|
| 2143 |
+
} else {
|
| 2144 |
+
if (group_size == -1) {
|
| 2145 |
+
group_blocks = -1;
|
| 2146 |
+
} else {
|
| 2147 |
+
group_blocks = group_size / 16;
|
| 2148 |
+
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
| 2149 |
+
" is not divisible by group_blocks = ", group_blocks);
|
| 2150 |
+
}
|
| 2151 |
+
}
|
| 2152 |
+
|
| 2153 |
+
const int4* A_ptr = (const int4*)A;
|
| 2154 |
+
const int4* B_ptr = (const int4*)B;
|
| 2155 |
+
int4* C_ptr = (int4*)C;
|
| 2156 |
+
int4* C_tmp_ptr = (int4*)C_tmp;
|
| 2157 |
+
const int4* s_ptr = (const int4*)s;
|
| 2158 |
+
const int4* zp_ptr = (const int4*)zp;
|
| 2159 |
+
const int* g_idx_ptr = (const int*)g_idx;
|
| 2160 |
+
const int* perm_ptr = (const int*)perm;
|
| 2161 |
+
int4* a_tmp_ptr = (int4*)a_tmp;
|
| 2162 |
+
|
| 2163 |
+
int* locks = (int*)workspace;
|
| 2164 |
+
|
| 2165 |
+
if (has_act_order) {
|
| 2166 |
+
// Permute A columns
|
| 2167 |
+
int block_rows = div_ceil(prob_m, blocks);
|
| 2168 |
+
permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
|
| 2169 |
+
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
|
| 2170 |
+
A_ptr = a_tmp_ptr;
|
| 2171 |
+
}
|
| 2172 |
+
|
| 2173 |
+
// If we have a full K, then we can run the non-act-order version of Marlin
|
| 2174 |
+
// (since the weight rows are reordered by increasing group ids, and by having
|
| 2175 |
+
// a full K, we have full original groups)
|
| 2176 |
+
if (is_k_full) {
|
| 2177 |
+
has_act_order = false;
|
| 2178 |
+
}
|
| 2179 |
+
|
| 2180 |
+
// Main loop
|
| 2181 |
+
for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
|
| 2182 |
+
int thread_m_blocks = tot_m_blocks - i;
|
| 2183 |
+
prob_m = tot_m - 16 * i;
|
| 2184 |
+
int par = 1;
|
| 2185 |
+
if (thread_m_blocks > exec_cfg.max_m_blocks) {
|
| 2186 |
+
// Note that parallel > 1 currently only works for inputs without any
|
| 2187 |
+
// padding
|
| 2188 |
+
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
|
| 2189 |
+
if (par > max_par) par = max_par;
|
| 2190 |
+
prob_m = (16 * exec_cfg.max_m_blocks) * par;
|
| 2191 |
+
i += exec_cfg.max_m_blocks * (par - 1);
|
| 2192 |
+
thread_m_blocks = exec_cfg.max_m_blocks;
|
| 2193 |
+
}
|
| 2194 |
+
|
| 2195 |
+
if (false) {
|
| 2196 |
+
}
|
| 2197 |
+
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
|
| 2198 |
+
GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
|
| 2199 |
+
GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
|
| 2200 |
+
GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
|
| 2201 |
+
GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
|
| 2202 |
+
GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
|
| 2203 |
+
GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
|
| 2204 |
+
GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
|
| 2205 |
+
|
| 2206 |
+
AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
|
| 2207 |
+
AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
|
| 2208 |
+
AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
|
| 2209 |
+
AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
|
| 2210 |
+
AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
|
| 2211 |
+
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
|
| 2212 |
+
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
|
| 2213 |
+
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
|
| 2214 |
+
|
| 2215 |
+
HQQ_CALL_IF(vllm::kU4, 16, 4, 256)
|
| 2216 |
+
HQQ_CALL_IF(vllm::kU4, 8, 8, 256)
|
| 2217 |
+
HQQ_CALL_IF(vllm::kU4, 8, 4, 128)
|
| 2218 |
+
HQQ_CALL_IF(vllm::kU4, 4, 8, 128)
|
| 2219 |
+
else {
|
| 2220 |
+
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
| 2221 |
+
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
| 2222 |
+
", num_groups = ", num_groups, ", group_size = ", group_size,
|
| 2223 |
+
", thread_m_blocks = ", thread_m_blocks,
|
| 2224 |
+
", thread_n_blocks = ", thread_n_blocks,
|
| 2225 |
+
", thread_k_blocks = ", thread_k_blocks,
|
| 2226 |
+
", num_bits = ", num_bits);
|
| 2227 |
+
}
|
| 2228 |
+
|
| 2229 |
+
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
| 2230 |
+
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
|
| 2231 |
+
}
|
| 2232 |
+
}
|
| 2233 |
+
|
| 2234 |
+
} // namespace marlin
|
| 2235 |
+
|
| 2236 |
+
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
| 2237 |
+
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
| 2238 |
+
torch::Tensor& g_idx, torch::Tensor& perm,
|
| 2239 |
+
torch::Tensor& workspace,
|
| 2240 |
+
vllm::ScalarTypeId const& b_q_type_id,
|
| 2241 |
+
int64_t size_m, int64_t size_n, int64_t size_k,
|
| 2242 |
+
bool is_k_full, bool has_zp,
|
| 2243 |
+
bool use_fp32_reduce, bool is_zp_float) {
|
| 2244 |
+
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
| 2245 |
+
if (has_zp) {
|
| 2246 |
+
TORCH_CHECK(
|
| 2247 |
+
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
| 2248 |
+
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
| 2249 |
+
} else {
|
| 2250 |
+
TORCH_CHECK(
|
| 2251 |
+
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
| 2252 |
+
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
| 2253 |
+
b_q_type.str());
|
| 2254 |
+
}
|
| 2255 |
+
|
| 2256 |
+
if (has_zp && is_zp_float) {
|
| 2257 |
+
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
| 2258 |
+
"Computation type must be float16 (half) when using float zero "
|
| 2259 |
+
"points.");
|
| 2260 |
+
}
|
| 2261 |
+
|
| 2262 |
+
int pack_factor = 32 / b_q_type.size_bits();
|
| 2263 |
+
|
| 2264 |
+
// Verify A
|
| 2265 |
+
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
| 2266 |
+
", size_m = ", size_m);
|
| 2267 |
+
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
| 2268 |
+
", size_k = ", size_k);
|
| 2269 |
+
|
| 2270 |
+
// Verify B
|
| 2271 |
+
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
| 2272 |
+
" is not divisible by tile_size = ", marlin::tile_size);
|
| 2273 |
+
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
| 2274 |
+
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
| 2275 |
+
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
| 2276 |
+
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
| 2277 |
+
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
| 2278 |
+
" is not divisible by tile_size = ", marlin::tile_size);
|
| 2279 |
+
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
| 2280 |
+
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
| 2281 |
+
", actual_size_n = ", actual_size_n);
|
| 2282 |
+
|
| 2283 |
+
// Verify device and strides
|
| 2284 |
+
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
| 2285 |
+
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
| 2286 |
+
|
| 2287 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
| 2288 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
| 2289 |
+
|
| 2290 |
+
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
| 2291 |
+
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
| 2292 |
+
|
| 2293 |
+
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
| 2294 |
+
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
| 2295 |
+
|
| 2296 |
+
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
| 2297 |
+
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
| 2298 |
+
|
| 2299 |
+
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
| 2300 |
+
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
| 2301 |
+
|
| 2302 |
+
// Alloc buffers
|
| 2303 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
| 2304 |
+
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
| 2305 |
+
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
| 2306 |
+
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
|
| 2307 |
+
|
| 2308 |
+
// Alloc C tmp buffer that is going to be used for the global reduce
|
| 2309 |
+
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
| 2310 |
+
int reduce_n = size_n;
|
| 2311 |
+
auto options_fp32 =
|
| 2312 |
+
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
| 2313 |
+
if (!use_fp32_reduce) {
|
| 2314 |
+
reduce_max_m = 0;
|
| 2315 |
+
reduce_n = 0;
|
| 2316 |
+
}
|
| 2317 |
+
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
| 2318 |
+
|
| 2319 |
+
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
| 2320 |
+
// auto -1)
|
| 2321 |
+
int thread_k = -1;
|
| 2322 |
+
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
| 2323 |
+
// auto -1)
|
| 2324 |
+
int thread_n = -1;
|
| 2325 |
+
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
|
| 2326 |
+
int sms = -1;
|
| 2327 |
+
|
| 2328 |
+
// Verify g_idx and perm
|
| 2329 |
+
TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
|
| 2330 |
+
(g_idx.size(0) == size_k && perm.size(0) == size_k),
|
| 2331 |
+
"Unexpected g_idx.size(0) = ", g_idx.size(0),
|
| 2332 |
+
" and perm.size(0) = ", perm.size(0),
|
| 2333 |
+
", where size_k = ", size_k);
|
| 2334 |
+
|
| 2335 |
+
// Detect groupsize and act_order
|
| 2336 |
+
int num_groups = -1;
|
| 2337 |
+
int group_size = -1;
|
| 2338 |
+
bool has_act_order = g_idx.size(0) != 0;
|
| 2339 |
+
|
| 2340 |
+
int rank = b_scales.sizes().size();
|
| 2341 |
+
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
| 2342 |
+
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
| 2343 |
+
" is not size_n = ", size_n);
|
| 2344 |
+
num_groups = b_scales.size(0);
|
| 2345 |
+
|
| 2346 |
+
if (has_act_order) {
|
| 2347 |
+
if (is_k_full) {
|
| 2348 |
+
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
| 2349 |
+
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
| 2350 |
+
", is not divisible by num_groups = ", num_groups);
|
| 2351 |
+
group_size = size_k / num_groups;
|
| 2352 |
+
} else {
|
| 2353 |
+
group_size = 0;
|
| 2354 |
+
}
|
| 2355 |
+
|
| 2356 |
+
} else {
|
| 2357 |
+
if (num_groups > 1) {
|
| 2358 |
+
TORCH_CHECK(
|
| 2359 |
+
size_k % num_groups == 0, "size_k = ", size_k,
|
| 2360 |
+
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
|
| 2361 |
+
group_size = size_k / num_groups;
|
| 2362 |
+
} else {
|
| 2363 |
+
group_size = -1;
|
| 2364 |
+
}
|
| 2365 |
+
}
|
| 2366 |
+
|
| 2367 |
+
// Verify b_zeros
|
| 2368 |
+
if (has_zp) {
|
| 2369 |
+
int rank = b_zeros.sizes().size();
|
| 2370 |
+
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
| 2371 |
+
if (is_zp_float) {
|
| 2372 |
+
TORCH_CHECK(b_zeros.size(1) == size_n,
|
| 2373 |
+
"b_zeros dim 1 = ", b_zeros.size(1),
|
| 2374 |
+
" is not size_n = ", size_n);
|
| 2375 |
+
TORCH_CHECK(num_groups == b_zeros.size(0),
|
| 2376 |
+
"b_zeros dim 0 = ", b_zeros.size(0),
|
| 2377 |
+
" is not num_groups = ", num_groups);
|
| 2378 |
+
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
| 2379 |
+
} else {
|
| 2380 |
+
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
| 2381 |
+
"b_zeros dim 0 = ", b_zeros.size(0),
|
| 2382 |
+
" is not num_groups = ", num_groups);
|
| 2383 |
+
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
| 2384 |
+
"b_zeros dim 1 = ", b_zeros.size(1),
|
| 2385 |
+
" is not size_n / pack_factor = ", size_n / pack_factor);
|
| 2386 |
+
}
|
| 2387 |
+
}
|
| 2388 |
+
|
| 2389 |
+
// Verify workspace size
|
| 2390 |
+
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
| 2391 |
+
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
| 2392 |
+
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
| 2393 |
+
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
| 2394 |
+
"workspace.numel = ", workspace.numel(),
|
| 2395 |
+
" is below min_workspace_size = ", min_workspace_size);
|
| 2396 |
+
|
| 2397 |
+
int dev = a.get_device();
|
| 2398 |
+
if (a.scalar_type() == at::ScalarType::Half) {
|
| 2399 |
+
marlin::marlin_mm<half>(
|
| 2400 |
+
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
| 2401 |
+
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
| 2402 |
+
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
| 2403 |
+
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
| 2404 |
+
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
| 2405 |
+
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
| 2406 |
+
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
| 2407 |
+
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
| 2408 |
+
marlin::marlin_mm<nv_bfloat16>(
|
| 2409 |
+
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
| 2410 |
+
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
| 2411 |
+
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
| 2412 |
+
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
| 2413 |
+
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
| 2414 |
+
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
| 2415 |
+
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
| 2416 |
+
} else {
|
| 2417 |
+
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
| 2418 |
+
}
|
| 2419 |
+
|
| 2420 |
+
return c;
|
| 2421 |
+
}
|
| 2422 |
+
|
| 2423 |
+
#endif
|
gptq_marlin/gptq_marlin_repack.cu
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "marlin.cuh"
|
| 2 |
+
|
| 3 |
+
namespace marlin {
|
| 4 |
+
|
| 5 |
+
template <int const num_threads, int const num_bits, bool const has_perm>
|
| 6 |
+
__global__ void gptq_marlin_repack_kernel(
|
| 7 |
+
uint32_t const* __restrict__ b_q_weight_ptr,
|
| 8 |
+
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
| 9 |
+
int size_k, int size_n) {
|
| 10 |
+
constexpr int pack_factor = 32 / num_bits;
|
| 11 |
+
|
| 12 |
+
int k_tiles = size_k / tile_k_size;
|
| 13 |
+
int n_tiles = size_n / tile_n_size;
|
| 14 |
+
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
| 15 |
+
|
| 16 |
+
int start_k_tile = blockIdx.x * block_k_tiles;
|
| 17 |
+
if (start_k_tile >= k_tiles) {
|
| 18 |
+
return;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
| 22 |
+
|
| 23 |
+
// Wait until the next thread tile has been loaded to shared memory.
|
| 24 |
+
auto wait_for_stage = [&]() {
|
| 25 |
+
// We only have `stages - 2` active fetches since we are double buffering
|
| 26 |
+
// and can only issue the next fetch when it is guaranteed that the previous
|
| 27 |
+
// shared memory load is fully complete (as it may otherwise be
|
| 28 |
+
// overwritten).
|
| 29 |
+
cp_async_wait<repack_stages - 2>();
|
| 30 |
+
__syncthreads();
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
extern __shared__ int4 sh[];
|
| 34 |
+
|
| 35 |
+
constexpr int perm_size = tile_k_size / 4;
|
| 36 |
+
|
| 37 |
+
int4* sh_perm_ptr = sh;
|
| 38 |
+
int4* sh_pipe_ptr = sh_perm_ptr;
|
| 39 |
+
if constexpr (has_perm) {
|
| 40 |
+
sh_pipe_ptr += perm_size;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
constexpr int tile_ints = tile_k_size / pack_factor;
|
| 44 |
+
|
| 45 |
+
constexpr int stage_n_threads = tile_n_size / 4;
|
| 46 |
+
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
| 47 |
+
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
| 48 |
+
|
| 49 |
+
auto load_perm_to_shared = [&](int k_tile_id) {
|
| 50 |
+
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
| 51 |
+
|
| 52 |
+
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
|
| 53 |
+
|
| 54 |
+
if (threadIdx.x < perm_size) {
|
| 55 |
+
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
| 56 |
+
}
|
| 57 |
+
__syncthreads();
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
| 61 |
+
if (n_tile_id >= n_tiles) {
|
| 62 |
+
cp_async_fence();
|
| 63 |
+
return;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
int first_n = n_tile_id * tile_n_size;
|
| 67 |
+
|
| 68 |
+
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
| 69 |
+
|
| 70 |
+
if constexpr (has_perm) {
|
| 71 |
+
if (threadIdx.x < stage_size) {
|
| 72 |
+
int k_id = threadIdx.x / stage_n_threads;
|
| 73 |
+
int n_id = threadIdx.x % stage_n_threads;
|
| 74 |
+
|
| 75 |
+
uint32_t const* sh_perm_int_ptr =
|
| 76 |
+
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
| 77 |
+
|
| 78 |
+
int src_k = sh_perm_int_ptr[k_id];
|
| 79 |
+
int src_k_packed = src_k / pack_factor;
|
| 80 |
+
|
| 81 |
+
cp_async4(
|
| 82 |
+
&sh_ptr[k_id * stage_n_threads + n_id],
|
| 83 |
+
reinterpret_cast<int4 const*>(&(
|
| 84 |
+
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
} else {
|
| 88 |
+
if (threadIdx.x < stage_size) {
|
| 89 |
+
int k_id = threadIdx.x / stage_n_threads;
|
| 90 |
+
int n_id = threadIdx.x % stage_n_threads;
|
| 91 |
+
|
| 92 |
+
int first_k = k_tile_id * tile_k_size;
|
| 93 |
+
int first_k_packed = first_k / pack_factor;
|
| 94 |
+
|
| 95 |
+
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
| 96 |
+
reinterpret_cast<int4 const*>(
|
| 97 |
+
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
| 98 |
+
first_n + (n_id * 4)])));
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
cp_async_fence();
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
| 106 |
+
if (n_tile_id >= n_tiles) {
|
| 107 |
+
return;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
int warp_id = threadIdx.x / 32;
|
| 111 |
+
int th_id = threadIdx.x % 32;
|
| 112 |
+
|
| 113 |
+
if (warp_id >= 4) {
|
| 114 |
+
return;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
int tc_col = th_id / 4;
|
| 118 |
+
int tc_row = (th_id % 4) * 2;
|
| 119 |
+
|
| 120 |
+
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
| 121 |
+
|
| 122 |
+
int cur_n = warp_id * 16 + tc_col;
|
| 123 |
+
|
| 124 |
+
constexpr int sh_stride = 64;
|
| 125 |
+
constexpr uint32_t mask = (1 << num_bits) - 1;
|
| 126 |
+
|
| 127 |
+
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
| 128 |
+
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
| 129 |
+
|
| 130 |
+
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
|
| 131 |
+
|
| 132 |
+
uint32_t vals[8];
|
| 133 |
+
|
| 134 |
+
if constexpr (has_perm) {
|
| 135 |
+
for (int i = 0; i < 4; i++) {
|
| 136 |
+
int k_idx = tc_row + tc_offsets[i];
|
| 137 |
+
|
| 138 |
+
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
| 139 |
+
uint32_t src_k_pos = src_k % pack_factor;
|
| 140 |
+
|
| 141 |
+
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
| 142 |
+
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
| 143 |
+
|
| 144 |
+
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
| 145 |
+
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
| 146 |
+
|
| 147 |
+
vals[i] = b1_cur_val;
|
| 148 |
+
vals[4 + i] = b2_cur_val;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
} else {
|
| 152 |
+
uint32_t b1_vals[tile_ints];
|
| 153 |
+
uint32_t b2_vals[tile_ints];
|
| 154 |
+
|
| 155 |
+
#pragma unroll
|
| 156 |
+
for (int i = 0; i < tile_ints; i++) {
|
| 157 |
+
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
| 158 |
+
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (int i = 0; i < 4; i++) {
|
| 163 |
+
int cur_elem = tc_row + tc_offsets[i];
|
| 164 |
+
int cur_int = cur_elem / pack_factor;
|
| 165 |
+
int cur_pos = cur_elem % pack_factor;
|
| 166 |
+
|
| 167 |
+
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
| 168 |
+
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
| 173 |
+
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
| 174 |
+
|
| 175 |
+
// Result of:
|
| 176 |
+
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
| 177 |
+
if constexpr (num_bits == 4) {
|
| 178 |
+
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
| 179 |
+
|
| 180 |
+
uint32_t res = 0;
|
| 181 |
+
#pragma unroll
|
| 182 |
+
for (int i = 0; i < 8; i++) {
|
| 183 |
+
res |= vals[pack_idx[i]] << (i * 4);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
| 187 |
+
|
| 188 |
+
} else {
|
| 189 |
+
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
| 190 |
+
|
| 191 |
+
uint32_t res1 = 0;
|
| 192 |
+
uint32_t res2 = 0;
|
| 193 |
+
#pragma unroll
|
| 194 |
+
for (int i = 0; i < 4; i++) {
|
| 195 |
+
res1 |= vals[pack_idx[i]] << (i * 8);
|
| 196 |
+
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
| 200 |
+
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
| 201 |
+
}
|
| 202 |
+
};
|
| 203 |
+
|
| 204 |
+
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
| 205 |
+
#pragma unroll
|
| 206 |
+
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
| 207 |
+
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
wait_for_stage();
|
| 211 |
+
};
|
| 212 |
+
#pragma unroll
|
| 213 |
+
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
| 214 |
+
int n_tile_id = 0;
|
| 215 |
+
|
| 216 |
+
if constexpr (has_perm) {
|
| 217 |
+
load_perm_to_shared(k_tile_id);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
start_pipes(k_tile_id, n_tile_id);
|
| 221 |
+
|
| 222 |
+
while (n_tile_id < n_tiles) {
|
| 223 |
+
#pragma unroll
|
| 224 |
+
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
| 225 |
+
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
| 226 |
+
n_tile_id + pipe + repack_stages - 1);
|
| 227 |
+
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
| 228 |
+
wait_for_stage();
|
| 229 |
+
}
|
| 230 |
+
n_tile_id += repack_stages;
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
} // namespace marlin
|
| 236 |
+
|
| 237 |
+
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
| 238 |
+
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
| 239 |
+
cudaFuncSetAttribute( \
|
| 240 |
+
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
| 241 |
+
HAS_PERM>, \
|
| 242 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
| 243 |
+
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
| 244 |
+
HAS_PERM> \
|
| 245 |
+
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
| 246 |
+
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
| 250 |
+
int64_t size_k, int64_t size_n,
|
| 251 |
+
int64_t num_bits) {
|
| 252 |
+
// Verify compatibility with marlin tile of 16x64
|
| 253 |
+
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
| 254 |
+
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
| 255 |
+
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
| 256 |
+
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
| 257 |
+
|
| 258 |
+
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
| 259 |
+
"num_bits must be 4 or 8. Got = ", num_bits);
|
| 260 |
+
int const pack_factor = 32 / num_bits;
|
| 261 |
+
|
| 262 |
+
// Verify B
|
| 263 |
+
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
|
| 264 |
+
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
| 265 |
+
", size_k = ", size_k, ", pack_factor = ", pack_factor);
|
| 266 |
+
TORCH_CHECK(b_q_weight.size(1) == size_n,
|
| 267 |
+
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
| 268 |
+
" is not size_n = ", size_n);
|
| 269 |
+
|
| 270 |
+
// Verify device and strides
|
| 271 |
+
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
| 272 |
+
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
| 273 |
+
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
| 274 |
+
|
| 275 |
+
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
| 276 |
+
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
| 277 |
+
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
| 278 |
+
|
| 279 |
+
// Alloc buffers
|
| 280 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
| 281 |
+
auto options = torch::TensorOptions()
|
| 282 |
+
.dtype(b_q_weight.dtype())
|
| 283 |
+
.device(b_q_weight.device());
|
| 284 |
+
torch::Tensor out = torch::empty(
|
| 285 |
+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
| 286 |
+
options);
|
| 287 |
+
|
| 288 |
+
// Detect if there is act_order
|
| 289 |
+
bool has_perm = perm.size(0) != 0;
|
| 290 |
+
|
| 291 |
+
// Get ptrs
|
| 292 |
+
uint32_t const* b_q_weight_ptr =
|
| 293 |
+
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
| 294 |
+
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
|
| 295 |
+
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
| 296 |
+
|
| 297 |
+
// Get dev info
|
| 298 |
+
int dev = b_q_weight.get_device();
|
| 299 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
| 300 |
+
int blocks;
|
| 301 |
+
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
| 302 |
+
|
| 303 |
+
int max_shared_mem = 0;
|
| 304 |
+
cudaDeviceGetAttribute(&max_shared_mem,
|
| 305 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
| 306 |
+
TORCH_CHECK(max_shared_mem > 0);
|
| 307 |
+
|
| 308 |
+
if (false) {
|
| 309 |
+
}
|
| 310 |
+
CALL_IF(4, false)
|
| 311 |
+
CALL_IF(4, true)
|
| 312 |
+
CALL_IF(8, false)
|
| 313 |
+
CALL_IF(8, true)
|
| 314 |
+
else {
|
| 315 |
+
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
| 316 |
+
", has_perm = ", has_perm);
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
return out;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
| 323 |
+
torch::Tensor& perm, c10::SymInt size_k,
|
| 324 |
+
c10::SymInt size_n, int64_t num_bits) {
|
| 325 |
+
int const pack_factor = 32 / num_bits;
|
| 326 |
+
auto options = torch::TensorOptions()
|
| 327 |
+
.dtype(b_q_weight.dtype())
|
| 328 |
+
.device(b_q_weight.device());
|
| 329 |
+
return torch::empty_symint(
|
| 330 |
+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
| 331 |
+
options);
|
| 332 |
+
}
|
| 333 |
+
|