kernel
File size: 5,161 Bytes
3224250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#pragma once

#include <ATen/cuda/detail/KernelUtils.h>
#include <cub/cub.cuh>
#include <cutlass/bfloat16.h>
#include <cutlass/gemm_coord.h>

namespace grouped_gemm {

constexpr int kDynamicDim = -1;
constexpr int kMaxExperts = 512;

struct GemmProblem {
  ::cutlass::gemm::GemmCoord dims;
  int64_t lda, ldb, ldc;
  // All offsets are in elements.
  int64_t a_offset, b_offset, c_offset;
};

// TODO: revisit `ExtractGemmProblemK` struct
// struct ExtractGemmProblemK {
//   __device__ ::cuda::std::tuple<int&> operator()(GemmProblem& problem) const {
//       return {problem.dims.k()};
//   }
// };

template <
    // If `k` is dynamic, we sort the problems by `k` in descending order.
    // Otherwise, `m` is dynamic, and no sorting happens.
    bool kDynamicK,
    typename ElementA, typename ElementB, typename ElementC,
    typename LayoutA, typename LayoutB, typename LayoutC,
    typename Args
>
__global__ void FillArguments(
    int num_experts, const int64_t* batch_sizes,
    ElementA* ptr_a, ElementB* ptr_b, ElementC* ptr_c,
    Args args, ::cutlass::gemm::GemmCoord dims
) {
  const int expert_idx = threadIdx.x;
  const int batch_size = expert_idx < num_experts ? batch_sizes[expert_idx] : -1;

  if (kDynamicK) {
    assert(dims.k() == kDynamicDim);
    dims.k() = batch_size;
  } else {
    assert(dims.m() == kDynamicDim);
    dims.m() = batch_size;
  }

  using BlockScan = cub::BlockScan<int, kMaxExperts>;
  using BlockSort = cub::BlockRadixSort<int, kMaxExperts, 1, GemmProblem>;

  union SharedMemory {
    typename BlockScan::TempStorage scan_storage;
    typename BlockSort::TempStorage sort_storage;
  };
  __shared__ SharedMemory shared_memory;

  int dynamic_dim = kDynamicK ? dims.k() : dims.m();
  int dynamic_dim_cumsum;
  BlockScan(shared_memory.scan_storage).ExclusiveSum(dynamic_dim, dynamic_dim_cumsum);
  __syncthreads();

  // We have to use `GemmProblem[1]` here instead of just `GemmProblem` because `SortDescending()` expects
  // `KeyT (&)[ITEMS_PER_THREAD]` for the `keys` argument (i.e., `GemmProblem (&keys)[1]` in our case).
  GemmProblem problem[1] = {
    GemmProblem {
      .dims = dims,
      .lda = LayoutA::packed({dims.m(), dims.k()}).stride(0),
      .ldb = LayoutB::packed({dims.k(), dims.n()}).stride(0),
      .ldc = LayoutC::packed({dims.m(), dims.n()}).stride(0),
      .a_offset = kDynamicK
          ? (dims.m() * dynamic_dim_cumsum)
          : (dynamic_dim_cumsum * dims.k()),
      .b_offset = (kDynamicK ? dynamic_dim_cumsum : expert_idx * dims.k()) * dims.n(),
      .c_offset = (kDynamicK ? expert_idx * dims.m() : dynamic_dim_cumsum) * dims.n(),
    },
  };

  if constexpr (kDynamicK) {
    // Sort by k dimension in descending order
    // We need to extract the key (k value) for sorting
    int k_keys[1] = { problem[0].dims.k() };
    
    BlockSort(shared_memory.sort_storage).SortDescending(k_keys, problem);
    
    // TODO: revisit original impl without `__syncthreads()`
    // BlockSort(shared_memory.sort_storage).SortDescending(problem, ExtractGemmProblemK{});
    // Quoting the CUB documentation (https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockRadixSort.html):
    // > A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage [...]
    // > is **to be reused or repurposed**.
    // We don't need `__syncthreads()` here, since we don't do either of these things.
  }

  if (expert_idx < num_experts) {
    args.problem_sizes[expert_idx] = problem[0].dims;
    args.lda[expert_idx] = problem[0].lda;
    args.ldb[expert_idx] = problem[0].ldb;
    args.ldc[expert_idx] = problem[0].ldc;

    args.ptr_A[expert_idx] = ptr_a + problem[0].a_offset;
    args.ptr_B[expert_idx] = ptr_b + problem[0].b_offset;
    args.ptr_C[expert_idx] = ptr_c + problem[0].c_offset;
  }
}

template <typename Args>
__global__ void ZeroOutK0Outputs(int num_experts, Args args) {
  const int64_t start_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
  const int64_t delta     = (int64_t)gridDim.x * blockDim.x;
  for (int ei = 0; ei < num_experts; ++ei) {
    auto& dims = args.problem_sizes[ei];
    // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
    // Until a fix is available on the CUTLASS side, handle these problems by ourselves:
    //   * (here) set the output to zero
    //   * (in `IgnoreK0Problems`) make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
    if (dims.k() == 0) {
      // Assume packed layout, run a grid-strided loop over the output.
      int64_t total_elems = (int64_t)dims.m() * dims.n();
      auto* out           = args.ptr_C[ei];
      for (int64_t idx = start_idx; idx < total_elems; idx += delta) {
        out[idx] = {};
      }
    }
  }
}

template <typename Args>
__global__ void IgnoreK0Problems(int num_experts, Args args) {
  const int expert_idx = threadIdx.x;
  if (expert_idx < num_experts) {
    auto& dims = args.problem_sizes[expert_idx];
    if (dims.k() == 0) {
      dims.m() = 0;
      dims.n() = 0;
    }
  }
}

}  // namespace grouped_gemm