File size: 5,161 Bytes
3224250 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
#pragma once
#include <ATen/cuda/detail/KernelUtils.h>
#include <cub/cub.cuh>
#include <cutlass/bfloat16.h>
#include <cutlass/gemm_coord.h>
namespace grouped_gemm {
constexpr int kDynamicDim = -1;
constexpr int kMaxExperts = 512;
struct GemmProblem {
::cutlass::gemm::GemmCoord dims;
int64_t lda, ldb, ldc;
// All offsets are in elements.
int64_t a_offset, b_offset, c_offset;
};
// TODO: revisit `ExtractGemmProblemK` struct
// struct ExtractGemmProblemK {
// __device__ ::cuda::std::tuple<int&> operator()(GemmProblem& problem) const {
// return {problem.dims.k()};
// }
// };
template <
// If `k` is dynamic, we sort the problems by `k` in descending order.
// Otherwise, `m` is dynamic, and no sorting happens.
bool kDynamicK,
typename ElementA, typename ElementB, typename ElementC,
typename LayoutA, typename LayoutB, typename LayoutC,
typename Args
>
__global__ void FillArguments(
int num_experts, const int64_t* batch_sizes,
ElementA* ptr_a, ElementB* ptr_b, ElementC* ptr_c,
Args args, ::cutlass::gemm::GemmCoord dims
) {
const int expert_idx = threadIdx.x;
const int batch_size = expert_idx < num_experts ? batch_sizes[expert_idx] : -1;
if (kDynamicK) {
assert(dims.k() == kDynamicDim);
dims.k() = batch_size;
} else {
assert(dims.m() == kDynamicDim);
dims.m() = batch_size;
}
using BlockScan = cub::BlockScan<int, kMaxExperts>;
using BlockSort = cub::BlockRadixSort<int, kMaxExperts, 1, GemmProblem>;
union SharedMemory {
typename BlockScan::TempStorage scan_storage;
typename BlockSort::TempStorage sort_storage;
};
__shared__ SharedMemory shared_memory;
int dynamic_dim = kDynamicK ? dims.k() : dims.m();
int dynamic_dim_cumsum;
BlockScan(shared_memory.scan_storage).ExclusiveSum(dynamic_dim, dynamic_dim_cumsum);
__syncthreads();
// We have to use `GemmProblem[1]` here instead of just `GemmProblem` because `SortDescending()` expects
// `KeyT (&)[ITEMS_PER_THREAD]` for the `keys` argument (i.e., `GemmProblem (&keys)[1]` in our case).
GemmProblem problem[1] = {
GemmProblem {
.dims = dims,
.lda = LayoutA::packed({dims.m(), dims.k()}).stride(0),
.ldb = LayoutB::packed({dims.k(), dims.n()}).stride(0),
.ldc = LayoutC::packed({dims.m(), dims.n()}).stride(0),
.a_offset = kDynamicK
? (dims.m() * dynamic_dim_cumsum)
: (dynamic_dim_cumsum * dims.k()),
.b_offset = (kDynamicK ? dynamic_dim_cumsum : expert_idx * dims.k()) * dims.n(),
.c_offset = (kDynamicK ? expert_idx * dims.m() : dynamic_dim_cumsum) * dims.n(),
},
};
if constexpr (kDynamicK) {
// Sort by k dimension in descending order
// We need to extract the key (k value) for sorting
int k_keys[1] = { problem[0].dims.k() };
BlockSort(shared_memory.sort_storage).SortDescending(k_keys, problem);
// TODO: revisit original impl without `__syncthreads()`
// BlockSort(shared_memory.sort_storage).SortDescending(problem, ExtractGemmProblemK{});
// Quoting the CUB documentation (https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockRadixSort.html):
// > A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage [...]
// > is **to be reused or repurposed**.
// We don't need `__syncthreads()` here, since we don't do either of these things.
}
if (expert_idx < num_experts) {
args.problem_sizes[expert_idx] = problem[0].dims;
args.lda[expert_idx] = problem[0].lda;
args.ldb[expert_idx] = problem[0].ldb;
args.ldc[expert_idx] = problem[0].ldc;
args.ptr_A[expert_idx] = ptr_a + problem[0].a_offset;
args.ptr_B[expert_idx] = ptr_b + problem[0].b_offset;
args.ptr_C[expert_idx] = ptr_c + problem[0].c_offset;
}
}
template <typename Args>
__global__ void ZeroOutK0Outputs(int num_experts, Args args) {
const int64_t start_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
const int64_t delta = (int64_t)gridDim.x * blockDim.x;
for (int ei = 0; ei < num_experts; ++ei) {
auto& dims = args.problem_sizes[ei];
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
// Until a fix is available on the CUTLASS side, handle these problems by ourselves:
// * (here) set the output to zero
// * (in `IgnoreK0Problems`) make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
if (dims.k() == 0) {
// Assume packed layout, run a grid-strided loop over the output.
int64_t total_elems = (int64_t)dims.m() * dims.n();
auto* out = args.ptr_C[ei];
for (int64_t idx = start_idx; idx < total_elems; idx += delta) {
out[idx] = {};
}
}
}
}
template <typename Args>
__global__ void IgnoreK0Problems(int num_experts, Args args) {
const int expert_idx = threadIdx.x;
if (expert_idx < num_experts) {
auto& dims = args.problem_sizes[expert_idx];
if (dims.k() == 0) {
dims.m() = 0;
dims.n() = 0;
}
}
}
} // namespace grouped_gemm
|