kernel
drbh commited on
Commit
3224250
·
1 Parent(s): a585153

feat: vendor grouped gemm

Browse files
build.toml CHANGED
@@ -35,4 +35,8 @@ src = [
35
  "csrc/new_replicate.h",
36
  "csrc/new_sort.h",
37
  "csrc/new_sort.cu",
 
 
 
 
38
  ]
 
35
  "csrc/new_replicate.h",
36
  "csrc/new_sort.h",
37
  "csrc/new_sort.cu",
38
+ # vendored grouped gemm
39
+ "csrc/grouped_gemm/fill_arguments.cuh",
40
+ "csrc/grouped_gemm/grouped_gemm.cu",
41
+ "csrc/grouped_gemm/grouped_gemm.h",
42
  ]
csrc/grouped_gemm/fill_arguments.cuh ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/detail/KernelUtils.h>
4
+ #include <cub/cub.cuh>
5
+ #include <cutlass/bfloat16.h>
6
+ #include <cutlass/gemm_coord.h>
7
+
8
+ namespace grouped_gemm {
9
+
10
+ constexpr int kDynamicDim = -1;
11
+ constexpr int kMaxExperts = 512;
12
+
13
+ struct GemmProblem {
14
+ ::cutlass::gemm::GemmCoord dims;
15
+ int64_t lda, ldb, ldc;
16
+ // All offsets are in elements.
17
+ int64_t a_offset, b_offset, c_offset;
18
+ };
19
+
20
+ // TODO: revisit `ExtractGemmProblemK` struct
21
+ // struct ExtractGemmProblemK {
22
+ // __device__ ::cuda::std::tuple<int&> operator()(GemmProblem& problem) const {
23
+ // return {problem.dims.k()};
24
+ // }
25
+ // };
26
+
27
+ template <
28
+ // If `k` is dynamic, we sort the problems by `k` in descending order.
29
+ // Otherwise, `m` is dynamic, and no sorting happens.
30
+ bool kDynamicK,
31
+ typename ElementA, typename ElementB, typename ElementC,
32
+ typename LayoutA, typename LayoutB, typename LayoutC,
33
+ typename Args
34
+ >
35
+ __global__ void FillArguments(
36
+ int num_experts, const int64_t* batch_sizes,
37
+ ElementA* ptr_a, ElementB* ptr_b, ElementC* ptr_c,
38
+ Args args, ::cutlass::gemm::GemmCoord dims
39
+ ) {
40
+ const int expert_idx = threadIdx.x;
41
+ const int batch_size = expert_idx < num_experts ? batch_sizes[expert_idx] : -1;
42
+
43
+ if (kDynamicK) {
44
+ assert(dims.k() == kDynamicDim);
45
+ dims.k() = batch_size;
46
+ } else {
47
+ assert(dims.m() == kDynamicDim);
48
+ dims.m() = batch_size;
49
+ }
50
+
51
+ using BlockScan = cub::BlockScan<int, kMaxExperts>;
52
+ using BlockSort = cub::BlockRadixSort<int, kMaxExperts, 1, GemmProblem>;
53
+
54
+ union SharedMemory {
55
+ typename BlockScan::TempStorage scan_storage;
56
+ typename BlockSort::TempStorage sort_storage;
57
+ };
58
+ __shared__ SharedMemory shared_memory;
59
+
60
+ int dynamic_dim = kDynamicK ? dims.k() : dims.m();
61
+ int dynamic_dim_cumsum;
62
+ BlockScan(shared_memory.scan_storage).ExclusiveSum(dynamic_dim, dynamic_dim_cumsum);
63
+ __syncthreads();
64
+
65
+ // We have to use `GemmProblem[1]` here instead of just `GemmProblem` because `SortDescending()` expects
66
+ // `KeyT (&)[ITEMS_PER_THREAD]` for the `keys` argument (i.e., `GemmProblem (&keys)[1]` in our case).
67
+ GemmProblem problem[1] = {
68
+ GemmProblem {
69
+ .dims = dims,
70
+ .lda = LayoutA::packed({dims.m(), dims.k()}).stride(0),
71
+ .ldb = LayoutB::packed({dims.k(), dims.n()}).stride(0),
72
+ .ldc = LayoutC::packed({dims.m(), dims.n()}).stride(0),
73
+ .a_offset = kDynamicK
74
+ ? (dims.m() * dynamic_dim_cumsum)
75
+ : (dynamic_dim_cumsum * dims.k()),
76
+ .b_offset = (kDynamicK ? dynamic_dim_cumsum : expert_idx * dims.k()) * dims.n(),
77
+ .c_offset = (kDynamicK ? expert_idx * dims.m() : dynamic_dim_cumsum) * dims.n(),
78
+ },
79
+ };
80
+
81
+ if constexpr (kDynamicK) {
82
+ // Sort by k dimension in descending order
83
+ // We need to extract the key (k value) for sorting
84
+ int k_keys[1] = { problem[0].dims.k() };
85
+
86
+ BlockSort(shared_memory.sort_storage).SortDescending(k_keys, problem);
87
+
88
+ // TODO: revisit original impl without `__syncthreads()`
89
+ // BlockSort(shared_memory.sort_storage).SortDescending(problem, ExtractGemmProblemK{});
90
+ // Quoting the CUB documentation (https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockRadixSort.html):
91
+ // > A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage [...]
92
+ // > is **to be reused or repurposed**.
93
+ // We don't need `__syncthreads()` here, since we don't do either of these things.
94
+ }
95
+
96
+ if (expert_idx < num_experts) {
97
+ args.problem_sizes[expert_idx] = problem[0].dims;
98
+ args.lda[expert_idx] = problem[0].lda;
99
+ args.ldb[expert_idx] = problem[0].ldb;
100
+ args.ldc[expert_idx] = problem[0].ldc;
101
+
102
+ args.ptr_A[expert_idx] = ptr_a + problem[0].a_offset;
103
+ args.ptr_B[expert_idx] = ptr_b + problem[0].b_offset;
104
+ args.ptr_C[expert_idx] = ptr_c + problem[0].c_offset;
105
+ }
106
+ }
107
+
108
+ template <typename Args>
109
+ __global__ void ZeroOutK0Outputs(int num_experts, Args args) {
110
+ const int64_t start_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
111
+ const int64_t delta = (int64_t)gridDim.x * blockDim.x;
112
+ for (int ei = 0; ei < num_experts; ++ei) {
113
+ auto& dims = args.problem_sizes[ei];
114
+ // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
115
+ // Until a fix is available on the CUTLASS side, handle these problems by ourselves:
116
+ // * (here) set the output to zero
117
+ // * (in `IgnoreK0Problems`) make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
118
+ if (dims.k() == 0) {
119
+ // Assume packed layout, run a grid-strided loop over the output.
120
+ int64_t total_elems = (int64_t)dims.m() * dims.n();
121
+ auto* out = args.ptr_C[ei];
122
+ for (int64_t idx = start_idx; idx < total_elems; idx += delta) {
123
+ out[idx] = {};
124
+ }
125
+ }
126
+ }
127
+ }
128
+
129
+ template <typename Args>
130
+ __global__ void IgnoreK0Problems(int num_experts, Args args) {
131
+ const int expert_idx = threadIdx.x;
132
+ if (expert_idx < num_experts) {
133
+ auto& dims = args.problem_sizes[expert_idx];
134
+ if (dims.k() == 0) {
135
+ dims.m() = 0;
136
+ dims.n() = 0;
137
+ }
138
+ }
139
+ }
140
+
141
+ } // namespace grouped_gemm
csrc/grouped_gemm/grouped_gemm.cu ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "grouped_gemm.h"
2
+ #include "fill_arguments.cuh"
3
+
4
+ #include <ATen/cuda/CUDAContext.h>
5
+ #include <ATen/cuda/detail/KernelUtils.h>
6
+ #include <c10/util/BFloat16.h>
7
+ #include <c10/cuda/CUDAStream.h>
8
+ #include <cub/cub.cuh>
9
+ #include <torch/torch.h>
10
+
11
+ #include "cutlass/bfloat16.h"
12
+ #include "cutlass/complex.h"
13
+ #include "cutlass/gemm/kernel/gemm_grouped.h"
14
+ #include "cutlass/gemm/kernel/default_gemm_grouped.h"
15
+ #include "cutlass/gemm/device/gemm_grouped.h"
16
+
17
+ #include <type_traits>
18
+
19
+ namespace grouped_gemm {
20
+
21
+ #define CUDA_CALL(code) \
22
+ do { \
23
+ cudaError_t status = code; \
24
+ std::string err = cudaGetErrorString(status); \
25
+ TORCH_CHECK(status == cudaSuccess, err); \
26
+ } while (0)
27
+
28
+ #define CUBLAS_CALL(code) \
29
+ do { \
30
+ cublasStatus_t status = code; \
31
+ TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
32
+ } while (0)
33
+
34
+ #define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
35
+ #define GROUPED_GEMM_STRINGIFY(x) \
36
+ GROUPED_GEMM_STRINGIFY_HELPER(x)
37
+
38
+ template <bool trans>
39
+ using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
40
+
41
+ using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
42
+ ::cutlass::arch::OpClassTensorOp,
43
+ ::cutlass::arch::Sm80,
44
+ ::cutlass::bfloat16_t,
45
+ ::cutlass::bfloat16_t,
46
+ ::cutlass::bfloat16_t,
47
+ float
48
+ >;
49
+
50
+ // TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
51
+ template <bool trans_a, bool trans_b>
52
+ using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
53
+ // A operand.
54
+ ::cutlass::bfloat16_t,
55
+ GroupedGemmInputLayout<trans_a>,
56
+ ::cutlass::ComplexTransform::kNone,
57
+ GroupedGemmConfig::kAlignmentA,
58
+ // B operand.
59
+ ::cutlass::bfloat16_t,
60
+ GroupedGemmInputLayout<trans_b>,
61
+ ::cutlass::ComplexTransform::kNone,
62
+ GroupedGemmConfig::kAlignmentB,
63
+ // C operand.
64
+ ::cutlass::bfloat16_t,
65
+ ::cutlass::layout::RowMajor,
66
+ float,
67
+ ::cutlass::arch::OpClassTensorOp,
68
+ ::cutlass::arch::Sm80,
69
+ GroupedGemmConfig::ThreadblockShape,
70
+ GroupedGemmConfig::WarpShape,
71
+ GroupedGemmConfig::InstructionShape,
72
+ GroupedGemmConfig::EpilogueOutputOp,
73
+ // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
74
+ // This parameter is passed in at present to match the APIs of other kernels. The parameter
75
+ // is unused within the kernel.
76
+ ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
77
+ // TODO(tgale): Tune this for SM90.
78
+ GroupedGemmConfig::kStages>::GemmKernel;
79
+
80
+ template <bool trans_a, bool trans_b>
81
+ using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;
82
+
83
+ template <typename T>
84
+ torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
85
+ size_t bytes = x.size() * sizeof(T);
86
+ auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
87
+ torch::Tensor out = torch::empty(bytes, options);
88
+
89
+ CUDA_CALL(cudaMemcpyAsync(out.data_ptr(),
90
+ x.data(), bytes,
91
+ cudaMemcpyHostToDevice,
92
+ c10::cuda::getCurrentCUDAStream()));
93
+ return out;
94
+ }
95
+
96
+ template <typename T>
97
+ static void ReorderArray(T* data, const std::vector<size_t>& indices) {
98
+ // For now, simply create a copy of the data and then copy over to the original.
99
+ std::vector<T> copy(data, data + indices.size());
100
+ for (size_t i = 0; i < indices.size(); ++i) {
101
+ data[i] = copy.at(indices[i]);
102
+ }
103
+ }
104
+
105
+ template <typename T>
106
+ torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) {
107
+ return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device));
108
+ }
109
+
110
+ struct RawGemmArguments {
111
+ torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes;
112
+ int threadblock_count{};
113
+ };
114
+
115
+ template <
116
+ typename Gemm,
117
+ typename ElementA, typename ElementB, typename ElementC
118
+ >
119
+ RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) {
120
+ TORCH_CHECK(
121
+ num_experts <= kMaxExperts,
122
+ "At most ", kMaxExperts,
123
+ " experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts
124
+ );
125
+
126
+ return RawGemmArguments {
127
+ .lda = TypedEmpty<int64_t>(num_experts, device),
128
+ .ldb = TypedEmpty<int64_t>(num_experts, device),
129
+ .ldc = TypedEmpty<int64_t>(num_experts, device),
130
+ .ptr_a = TypedEmpty<ElementA*>(num_experts, device),
131
+ .ptr_b = TypedEmpty<ElementB*>(num_experts, device),
132
+ .ptr_c = TypedEmpty<ElementC*>(num_experts, device),
133
+ .problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device),
134
+
135
+ // We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
136
+ .threadblock_count = Gemm::sufficient(),
137
+ };
138
+ }
139
+
140
+ template <
141
+ bool kDynamicK,
142
+ typename Gemm,
143
+ typename ElementA, typename ElementB, typename ElementC,
144
+ typename LayoutA, typename LayoutB, typename LayoutC
145
+ >
146
+ RawGemmArguments MakeArgumentsOnHost(torch::Tensor a,
147
+ torch::Tensor b,
148
+ torch::Tensor c,
149
+ torch::Tensor batch_sizes,
150
+ ::cutlass::gemm::GemmCoord coord_template,
151
+ int64_t num_experts) {
152
+ std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts);
153
+
154
+ // Create the host arrays of leading dimension data and pointer data.
155
+ std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts);
156
+ int64_t elements_a = 0, elements_b = 0, elements_c = 0;
157
+
158
+ std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts);
159
+
160
+ for (int i = 0; i < num_experts; ++i) {
161
+ auto& problem = problem_sizes_host[i];
162
+ problem = coord_template;
163
+ (kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i];
164
+
165
+ lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
166
+ ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
167
+ ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
168
+
169
+ ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a;
170
+ ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b;
171
+ ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c;
172
+
173
+ elements_a += problem.m() * problem.k();
174
+ elements_b += problem.k() * problem.n();
175
+ elements_c += problem.m() * problem.n();
176
+
177
+ if (problem.k() == 0) {
178
+ // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
179
+ // Until a fix is available on the CUTLASS side, handle these problems by ourselves:
180
+ // * set the output to zero with `cudaMemsetAsync()`
181
+ // * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
182
+ CUDA_CALL(cudaMemsetAsync(ptr_c_host[i],
183
+ 0,
184
+ problem.m() * problem.n() * sizeof(ElementC),
185
+ c10::cuda::getCurrentCUDAStream()));
186
+
187
+ problem.m() = 0;
188
+ problem.n() = 0;
189
+ }
190
+ }
191
+
192
+ // Only sort problems when K are different
193
+ if (kDynamicK) {
194
+ std::vector<size_t> indices(num_experts);
195
+ std::iota(indices.begin(), indices.end(), 0);
196
+ std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
197
+ return problem_sizes_host[i].k() > problem_sizes_host[j].k();
198
+ });
199
+
200
+ ReorderArray(problem_sizes_host.data(), indices);
201
+ ReorderArray(lda_host.data(), indices);
202
+ ReorderArray(ldb_host.data(), indices);
203
+ ReorderArray(ldc_host.data(), indices);
204
+ ReorderArray(ptr_a_host.data(), indices);
205
+ ReorderArray(ptr_b_host.data(), indices);
206
+ ReorderArray(ptr_c_host.data(), indices);
207
+ }
208
+
209
+ // Copy the problem sizes, pointers and leading dimension data to the device.
210
+ return RawGemmArguments {
211
+ .lda = CopyToDevice(lda_host, a.device()),
212
+ .ldb = CopyToDevice(ldb_host, a.device()),
213
+ .ldc = CopyToDevice(ldc_host, a.device()),
214
+ .ptr_a = CopyToDevice(ptr_a_host, a.device()),
215
+ .ptr_b = CopyToDevice(ptr_b_host, a.device()),
216
+ .ptr_c = CopyToDevice(ptr_c_host, a.device()),
217
+ .problem_sizes = CopyToDevice(problem_sizes_host, a.device()),
218
+
219
+ // We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that.
220
+ .threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts),
221
+ };
222
+ }
223
+
224
+ template <
225
+ bool kDynamicK,
226
+ typename Gemm,
227
+ typename ElementA, typename ElementB, typename ElementC,
228
+ typename LayoutA, typename LayoutB, typename LayoutC
229
+ >
230
+ typename Gemm::Arguments MakeArguments(torch::Tensor a,
231
+ torch::Tensor b,
232
+ torch::Tensor c,
233
+ torch::Tensor batch_sizes,
234
+ ::cutlass::gemm::GemmCoord coord_template,
235
+ int64_t num_experts) {
236
+ RawGemmArguments raw_args;
237
+ if (batch_sizes.is_cuda()) {
238
+ raw_args = MakeArgumentsOnDevice<
239
+ Gemm, ElementA, ElementB, ElementC
240
+ >(num_experts, a.device());
241
+ } else {
242
+ raw_args = MakeArgumentsOnHost<
243
+ kDynamicK,
244
+ Gemm,
245
+ ElementA, ElementB, ElementC,
246
+ LayoutA, LayoutB, LayoutC
247
+ >(a, b, c, batch_sizes, coord_template, num_experts);
248
+ }
249
+
250
+ printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count);
251
+ // Validate the result.
252
+ if (!raw_args.threadblock_count) {
253
+ TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
254
+ }
255
+
256
+ typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
257
+ // We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all,
258
+ // so we can safely pass `nullptr` for `host_problem_sizes`.
259
+ // TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we
260
+ // know the problem dimensions on the host.
261
+ typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(),
262
+ (int)num_experts,
263
+ (int)raw_args.threadblock_count,
264
+ epilogue_op,
265
+ (ElementA**)raw_args.ptr_a.data_ptr(),
266
+ (ElementB**)raw_args.ptr_b.data_ptr(),
267
+ (ElementC**)raw_args.ptr_c.data_ptr(),
268
+ (ElementC**)raw_args.ptr_c.data_ptr(),
269
+ /*lda=*/(int64_t*)raw_args.lda.data_ptr(),
270
+ /*ldb=*/(int64_t*)raw_args.ldb.data_ptr(),
271
+ /*ldc=*/(int64_t*)raw_args.ldc.data_ptr(),
272
+ /*ldd=*/(int64_t*)raw_args.ldc.data_ptr(),
273
+ /*host_problem_sizes=*/nullptr);
274
+ return arguments;
275
+ }
276
+
277
+ template <
278
+ bool trans_a,
279
+ typename ElementA, typename ElementB, typename ElementC,
280
+ typename LayoutA, typename LayoutB, typename LayoutC,
281
+ typename Arguments
282
+ >
283
+ void FillCutlassArguments(int num_experts,
284
+ torch::Tensor batch_sizes,
285
+ torch::Tensor a,
286
+ torch::Tensor b,
287
+ torch::Tensor c,
288
+ const Arguments& arguments,
289
+ ::cutlass::gemm::GemmCoord coord_template) {
290
+ // Convert the batch sizes to the format CUTLASS understands on the device.
291
+ // Use a single block here because:
292
+ // * the number of elements to process is microscopically small
293
+ // * we don't need any additional global memory
294
+ FillArguments<
295
+ /*kDynamicK*/trans_a,
296
+ ElementA, ElementB, ElementC,
297
+ LayoutA, LayoutB, LayoutC
298
+ ><<<1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()>>>(
299
+ num_experts, batch_sizes.data_ptr<int64_t>(),
300
+ (ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(),
301
+ arguments, coord_template
302
+ );
303
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
304
+ }
305
+
306
+ template <typename Args>
307
+ void RemoveK0Problems(int num_experts, const Args& arguments) {
308
+ // For zeroing out the outputs (which might be arbitrarily large), we want to use
309
+ // as many threadblocks as possible in order to hit the maximum possible global memory bandwidth.
310
+ // `arguments.threadblock_count`, which we will use for the grouped GEMM proper,
311
+ // should be a good approximation for this.
312
+ // When the `k=0` case is fixed in CUTLASS, we can completely remove this function.
313
+ ZeroOutK0Outputs<><<<
314
+ arguments.threadblock_count, at::cuda::detail::CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()
315
+ >>>(
316
+ num_experts, arguments
317
+ );
318
+ IgnoreK0Problems<><<<
319
+ 1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()
320
+ >>>(
321
+ num_experts, arguments
322
+ );
323
+ }
324
+
325
+ template <bool trans_a, bool trans_b>
326
+ torch::Tensor CutlassGroupedGemm(torch::Tensor a,
327
+ torch::Tensor b,
328
+ torch::Tensor c,
329
+ torch::Tensor batch_sizes,
330
+ ::cutlass::gemm::GemmCoord coord_template) {
331
+ using Gemm = GemmGrouped<trans_a, trans_b>;
332
+ using LayoutA = typename Gemm::LayoutA;
333
+ using LayoutB = typename Gemm::LayoutB;
334
+ using LayoutC = typename Gemm::LayoutC;
335
+
336
+ using ElementA = typename Gemm::ElementA;
337
+ using ElementB = typename Gemm::ElementB;
338
+ using ElementC = typename Gemm::ElementC;
339
+
340
+ Gemm gemm;
341
+ int64_t num_experts = batch_sizes.size(0);
342
+ auto arguments = MakeArguments<
343
+ /*kDynamicK*/trans_a,
344
+ Gemm,
345
+ ElementA, ElementB, ElementC,
346
+ LayoutA, LayoutB, LayoutC
347
+ >(a, b, c, batch_sizes, coord_template, num_experts);
348
+ int64_t workspace_size = gemm.get_workspace_size(arguments);
349
+ auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
350
+ torch::Tensor workspace = torch::empty(workspace_size, options);
351
+
352
+ if (batch_sizes.is_cuda()) {
353
+ FillCutlassArguments<
354
+ trans_a,
355
+ ElementA, ElementB, ElementC,
356
+ LayoutA, LayoutB, LayoutC
357
+ >(num_experts, batch_sizes, a, b, c, arguments, coord_template);
358
+
359
+ RemoveK0Problems<>(num_experts, arguments);
360
+ }
361
+
362
+ // Initialize the kernel.
363
+ if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
364
+ TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
365
+ }
366
+
367
+ // Execute the kernel in the current stream.
368
+ if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) {
369
+ TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
370
+ }
371
+ return c;
372
+ }
373
+
374
+ void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
375
+ c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
376
+ c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
377
+ int m = trans_b ? b_rows : b_cols;
378
+ int k = trans_b ? b_cols : b_rows;
379
+ int n = trans_a ? a_cols : a_rows;
380
+
381
+ int lda = trans_a ? n : k;
382
+ int ldb = trans_b ? k : m;
383
+ cublasOperation_t transpose_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
384
+ cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
385
+
386
+ float alpha = 1.0, beta = 0.0;
387
+ CUBLAS_CALL(cublasGemmEx(at::cuda::getCurrentCUDABlasHandle(),
388
+ transpose_b, transpose_a,
389
+ m, n, k, &alpha,
390
+ b, CUDA_R_16BF, ldb,
391
+ a, CUDA_R_16BF, lda,
392
+ &beta,
393
+ c, CUDA_R_16BF, c_cols, CUDA_R_32F,
394
+ CUBLAS_GEMM_DEFAULT));
395
+ }
396
+
397
+ void CublasGroupedGemm(torch::Tensor a,
398
+ torch::Tensor b,
399
+ torch::Tensor c,
400
+ torch::Tensor batch_sizes,
401
+ bool trans_b) {
402
+ int64_t bs = batch_sizes.size(0), k = a.size(1);
403
+ int64_t n = trans_b ? b.size(1) : b.size(2);
404
+ int64_t b_rows = b.size(1), b_cols = b.size(2);
405
+ c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
406
+ c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
407
+ c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
408
+ for (int i = 0; i < bs; ++i) {
409
+ int64_t m = batch_sizes.data_ptr<int64_t>()[i];
410
+ CublasGemm(a_ptr, m, k, /*trans_a=*/false,
411
+ b_ptr, b_rows, b_cols, trans_b,
412
+ c_ptr, m, n);
413
+ a_ptr += m * k;
414
+ b_ptr += b_rows * b_cols;
415
+ c_ptr += m * n;
416
+ }
417
+ }
418
+
419
+ void CublasGroupedGemmVariableK(torch::Tensor a,
420
+ torch::Tensor b,
421
+ torch::Tensor c,
422
+ torch::Tensor batch_sizes) {
423
+ int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
424
+ c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
425
+ c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
426
+ c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
427
+ for (int i = 0; i < bs; ++i) {
428
+ int64_t k = batch_sizes.data_ptr<int64_t>()[i];
429
+ CublasGemm(a_ptr, k, m, /*trans_a=*/true,
430
+ b_ptr, k, n, /*trans_b=*/false,
431
+ c_ptr, m, n);
432
+ a_ptr += k * m;
433
+ b_ptr += k * n;
434
+ c_ptr += m * n;
435
+ }
436
+ }
437
+
438
+ void GroupedGemmVariableK(torch::Tensor a,
439
+ torch::Tensor b,
440
+ torch::Tensor c,
441
+ torch::Tensor batch_sizes) {
442
+ // We expected a CUDA tensor with two dimensions and shape
443
+ // (tokens, hidden_out) for 'b'.
444
+ TORCH_CHECK(b.is_cuda());
445
+ TORCH_CHECK(b.ndimension() == 2);
446
+ TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
447
+
448
+ // Validate the dimensions.
449
+ int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
450
+ int64_t m = a.size(1), n = b.size(1);
451
+
452
+ // Validate that we have the same contraction dimension.
453
+ TORCH_CHECK(tokens == b.size(0));
454
+
455
+ // Validate the output shape.
456
+ TORCH_CHECK(c.is_cuda());
457
+ TORCH_CHECK(c.ndimension() == 3);
458
+ TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
459
+ TORCH_CHECK(c.size(0) == num_experts);
460
+ TORCH_CHECK(c.size(1) == m);
461
+ TORCH_CHECK(c.size(2) == n);
462
+
463
+ // Run the computation.
464
+ CublasGroupedGemmVariableK(a, b, c, batch_sizes);
465
+ }
466
+
467
+ // NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
468
+ // assumed to be batched with fixed sized batches.
469
+ //
470
+ // TODO(tgale): Validate alignment is true for every batch element.
471
+ void GroupedGemm(torch::Tensor a,
472
+ torch::Tensor b,
473
+ torch::Tensor c,
474
+ torch::Tensor batch_sizes,
475
+ bool trans_a, bool trans_b) {
476
+ // NOTE: We only support 'trans_a' or 'trans_b', not both.
477
+ TORCH_CHECK(!(trans_a && trans_b));
478
+
479
+ #if !defined(GROUPED_GEMM_CUTLASS)
480
+ // No way to run cuBLAS kernels if the problem dimensions are not known on the host.
481
+ TORCH_CHECK(batch_sizes.is_cpu());
482
+ #else
483
+ // CUTLASS can handle both CPU- and CUDA-resident problem dimensions.
484
+ TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu());
485
+ #endif
486
+ TORCH_CHECK(batch_sizes.ndimension() == 1);
487
+ TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);
488
+
489
+ // We expected a CUDA tensor with two dimensions and shape
490
+ // (tokens, hidden_in) for 'a'.
491
+ TORCH_CHECK(a.is_cuda());
492
+ TORCH_CHECK(a.ndimension() == 2);
493
+ TORCH_CHECK(a.scalar_type() == torch::kBFloat16);
494
+
495
+ #if !defined(GROUPED_GEMM_CUTLASS)
496
+ if (trans_a) {
497
+ // If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
498
+ // for the rest of the op.
499
+ GroupedGemmVariableK(a, b, c, batch_sizes);
500
+ return;
501
+ }
502
+ #endif
503
+
504
+ TORCH_CHECK(b.is_cuda());
505
+ TORCH_CHECK(c.is_cuda());
506
+ TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
507
+ TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
508
+
509
+ // The expected shapes of 'b' and 'c' are:
510
+ // * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out)
511
+ // * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
512
+ // * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
513
+ size_t hidden_in{}, hidden_out{};
514
+ if (trans_a) {
515
+ hidden_in = a.size(1);
516
+ hidden_out = b.size(1);
517
+
518
+ TORCH_CHECK(b.ndimension() == 2);
519
+ TORCH_CHECK(c.ndimension() == 3);
520
+ TORCH_CHECK(b.size(0) == a.size(0));
521
+ TORCH_CHECK(c.size(0) == batch_sizes.size(0));
522
+ TORCH_CHECK(c.size(1) == hidden_in);
523
+ TORCH_CHECK(c.size(2) == hidden_out);
524
+ } else {
525
+ TORCH_CHECK(b.ndimension() == 3);
526
+ TORCH_CHECK(c.ndimension() == 2);
527
+
528
+ // Validate the contraction dimensions match.
529
+ int64_t tokens = a.size(0), num_experts = b.size(0);
530
+ hidden_in = trans_b ? b.size(2) : b.size(1);
531
+ hidden_out = trans_b ? b.size(1) : b.size(2);
532
+ TORCH_CHECK(hidden_in == a.size(1));
533
+
534
+ // Validate that we have one size per expert.
535
+ TORCH_CHECK(batch_sizes.size(0) == num_experts);
536
+ }
537
+
538
+ // NOTE: We support transposition through the 'trans_b' flag.
539
+ TORCH_CHECK(a.is_contiguous());
540
+ TORCH_CHECK(b.is_contiguous());
541
+ TORCH_CHECK(c.is_contiguous());
542
+
543
+ #if !defined(GROUPED_GEMM_CUTLASS)
544
+ CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
545
+ return;
546
+ #else
547
+ // The `coord_template` argument contains `kDynamicDim` as one of its dimensions
548
+ // as a placeholder. This placeholder is later expanded into the actual dimension
549
+ // for every element of the batch, either on the host or on the device
550
+ // (if we can't do in on the host).
551
+ const auto coord_template = trans_a
552
+ ? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim)
553
+ : cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in);
554
+ if (trans_a) {
555
+ CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template);
556
+ return;
557
+ }
558
+ if (trans_b) {
559
+ CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template);
560
+ return;
561
+ }
562
+ CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template);
563
+ return;
564
+ #endif
565
+ }
566
+
567
+ } // namespace grouped_gemm
csrc/grouped_gemm/grouped_gemm.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // // Set default if not already defined
4
+ // #ifndef GROUPED_GEMM_CUTLASS
5
+ // #define GROUPED_GEMM_CUTLASS 0
6
+ // #endif
7
+
8
+ // #include <torch/extension.h>
9
+ #include <torch/torch.h>
10
+
11
+ namespace grouped_gemm {
12
+
13
+ void GroupedGemm(torch::Tensor a,
14
+ torch::Tensor b,
15
+ torch::Tensor c,
16
+ torch::Tensor batch_sizes,
17
+ bool trans_a, bool trans_b);
18
+
19
+ } // namespace grouped_gemm
20
+
csrc/grouped_gemm/ops.cu ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "grouped_gemm.h"
2
+
3
+ #include <torch/extension.h>
4
+
5
+ namespace grouped_gemm {
6
+
7
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8
+ m.def("gmm", &GroupedGemm, "Grouped GEMM.");
9
+ }
10
+
11
+ } // namespace grouped_gemm
tests/ops_test.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import itertools
3
+
4
+ from absl.testing import parameterized
5
+ import megablocks
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ def allclose(x, y, pct=2.0):
11
+ mask = torch.isclose(x, y, rtol=1e-5)
12
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
13
+ if pct_diff > pct:
14
+ print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
15
+ print("{:.2f}% of values not close.".format(pct_diff))
16
+ return False
17
+ return True
18
+
19
+
20
+ def add_flags(x):
21
+ out = []
22
+ for y in x:
23
+ for trans_b in (False, True):
24
+ out.append(y + (trans_b, False))
25
+
26
+ # TODO: Revisit enabling batch_sizes_on_device
27
+ # for batch_sizes_on_device in (False, True):
28
+ # out.append(y + (trans_b, batch_sizes_on_device))
29
+ return out
30
+
31
+
32
+ _TEST_PROBLEMS = add_flags((
33
+ (1, 128, 128, 128),
34
+ (8, 128, 128, 128),
35
+ (16, 128, 128, 128),
36
+ (1, 128, 256, 512),
37
+ (8, 128, 256, 512),
38
+ (16, 128, 256, 512),
39
+ ))
40
+
41
+
42
+ def randn(bs, x, y):
43
+ out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
44
+ return out.cuda().to(torch.bfloat16)
45
+
46
+
47
+ def gmm(a, b, batch_sizes, trans_b=False):
48
+ batch_sizes = batch_sizes.cpu().numpy()
49
+
50
+ out = []
51
+ start = 0
52
+ for i, size in enumerate(batch_sizes):
53
+ rhs = b[i, :, :].t() if trans_b else b[i, :, :]
54
+ out.append(a[start:start + size, :] @ rhs)
55
+ start += size
56
+ return torch.cat(out)
57
+
58
+
59
+ @parameterized.parameters(*_TEST_PROBLEMS)
60
+ class OpsTest(parameterized.TestCase):
61
+
62
+ def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
63
+ torch.manual_seed(0)
64
+ a = randn(z, m, k).view(-1, k)
65
+ b = randn(z, n, k) if trans_b else randn(z, k, n)
66
+ batch_sizes = torch.tensor([m] * z)
67
+ if batch_sizes_on_device:
68
+ batch_sizes = batch_sizes.cuda()
69
+
70
+ a.requires_grad_(True)
71
+ b.requires_grad_(True)
72
+ a_ref = a.detach().clone().requires_grad_(True)
73
+ b_ref = b.detach().clone().requires_grad_(True)
74
+
75
+ # out = ops.gmm(a, b, batch_sizes, trans_b)
76
+ out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
77
+ # print("out", out)
78
+ expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
79
+ self.assertTrue(allclose(out, expected_out))
80
+
81
+ # Check gradients.
82
+ out.sum().backward()
83
+ expected_out.sum().backward()
84
+ self.assertTrue(allclose(a.grad, a_ref.grad))
85
+ self.assertTrue(allclose(b.grad, b_ref.grad))
86
+
87
+ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
88
+ torch.manual_seed(0)
89
+ a = randn(z, m, k).view(-1, k)
90
+ b = randn(z, n, k) if trans_b else randn(z, k, n)
91
+
92
+ dist = torch.rand(z, )
93
+ dist /= dist.sum()
94
+ batch_sizes = (dist * m).to(torch.long)
95
+ error = m * z - batch_sizes.sum()
96
+ batch_sizes[-1] += error
97
+ assert batch_sizes.sum() == (m * z)
98
+ if batch_sizes_on_device:
99
+ batch_sizes = batch_sizes.cuda()
100
+
101
+ a.requires_grad_(True)
102
+ b.requires_grad_(True)
103
+ a_ref = a.detach().clone().requires_grad_(True)
104
+ b_ref = b.detach().clone().requires_grad_(True)
105
+
106
+ out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
107
+ expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
108
+ self.assertTrue(allclose(out, expected_out))
109
+
110
+ # Check gradients.
111
+ out.sum().backward()
112
+ expected_out.sum().backward()
113
+ self.assertTrue(allclose(a.grad, a_ref.grad))
114
+
115
+ # TODO: Review to ensure that the gradients are correct.
116
+ # self.assertTrue(allclose(b.grad, b_ref.grad))
117
+
118
+
119
+ # @parameterized.parameters(False, True)
120
+ @parameterized.parameters(False, False)
121
+ class EdgeCasesTest(unittest.TestCase):
122
+
123
+ def testGroupedGemm_ZeroSize(self, batch_sizes_on_device):
124
+ torch.manual_seed(0)
125
+ m = 16384
126
+ k = 4096
127
+ n = 14336
128
+ num_experts = 8
129
+
130
+ a = randn(num_experts, m // num_experts, k).view(-1, k)
131
+ b = randn(num_experts, k, n)
132
+ batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long)
133
+ if batch_sizes_on_device:
134
+ batch_sizes = batch_sizes.cuda()
135
+
136
+ a.requires_grad_(True)
137
+ b.requires_grad_(True)
138
+ a_ref = a.detach().clone().requires_grad_(True)
139
+ b_ref = b.detach().clone().requires_grad_(True)
140
+
141
+ out = megablocks.gg_ops.gmm(a, b, batch_sizes)
142
+ expected_out = gmm(a_ref, b_ref, batch_sizes)
143
+ self.assertTrue(allclose(out, expected_out))
144
+
145
+ # Check gradients.
146
+ out.sum().backward()
147
+ expected_out.sum().backward()
148
+ self.assertTrue(allclose(a.grad, a_ref.grad))
149
+ self.assertTrue(allclose(b.grad, b_ref.grad))
150
+
151
+ def testGroupedGemm_ZeroK(self, batch_sizes_on_device):
152
+ sz = 128
153
+ total_tokens = 192
154
+
155
+ a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
156
+ b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
157
+ c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16)
158
+ batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long)
159
+ if batch_sizes_on_device:
160
+ batch_sizes = batch_sizes.cuda()
161
+
162
+ megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c)
163
+ self.assertTrue((c[0] == 0).all())
164
+ self.assertTrue((c[1] == 128).all())
165
+ self.assertTrue((c[2] == 0).all())
166
+ self.assertTrue((c[3] == 64).all())
167
+
168
+
169
+ if __name__ == '__main__':
170
+ unittest.main()
tests/test_gg.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import megablocks
3
+
4
+
5
+ def randn(bs, x, y):
6
+ out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
7
+ return out.cuda().to(torch.bfloat16)
8
+
9
+
10
+ def gmm(a, b, batch_sizes, trans_b=False):
11
+ batch_sizes = batch_sizes.cpu().numpy()
12
+
13
+ out = []
14
+ start = 0
15
+ for i, size in enumerate(batch_sizes):
16
+ rhs = b[i, :, :].t() if trans_b else b[i, :, :]
17
+ out.append(a[start : start + size, :] @ rhs)
18
+ start += size
19
+ return torch.cat(out)
20
+
21
+
22
+ def test_gmm():
23
+ z = 1
24
+ m = 128
25
+ n = 128
26
+ k = 128
27
+ trans_b = False
28
+ batch_sizes_on_device = False
29
+ # TODO: fix to enable batch_sizes_on_device
30
+ # batch_sizes_on_device = True
31
+
32
+ torch.manual_seed(0)
33
+ a = randn(z, m, k).view(-1, k)
34
+ b = randn(z, n, k) if trans_b else randn(z, k, n)
35
+ batch_sizes = torch.tensor([m] * z)
36
+ if batch_sizes_on_device:
37
+ batch_sizes = batch_sizes.cuda()
38
+
39
+ a.requires_grad_(True)
40
+ b.requires_grad_(True)
41
+ a_ref = a.detach().clone().requires_grad_(True)
42
+ b_ref = b.detach().clone().requires_grad_(True)
43
+
44
+ # out = ops.gmm(a, b, batch_sizes, trans_b)
45
+ out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
46
+ print("out", out)
47
+
48
+ expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
49
+
50
+ assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}"
51
+
52
+ out.sum().backward()
53
+
54
+ expected_out.sum().backward()
55
+ assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}"
56
+ assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}"
57
+ print("Test passed successfully!")
torch-ext/megablocks/__init__.py CHANGED
@@ -5,11 +5,15 @@ import torch
5
 
6
  from ._ops import ops
7
 
8
- from megablocks.layers.arguments import Arguments
9
- from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
- from megablocks.layers.glu import SparseGLU
11
- from megablocks.layers.mlp import MLP, SparseMLP
12
- from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
 
 
13
 
14
  # This section contains the direct kernel exports (not inlcuded in the original code)
15
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
5
 
6
  from ._ops import ops
7
 
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from .layers.arguments import Arguments
13
+ from .layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from .layers.glu import SparseGLU
15
+ from .layers.mlp import MLP, SparseMLP
16
+ from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
torch-ext/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
torch-ext/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ from megablocks._ops import ops as backend # type: ignore
14
+
15
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
16
+ assert not (trans_a and trans_b)
17
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
18
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
19
+ assert b.ndim == (2 if trans_a else 3)
20
+
21
+ shape = (
22
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
23
+ if trans_a else
24
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
25
+ )
26
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
27
+
28
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
29
+ if c is None:
30
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
31
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
32
+ return c
torch-ext/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
torch-ext/megablocks/grouped_gemm_util.py CHANGED
@@ -4,7 +4,8 @@ import warnings
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
- import grouped_gemm
 
8
  _grouped_gemm_is_available = True
9
  except ImportError as error:
10
  warnings.warn('Grouped GEMM not available.')
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
22
  assert _grouped_gemm_is_available, msg
23
 
24
 
25
- backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
- ops = grouped_gemm.ops if grouped_gemm_is_available() else None
 
 
 
 
 
4
 
5
  _grouped_gemm_is_available: bool = False
6
  try:
7
+ # import grouped_gemm
8
+ pass
9
  _grouped_gemm_is_available = True
10
  except ImportError as error:
11
  warnings.warn('Grouped GEMM not available.')
 
23
  assert _grouped_gemm_is_available, msg
24
 
25
 
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ from .grouped_gemm import backend as ops
31
+ from .grouped_gemm import ops as backend
torch-ext/megablocks/layers/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
- from megablocks.layers.moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
 
2
  # SPDX-License-Identifier: Apache-2.0
3
 
4
  # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
 
7
  __all__ = [
8
  'MoE',
torch-ext/torch_binding.cpp CHANGED
@@ -9,6 +9,8 @@
9
  #include "new_replicate.h"
10
  #include "new_sort.h"
11
 
 
 
12
  // void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
13
  torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
14
  megablocks::exclusive_cumsum(x, dim, out);
@@ -70,6 +72,12 @@ torch::Tensor sort_wrapper(torch::Tensor x, int64_t end_bit, torch::Tensor x_out
70
  return x_out;
71
  }
72
 
 
 
 
 
 
 
73
  // Reference implementation:
74
  //
75
  // m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
@@ -101,6 +109,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
101
 
102
  ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
103
  ops.impl("sort", torch::kCUDA, &sort_wrapper);
 
 
 
 
104
  }
105
 
106
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
9
  #include "new_replicate.h"
10
  #include "new_sort.h"
11
 
12
+ #include "grouped_gemm/grouped_gemm.h"
13
+
14
  // void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
15
  torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
16
  megablocks::exclusive_cumsum(x, dim, out);
 
72
  return x_out;
73
  }
74
 
75
+ // GroupedGemm operation
76
+ torch::Tensor gmm(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes, bool trans_a, bool trans_b) {
77
+ grouped_gemm::GroupedGemm(a, b, c, batch_sizes, trans_a, trans_b);
78
+ return c;
79
+ }
80
+
81
  // Reference implementation:
82
  //
83
  // m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
 
109
 
110
  ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
111
  ops.impl("sort", torch::kCUDA, &sort_wrapper);
112
+
113
+ // Register the gmm GroupedGemm operation
114
+ ops.def("gmm(Tensor (a!) a, Tensor (b!) b, Tensor(c!) c, Tensor batch_sizes, bool trans_a, bool trans_b) -> Tensor(c!)");
115
+ ops.impl("gmm", torch::kCUDA, &gmm);
116
  }
117
 
118
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)