drbh
commited on
Commit
·
3224250
1
Parent(s):
a585153
feat: vendor grouped gemm
Browse files- build.toml +4 -0
- csrc/grouped_gemm/fill_arguments.cuh +141 -0
- csrc/grouped_gemm/grouped_gemm.cu +567 -0
- csrc/grouped_gemm/grouped_gemm.h +20 -0
- csrc/grouped_gemm/ops.cu +11 -0
- tests/ops_test.py +170 -0
- tests/test_gg.py +57 -0
- torch-ext/megablocks/__init__.py +9 -5
- torch-ext/megablocks/grouped_gemm/__init__.py +2 -0
- torch-ext/megablocks/grouped_gemm/backend.py +32 -0
- torch-ext/megablocks/grouped_gemm/ops.py +33 -0
- torch-ext/megablocks/grouped_gemm_util.py +8 -3
- torch-ext/megablocks/layers/__init__.py +1 -1
- torch-ext/torch_binding.cpp +12 -0
build.toml
CHANGED
@@ -35,4 +35,8 @@ src = [
|
|
35 |
"csrc/new_replicate.h",
|
36 |
"csrc/new_sort.h",
|
37 |
"csrc/new_sort.cu",
|
|
|
|
|
|
|
|
|
38 |
]
|
|
|
35 |
"csrc/new_replicate.h",
|
36 |
"csrc/new_sort.h",
|
37 |
"csrc/new_sort.cu",
|
38 |
+
# vendored grouped gemm
|
39 |
+
"csrc/grouped_gemm/fill_arguments.cuh",
|
40 |
+
"csrc/grouped_gemm/grouped_gemm.cu",
|
41 |
+
"csrc/grouped_gemm/grouped_gemm.h",
|
42 |
]
|
csrc/grouped_gemm/fill_arguments.cuh
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/cuda/detail/KernelUtils.h>
|
4 |
+
#include <cub/cub.cuh>
|
5 |
+
#include <cutlass/bfloat16.h>
|
6 |
+
#include <cutlass/gemm_coord.h>
|
7 |
+
|
8 |
+
namespace grouped_gemm {
|
9 |
+
|
10 |
+
constexpr int kDynamicDim = -1;
|
11 |
+
constexpr int kMaxExperts = 512;
|
12 |
+
|
13 |
+
struct GemmProblem {
|
14 |
+
::cutlass::gemm::GemmCoord dims;
|
15 |
+
int64_t lda, ldb, ldc;
|
16 |
+
// All offsets are in elements.
|
17 |
+
int64_t a_offset, b_offset, c_offset;
|
18 |
+
};
|
19 |
+
|
20 |
+
// TODO: revisit `ExtractGemmProblemK` struct
|
21 |
+
// struct ExtractGemmProblemK {
|
22 |
+
// __device__ ::cuda::std::tuple<int&> operator()(GemmProblem& problem) const {
|
23 |
+
// return {problem.dims.k()};
|
24 |
+
// }
|
25 |
+
// };
|
26 |
+
|
27 |
+
template <
|
28 |
+
// If `k` is dynamic, we sort the problems by `k` in descending order.
|
29 |
+
// Otherwise, `m` is dynamic, and no sorting happens.
|
30 |
+
bool kDynamicK,
|
31 |
+
typename ElementA, typename ElementB, typename ElementC,
|
32 |
+
typename LayoutA, typename LayoutB, typename LayoutC,
|
33 |
+
typename Args
|
34 |
+
>
|
35 |
+
__global__ void FillArguments(
|
36 |
+
int num_experts, const int64_t* batch_sizes,
|
37 |
+
ElementA* ptr_a, ElementB* ptr_b, ElementC* ptr_c,
|
38 |
+
Args args, ::cutlass::gemm::GemmCoord dims
|
39 |
+
) {
|
40 |
+
const int expert_idx = threadIdx.x;
|
41 |
+
const int batch_size = expert_idx < num_experts ? batch_sizes[expert_idx] : -1;
|
42 |
+
|
43 |
+
if (kDynamicK) {
|
44 |
+
assert(dims.k() == kDynamicDim);
|
45 |
+
dims.k() = batch_size;
|
46 |
+
} else {
|
47 |
+
assert(dims.m() == kDynamicDim);
|
48 |
+
dims.m() = batch_size;
|
49 |
+
}
|
50 |
+
|
51 |
+
using BlockScan = cub::BlockScan<int, kMaxExperts>;
|
52 |
+
using BlockSort = cub::BlockRadixSort<int, kMaxExperts, 1, GemmProblem>;
|
53 |
+
|
54 |
+
union SharedMemory {
|
55 |
+
typename BlockScan::TempStorage scan_storage;
|
56 |
+
typename BlockSort::TempStorage sort_storage;
|
57 |
+
};
|
58 |
+
__shared__ SharedMemory shared_memory;
|
59 |
+
|
60 |
+
int dynamic_dim = kDynamicK ? dims.k() : dims.m();
|
61 |
+
int dynamic_dim_cumsum;
|
62 |
+
BlockScan(shared_memory.scan_storage).ExclusiveSum(dynamic_dim, dynamic_dim_cumsum);
|
63 |
+
__syncthreads();
|
64 |
+
|
65 |
+
// We have to use `GemmProblem[1]` here instead of just `GemmProblem` because `SortDescending()` expects
|
66 |
+
// `KeyT (&)[ITEMS_PER_THREAD]` for the `keys` argument (i.e., `GemmProblem (&keys)[1]` in our case).
|
67 |
+
GemmProblem problem[1] = {
|
68 |
+
GemmProblem {
|
69 |
+
.dims = dims,
|
70 |
+
.lda = LayoutA::packed({dims.m(), dims.k()}).stride(0),
|
71 |
+
.ldb = LayoutB::packed({dims.k(), dims.n()}).stride(0),
|
72 |
+
.ldc = LayoutC::packed({dims.m(), dims.n()}).stride(0),
|
73 |
+
.a_offset = kDynamicK
|
74 |
+
? (dims.m() * dynamic_dim_cumsum)
|
75 |
+
: (dynamic_dim_cumsum * dims.k()),
|
76 |
+
.b_offset = (kDynamicK ? dynamic_dim_cumsum : expert_idx * dims.k()) * dims.n(),
|
77 |
+
.c_offset = (kDynamicK ? expert_idx * dims.m() : dynamic_dim_cumsum) * dims.n(),
|
78 |
+
},
|
79 |
+
};
|
80 |
+
|
81 |
+
if constexpr (kDynamicK) {
|
82 |
+
// Sort by k dimension in descending order
|
83 |
+
// We need to extract the key (k value) for sorting
|
84 |
+
int k_keys[1] = { problem[0].dims.k() };
|
85 |
+
|
86 |
+
BlockSort(shared_memory.sort_storage).SortDescending(k_keys, problem);
|
87 |
+
|
88 |
+
// TODO: revisit original impl without `__syncthreads()`
|
89 |
+
// BlockSort(shared_memory.sort_storage).SortDescending(problem, ExtractGemmProblemK{});
|
90 |
+
// Quoting the CUB documentation (https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockRadixSort.html):
|
91 |
+
// > A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage [...]
|
92 |
+
// > is **to be reused or repurposed**.
|
93 |
+
// We don't need `__syncthreads()` here, since we don't do either of these things.
|
94 |
+
}
|
95 |
+
|
96 |
+
if (expert_idx < num_experts) {
|
97 |
+
args.problem_sizes[expert_idx] = problem[0].dims;
|
98 |
+
args.lda[expert_idx] = problem[0].lda;
|
99 |
+
args.ldb[expert_idx] = problem[0].ldb;
|
100 |
+
args.ldc[expert_idx] = problem[0].ldc;
|
101 |
+
|
102 |
+
args.ptr_A[expert_idx] = ptr_a + problem[0].a_offset;
|
103 |
+
args.ptr_B[expert_idx] = ptr_b + problem[0].b_offset;
|
104 |
+
args.ptr_C[expert_idx] = ptr_c + problem[0].c_offset;
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
template <typename Args>
|
109 |
+
__global__ void ZeroOutK0Outputs(int num_experts, Args args) {
|
110 |
+
const int64_t start_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
|
111 |
+
const int64_t delta = (int64_t)gridDim.x * blockDim.x;
|
112 |
+
for (int ei = 0; ei < num_experts; ++ei) {
|
113 |
+
auto& dims = args.problem_sizes[ei];
|
114 |
+
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
|
115 |
+
// Until a fix is available on the CUTLASS side, handle these problems by ourselves:
|
116 |
+
// * (here) set the output to zero
|
117 |
+
// * (in `IgnoreK0Problems`) make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
|
118 |
+
if (dims.k() == 0) {
|
119 |
+
// Assume packed layout, run a grid-strided loop over the output.
|
120 |
+
int64_t total_elems = (int64_t)dims.m() * dims.n();
|
121 |
+
auto* out = args.ptr_C[ei];
|
122 |
+
for (int64_t idx = start_idx; idx < total_elems; idx += delta) {
|
123 |
+
out[idx] = {};
|
124 |
+
}
|
125 |
+
}
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
template <typename Args>
|
130 |
+
__global__ void IgnoreK0Problems(int num_experts, Args args) {
|
131 |
+
const int expert_idx = threadIdx.x;
|
132 |
+
if (expert_idx < num_experts) {
|
133 |
+
auto& dims = args.problem_sizes[expert_idx];
|
134 |
+
if (dims.k() == 0) {
|
135 |
+
dims.m() = 0;
|
136 |
+
dims.n() = 0;
|
137 |
+
}
|
138 |
+
}
|
139 |
+
}
|
140 |
+
|
141 |
+
} // namespace grouped_gemm
|
csrc/grouped_gemm/grouped_gemm.cu
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "grouped_gemm.h"
|
2 |
+
#include "fill_arguments.cuh"
|
3 |
+
|
4 |
+
#include <ATen/cuda/CUDAContext.h>
|
5 |
+
#include <ATen/cuda/detail/KernelUtils.h>
|
6 |
+
#include <c10/util/BFloat16.h>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
#include <cub/cub.cuh>
|
9 |
+
#include <torch/torch.h>
|
10 |
+
|
11 |
+
#include "cutlass/bfloat16.h"
|
12 |
+
#include "cutlass/complex.h"
|
13 |
+
#include "cutlass/gemm/kernel/gemm_grouped.h"
|
14 |
+
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
15 |
+
#include "cutlass/gemm/device/gemm_grouped.h"
|
16 |
+
|
17 |
+
#include <type_traits>
|
18 |
+
|
19 |
+
namespace grouped_gemm {
|
20 |
+
|
21 |
+
#define CUDA_CALL(code) \
|
22 |
+
do { \
|
23 |
+
cudaError_t status = code; \
|
24 |
+
std::string err = cudaGetErrorString(status); \
|
25 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
26 |
+
} while (0)
|
27 |
+
|
28 |
+
#define CUBLAS_CALL(code) \
|
29 |
+
do { \
|
30 |
+
cublasStatus_t status = code; \
|
31 |
+
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
|
32 |
+
} while (0)
|
33 |
+
|
34 |
+
#define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
|
35 |
+
#define GROUPED_GEMM_STRINGIFY(x) \
|
36 |
+
GROUPED_GEMM_STRINGIFY_HELPER(x)
|
37 |
+
|
38 |
+
template <bool trans>
|
39 |
+
using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
|
40 |
+
|
41 |
+
using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
|
42 |
+
::cutlass::arch::OpClassTensorOp,
|
43 |
+
::cutlass::arch::Sm80,
|
44 |
+
::cutlass::bfloat16_t,
|
45 |
+
::cutlass::bfloat16_t,
|
46 |
+
::cutlass::bfloat16_t,
|
47 |
+
float
|
48 |
+
>;
|
49 |
+
|
50 |
+
// TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
|
51 |
+
template <bool trans_a, bool trans_b>
|
52 |
+
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
53 |
+
// A operand.
|
54 |
+
::cutlass::bfloat16_t,
|
55 |
+
GroupedGemmInputLayout<trans_a>,
|
56 |
+
::cutlass::ComplexTransform::kNone,
|
57 |
+
GroupedGemmConfig::kAlignmentA,
|
58 |
+
// B operand.
|
59 |
+
::cutlass::bfloat16_t,
|
60 |
+
GroupedGemmInputLayout<trans_b>,
|
61 |
+
::cutlass::ComplexTransform::kNone,
|
62 |
+
GroupedGemmConfig::kAlignmentB,
|
63 |
+
// C operand.
|
64 |
+
::cutlass::bfloat16_t,
|
65 |
+
::cutlass::layout::RowMajor,
|
66 |
+
float,
|
67 |
+
::cutlass::arch::OpClassTensorOp,
|
68 |
+
::cutlass::arch::Sm80,
|
69 |
+
GroupedGemmConfig::ThreadblockShape,
|
70 |
+
GroupedGemmConfig::WarpShape,
|
71 |
+
GroupedGemmConfig::InstructionShape,
|
72 |
+
GroupedGemmConfig::EpilogueOutputOp,
|
73 |
+
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
|
74 |
+
// This parameter is passed in at present to match the APIs of other kernels. The parameter
|
75 |
+
// is unused within the kernel.
|
76 |
+
::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
77 |
+
// TODO(tgale): Tune this for SM90.
|
78 |
+
GroupedGemmConfig::kStages>::GemmKernel;
|
79 |
+
|
80 |
+
template <bool trans_a, bool trans_b>
|
81 |
+
using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;
|
82 |
+
|
83 |
+
template <typename T>
|
84 |
+
torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
|
85 |
+
size_t bytes = x.size() * sizeof(T);
|
86 |
+
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
|
87 |
+
torch::Tensor out = torch::empty(bytes, options);
|
88 |
+
|
89 |
+
CUDA_CALL(cudaMemcpyAsync(out.data_ptr(),
|
90 |
+
x.data(), bytes,
|
91 |
+
cudaMemcpyHostToDevice,
|
92 |
+
c10::cuda::getCurrentCUDAStream()));
|
93 |
+
return out;
|
94 |
+
}
|
95 |
+
|
96 |
+
template <typename T>
|
97 |
+
static void ReorderArray(T* data, const std::vector<size_t>& indices) {
|
98 |
+
// For now, simply create a copy of the data and then copy over to the original.
|
99 |
+
std::vector<T> copy(data, data + indices.size());
|
100 |
+
for (size_t i = 0; i < indices.size(); ++i) {
|
101 |
+
data[i] = copy.at(indices[i]);
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
template <typename T>
|
106 |
+
torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) {
|
107 |
+
return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device));
|
108 |
+
}
|
109 |
+
|
110 |
+
struct RawGemmArguments {
|
111 |
+
torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes;
|
112 |
+
int threadblock_count{};
|
113 |
+
};
|
114 |
+
|
115 |
+
template <
|
116 |
+
typename Gemm,
|
117 |
+
typename ElementA, typename ElementB, typename ElementC
|
118 |
+
>
|
119 |
+
RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) {
|
120 |
+
TORCH_CHECK(
|
121 |
+
num_experts <= kMaxExperts,
|
122 |
+
"At most ", kMaxExperts,
|
123 |
+
" experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts
|
124 |
+
);
|
125 |
+
|
126 |
+
return RawGemmArguments {
|
127 |
+
.lda = TypedEmpty<int64_t>(num_experts, device),
|
128 |
+
.ldb = TypedEmpty<int64_t>(num_experts, device),
|
129 |
+
.ldc = TypedEmpty<int64_t>(num_experts, device),
|
130 |
+
.ptr_a = TypedEmpty<ElementA*>(num_experts, device),
|
131 |
+
.ptr_b = TypedEmpty<ElementB*>(num_experts, device),
|
132 |
+
.ptr_c = TypedEmpty<ElementC*>(num_experts, device),
|
133 |
+
.problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device),
|
134 |
+
|
135 |
+
// We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
|
136 |
+
.threadblock_count = Gemm::sufficient(),
|
137 |
+
};
|
138 |
+
}
|
139 |
+
|
140 |
+
template <
|
141 |
+
bool kDynamicK,
|
142 |
+
typename Gemm,
|
143 |
+
typename ElementA, typename ElementB, typename ElementC,
|
144 |
+
typename LayoutA, typename LayoutB, typename LayoutC
|
145 |
+
>
|
146 |
+
RawGemmArguments MakeArgumentsOnHost(torch::Tensor a,
|
147 |
+
torch::Tensor b,
|
148 |
+
torch::Tensor c,
|
149 |
+
torch::Tensor batch_sizes,
|
150 |
+
::cutlass::gemm::GemmCoord coord_template,
|
151 |
+
int64_t num_experts) {
|
152 |
+
std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts);
|
153 |
+
|
154 |
+
// Create the host arrays of leading dimension data and pointer data.
|
155 |
+
std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts);
|
156 |
+
int64_t elements_a = 0, elements_b = 0, elements_c = 0;
|
157 |
+
|
158 |
+
std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts);
|
159 |
+
|
160 |
+
for (int i = 0; i < num_experts; ++i) {
|
161 |
+
auto& problem = problem_sizes_host[i];
|
162 |
+
problem = coord_template;
|
163 |
+
(kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i];
|
164 |
+
|
165 |
+
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
|
166 |
+
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
|
167 |
+
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
|
168 |
+
|
169 |
+
ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a;
|
170 |
+
ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b;
|
171 |
+
ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c;
|
172 |
+
|
173 |
+
elements_a += problem.m() * problem.k();
|
174 |
+
elements_b += problem.k() * problem.n();
|
175 |
+
elements_c += problem.m() * problem.n();
|
176 |
+
|
177 |
+
if (problem.k() == 0) {
|
178 |
+
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
|
179 |
+
// Until a fix is available on the CUTLASS side, handle these problems by ourselves:
|
180 |
+
// * set the output to zero with `cudaMemsetAsync()`
|
181 |
+
// * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
|
182 |
+
CUDA_CALL(cudaMemsetAsync(ptr_c_host[i],
|
183 |
+
0,
|
184 |
+
problem.m() * problem.n() * sizeof(ElementC),
|
185 |
+
c10::cuda::getCurrentCUDAStream()));
|
186 |
+
|
187 |
+
problem.m() = 0;
|
188 |
+
problem.n() = 0;
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
// Only sort problems when K are different
|
193 |
+
if (kDynamicK) {
|
194 |
+
std::vector<size_t> indices(num_experts);
|
195 |
+
std::iota(indices.begin(), indices.end(), 0);
|
196 |
+
std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
|
197 |
+
return problem_sizes_host[i].k() > problem_sizes_host[j].k();
|
198 |
+
});
|
199 |
+
|
200 |
+
ReorderArray(problem_sizes_host.data(), indices);
|
201 |
+
ReorderArray(lda_host.data(), indices);
|
202 |
+
ReorderArray(ldb_host.data(), indices);
|
203 |
+
ReorderArray(ldc_host.data(), indices);
|
204 |
+
ReorderArray(ptr_a_host.data(), indices);
|
205 |
+
ReorderArray(ptr_b_host.data(), indices);
|
206 |
+
ReorderArray(ptr_c_host.data(), indices);
|
207 |
+
}
|
208 |
+
|
209 |
+
// Copy the problem sizes, pointers and leading dimension data to the device.
|
210 |
+
return RawGemmArguments {
|
211 |
+
.lda = CopyToDevice(lda_host, a.device()),
|
212 |
+
.ldb = CopyToDevice(ldb_host, a.device()),
|
213 |
+
.ldc = CopyToDevice(ldc_host, a.device()),
|
214 |
+
.ptr_a = CopyToDevice(ptr_a_host, a.device()),
|
215 |
+
.ptr_b = CopyToDevice(ptr_b_host, a.device()),
|
216 |
+
.ptr_c = CopyToDevice(ptr_c_host, a.device()),
|
217 |
+
.problem_sizes = CopyToDevice(problem_sizes_host, a.device()),
|
218 |
+
|
219 |
+
// We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that.
|
220 |
+
.threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts),
|
221 |
+
};
|
222 |
+
}
|
223 |
+
|
224 |
+
template <
|
225 |
+
bool kDynamicK,
|
226 |
+
typename Gemm,
|
227 |
+
typename ElementA, typename ElementB, typename ElementC,
|
228 |
+
typename LayoutA, typename LayoutB, typename LayoutC
|
229 |
+
>
|
230 |
+
typename Gemm::Arguments MakeArguments(torch::Tensor a,
|
231 |
+
torch::Tensor b,
|
232 |
+
torch::Tensor c,
|
233 |
+
torch::Tensor batch_sizes,
|
234 |
+
::cutlass::gemm::GemmCoord coord_template,
|
235 |
+
int64_t num_experts) {
|
236 |
+
RawGemmArguments raw_args;
|
237 |
+
if (batch_sizes.is_cuda()) {
|
238 |
+
raw_args = MakeArgumentsOnDevice<
|
239 |
+
Gemm, ElementA, ElementB, ElementC
|
240 |
+
>(num_experts, a.device());
|
241 |
+
} else {
|
242 |
+
raw_args = MakeArgumentsOnHost<
|
243 |
+
kDynamicK,
|
244 |
+
Gemm,
|
245 |
+
ElementA, ElementB, ElementC,
|
246 |
+
LayoutA, LayoutB, LayoutC
|
247 |
+
>(a, b, c, batch_sizes, coord_template, num_experts);
|
248 |
+
}
|
249 |
+
|
250 |
+
printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count);
|
251 |
+
// Validate the result.
|
252 |
+
if (!raw_args.threadblock_count) {
|
253 |
+
TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
|
254 |
+
}
|
255 |
+
|
256 |
+
typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
|
257 |
+
// We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all,
|
258 |
+
// so we can safely pass `nullptr` for `host_problem_sizes`.
|
259 |
+
// TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we
|
260 |
+
// know the problem dimensions on the host.
|
261 |
+
typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(),
|
262 |
+
(int)num_experts,
|
263 |
+
(int)raw_args.threadblock_count,
|
264 |
+
epilogue_op,
|
265 |
+
(ElementA**)raw_args.ptr_a.data_ptr(),
|
266 |
+
(ElementB**)raw_args.ptr_b.data_ptr(),
|
267 |
+
(ElementC**)raw_args.ptr_c.data_ptr(),
|
268 |
+
(ElementC**)raw_args.ptr_c.data_ptr(),
|
269 |
+
/*lda=*/(int64_t*)raw_args.lda.data_ptr(),
|
270 |
+
/*ldb=*/(int64_t*)raw_args.ldb.data_ptr(),
|
271 |
+
/*ldc=*/(int64_t*)raw_args.ldc.data_ptr(),
|
272 |
+
/*ldd=*/(int64_t*)raw_args.ldc.data_ptr(),
|
273 |
+
/*host_problem_sizes=*/nullptr);
|
274 |
+
return arguments;
|
275 |
+
}
|
276 |
+
|
277 |
+
template <
|
278 |
+
bool trans_a,
|
279 |
+
typename ElementA, typename ElementB, typename ElementC,
|
280 |
+
typename LayoutA, typename LayoutB, typename LayoutC,
|
281 |
+
typename Arguments
|
282 |
+
>
|
283 |
+
void FillCutlassArguments(int num_experts,
|
284 |
+
torch::Tensor batch_sizes,
|
285 |
+
torch::Tensor a,
|
286 |
+
torch::Tensor b,
|
287 |
+
torch::Tensor c,
|
288 |
+
const Arguments& arguments,
|
289 |
+
::cutlass::gemm::GemmCoord coord_template) {
|
290 |
+
// Convert the batch sizes to the format CUTLASS understands on the device.
|
291 |
+
// Use a single block here because:
|
292 |
+
// * the number of elements to process is microscopically small
|
293 |
+
// * we don't need any additional global memory
|
294 |
+
FillArguments<
|
295 |
+
/*kDynamicK*/trans_a,
|
296 |
+
ElementA, ElementB, ElementC,
|
297 |
+
LayoutA, LayoutB, LayoutC
|
298 |
+
><<<1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
299 |
+
num_experts, batch_sizes.data_ptr<int64_t>(),
|
300 |
+
(ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(),
|
301 |
+
arguments, coord_template
|
302 |
+
);
|
303 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
304 |
+
}
|
305 |
+
|
306 |
+
template <typename Args>
|
307 |
+
void RemoveK0Problems(int num_experts, const Args& arguments) {
|
308 |
+
// For zeroing out the outputs (which might be arbitrarily large), we want to use
|
309 |
+
// as many threadblocks as possible in order to hit the maximum possible global memory bandwidth.
|
310 |
+
// `arguments.threadblock_count`, which we will use for the grouped GEMM proper,
|
311 |
+
// should be a good approximation for this.
|
312 |
+
// When the `k=0` case is fixed in CUTLASS, we can completely remove this function.
|
313 |
+
ZeroOutK0Outputs<><<<
|
314 |
+
arguments.threadblock_count, at::cuda::detail::CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()
|
315 |
+
>>>(
|
316 |
+
num_experts, arguments
|
317 |
+
);
|
318 |
+
IgnoreK0Problems<><<<
|
319 |
+
1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()
|
320 |
+
>>>(
|
321 |
+
num_experts, arguments
|
322 |
+
);
|
323 |
+
}
|
324 |
+
|
325 |
+
template <bool trans_a, bool trans_b>
|
326 |
+
torch::Tensor CutlassGroupedGemm(torch::Tensor a,
|
327 |
+
torch::Tensor b,
|
328 |
+
torch::Tensor c,
|
329 |
+
torch::Tensor batch_sizes,
|
330 |
+
::cutlass::gemm::GemmCoord coord_template) {
|
331 |
+
using Gemm = GemmGrouped<trans_a, trans_b>;
|
332 |
+
using LayoutA = typename Gemm::LayoutA;
|
333 |
+
using LayoutB = typename Gemm::LayoutB;
|
334 |
+
using LayoutC = typename Gemm::LayoutC;
|
335 |
+
|
336 |
+
using ElementA = typename Gemm::ElementA;
|
337 |
+
using ElementB = typename Gemm::ElementB;
|
338 |
+
using ElementC = typename Gemm::ElementC;
|
339 |
+
|
340 |
+
Gemm gemm;
|
341 |
+
int64_t num_experts = batch_sizes.size(0);
|
342 |
+
auto arguments = MakeArguments<
|
343 |
+
/*kDynamicK*/trans_a,
|
344 |
+
Gemm,
|
345 |
+
ElementA, ElementB, ElementC,
|
346 |
+
LayoutA, LayoutB, LayoutC
|
347 |
+
>(a, b, c, batch_sizes, coord_template, num_experts);
|
348 |
+
int64_t workspace_size = gemm.get_workspace_size(arguments);
|
349 |
+
auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
|
350 |
+
torch::Tensor workspace = torch::empty(workspace_size, options);
|
351 |
+
|
352 |
+
if (batch_sizes.is_cuda()) {
|
353 |
+
FillCutlassArguments<
|
354 |
+
trans_a,
|
355 |
+
ElementA, ElementB, ElementC,
|
356 |
+
LayoutA, LayoutB, LayoutC
|
357 |
+
>(num_experts, batch_sizes, a, b, c, arguments, coord_template);
|
358 |
+
|
359 |
+
RemoveK0Problems<>(num_experts, arguments);
|
360 |
+
}
|
361 |
+
|
362 |
+
// Initialize the kernel.
|
363 |
+
if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
|
364 |
+
TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
|
365 |
+
}
|
366 |
+
|
367 |
+
// Execute the kernel in the current stream.
|
368 |
+
if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) {
|
369 |
+
TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
|
370 |
+
}
|
371 |
+
return c;
|
372 |
+
}
|
373 |
+
|
374 |
+
void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
|
375 |
+
c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
|
376 |
+
c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
|
377 |
+
int m = trans_b ? b_rows : b_cols;
|
378 |
+
int k = trans_b ? b_cols : b_rows;
|
379 |
+
int n = trans_a ? a_cols : a_rows;
|
380 |
+
|
381 |
+
int lda = trans_a ? n : k;
|
382 |
+
int ldb = trans_b ? k : m;
|
383 |
+
cublasOperation_t transpose_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
|
384 |
+
cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
|
385 |
+
|
386 |
+
float alpha = 1.0, beta = 0.0;
|
387 |
+
CUBLAS_CALL(cublasGemmEx(at::cuda::getCurrentCUDABlasHandle(),
|
388 |
+
transpose_b, transpose_a,
|
389 |
+
m, n, k, &alpha,
|
390 |
+
b, CUDA_R_16BF, ldb,
|
391 |
+
a, CUDA_R_16BF, lda,
|
392 |
+
&beta,
|
393 |
+
c, CUDA_R_16BF, c_cols, CUDA_R_32F,
|
394 |
+
CUBLAS_GEMM_DEFAULT));
|
395 |
+
}
|
396 |
+
|
397 |
+
void CublasGroupedGemm(torch::Tensor a,
|
398 |
+
torch::Tensor b,
|
399 |
+
torch::Tensor c,
|
400 |
+
torch::Tensor batch_sizes,
|
401 |
+
bool trans_b) {
|
402 |
+
int64_t bs = batch_sizes.size(0), k = a.size(1);
|
403 |
+
int64_t n = trans_b ? b.size(1) : b.size(2);
|
404 |
+
int64_t b_rows = b.size(1), b_cols = b.size(2);
|
405 |
+
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
|
406 |
+
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
|
407 |
+
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
|
408 |
+
for (int i = 0; i < bs; ++i) {
|
409 |
+
int64_t m = batch_sizes.data_ptr<int64_t>()[i];
|
410 |
+
CublasGemm(a_ptr, m, k, /*trans_a=*/false,
|
411 |
+
b_ptr, b_rows, b_cols, trans_b,
|
412 |
+
c_ptr, m, n);
|
413 |
+
a_ptr += m * k;
|
414 |
+
b_ptr += b_rows * b_cols;
|
415 |
+
c_ptr += m * n;
|
416 |
+
}
|
417 |
+
}
|
418 |
+
|
419 |
+
void CublasGroupedGemmVariableK(torch::Tensor a,
|
420 |
+
torch::Tensor b,
|
421 |
+
torch::Tensor c,
|
422 |
+
torch::Tensor batch_sizes) {
|
423 |
+
int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
|
424 |
+
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
|
425 |
+
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
|
426 |
+
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
|
427 |
+
for (int i = 0; i < bs; ++i) {
|
428 |
+
int64_t k = batch_sizes.data_ptr<int64_t>()[i];
|
429 |
+
CublasGemm(a_ptr, k, m, /*trans_a=*/true,
|
430 |
+
b_ptr, k, n, /*trans_b=*/false,
|
431 |
+
c_ptr, m, n);
|
432 |
+
a_ptr += k * m;
|
433 |
+
b_ptr += k * n;
|
434 |
+
c_ptr += m * n;
|
435 |
+
}
|
436 |
+
}
|
437 |
+
|
438 |
+
void GroupedGemmVariableK(torch::Tensor a,
|
439 |
+
torch::Tensor b,
|
440 |
+
torch::Tensor c,
|
441 |
+
torch::Tensor batch_sizes) {
|
442 |
+
// We expected a CUDA tensor with two dimensions and shape
|
443 |
+
// (tokens, hidden_out) for 'b'.
|
444 |
+
TORCH_CHECK(b.is_cuda());
|
445 |
+
TORCH_CHECK(b.ndimension() == 2);
|
446 |
+
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
|
447 |
+
|
448 |
+
// Validate the dimensions.
|
449 |
+
int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
|
450 |
+
int64_t m = a.size(1), n = b.size(1);
|
451 |
+
|
452 |
+
// Validate that we have the same contraction dimension.
|
453 |
+
TORCH_CHECK(tokens == b.size(0));
|
454 |
+
|
455 |
+
// Validate the output shape.
|
456 |
+
TORCH_CHECK(c.is_cuda());
|
457 |
+
TORCH_CHECK(c.ndimension() == 3);
|
458 |
+
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
|
459 |
+
TORCH_CHECK(c.size(0) == num_experts);
|
460 |
+
TORCH_CHECK(c.size(1) == m);
|
461 |
+
TORCH_CHECK(c.size(2) == n);
|
462 |
+
|
463 |
+
// Run the computation.
|
464 |
+
CublasGroupedGemmVariableK(a, b, c, batch_sizes);
|
465 |
+
}
|
466 |
+
|
467 |
+
// NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
|
468 |
+
// assumed to be batched with fixed sized batches.
|
469 |
+
//
|
470 |
+
// TODO(tgale): Validate alignment is true for every batch element.
|
471 |
+
void GroupedGemm(torch::Tensor a,
|
472 |
+
torch::Tensor b,
|
473 |
+
torch::Tensor c,
|
474 |
+
torch::Tensor batch_sizes,
|
475 |
+
bool trans_a, bool trans_b) {
|
476 |
+
// NOTE: We only support 'trans_a' or 'trans_b', not both.
|
477 |
+
TORCH_CHECK(!(trans_a && trans_b));
|
478 |
+
|
479 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
480 |
+
// No way to run cuBLAS kernels if the problem dimensions are not known on the host.
|
481 |
+
TORCH_CHECK(batch_sizes.is_cpu());
|
482 |
+
#else
|
483 |
+
// CUTLASS can handle both CPU- and CUDA-resident problem dimensions.
|
484 |
+
TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu());
|
485 |
+
#endif
|
486 |
+
TORCH_CHECK(batch_sizes.ndimension() == 1);
|
487 |
+
TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);
|
488 |
+
|
489 |
+
// We expected a CUDA tensor with two dimensions and shape
|
490 |
+
// (tokens, hidden_in) for 'a'.
|
491 |
+
TORCH_CHECK(a.is_cuda());
|
492 |
+
TORCH_CHECK(a.ndimension() == 2);
|
493 |
+
TORCH_CHECK(a.scalar_type() == torch::kBFloat16);
|
494 |
+
|
495 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
496 |
+
if (trans_a) {
|
497 |
+
// If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
|
498 |
+
// for the rest of the op.
|
499 |
+
GroupedGemmVariableK(a, b, c, batch_sizes);
|
500 |
+
return;
|
501 |
+
}
|
502 |
+
#endif
|
503 |
+
|
504 |
+
TORCH_CHECK(b.is_cuda());
|
505 |
+
TORCH_CHECK(c.is_cuda());
|
506 |
+
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
|
507 |
+
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
|
508 |
+
|
509 |
+
// The expected shapes of 'b' and 'c' are:
|
510 |
+
// * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out)
|
511 |
+
// * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
|
512 |
+
// * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
|
513 |
+
size_t hidden_in{}, hidden_out{};
|
514 |
+
if (trans_a) {
|
515 |
+
hidden_in = a.size(1);
|
516 |
+
hidden_out = b.size(1);
|
517 |
+
|
518 |
+
TORCH_CHECK(b.ndimension() == 2);
|
519 |
+
TORCH_CHECK(c.ndimension() == 3);
|
520 |
+
TORCH_CHECK(b.size(0) == a.size(0));
|
521 |
+
TORCH_CHECK(c.size(0) == batch_sizes.size(0));
|
522 |
+
TORCH_CHECK(c.size(1) == hidden_in);
|
523 |
+
TORCH_CHECK(c.size(2) == hidden_out);
|
524 |
+
} else {
|
525 |
+
TORCH_CHECK(b.ndimension() == 3);
|
526 |
+
TORCH_CHECK(c.ndimension() == 2);
|
527 |
+
|
528 |
+
// Validate the contraction dimensions match.
|
529 |
+
int64_t tokens = a.size(0), num_experts = b.size(0);
|
530 |
+
hidden_in = trans_b ? b.size(2) : b.size(1);
|
531 |
+
hidden_out = trans_b ? b.size(1) : b.size(2);
|
532 |
+
TORCH_CHECK(hidden_in == a.size(1));
|
533 |
+
|
534 |
+
// Validate that we have one size per expert.
|
535 |
+
TORCH_CHECK(batch_sizes.size(0) == num_experts);
|
536 |
+
}
|
537 |
+
|
538 |
+
// NOTE: We support transposition through the 'trans_b' flag.
|
539 |
+
TORCH_CHECK(a.is_contiguous());
|
540 |
+
TORCH_CHECK(b.is_contiguous());
|
541 |
+
TORCH_CHECK(c.is_contiguous());
|
542 |
+
|
543 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
544 |
+
CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
|
545 |
+
return;
|
546 |
+
#else
|
547 |
+
// The `coord_template` argument contains `kDynamicDim` as one of its dimensions
|
548 |
+
// as a placeholder. This placeholder is later expanded into the actual dimension
|
549 |
+
// for every element of the batch, either on the host or on the device
|
550 |
+
// (if we can't do in on the host).
|
551 |
+
const auto coord_template = trans_a
|
552 |
+
? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim)
|
553 |
+
: cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in);
|
554 |
+
if (trans_a) {
|
555 |
+
CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template);
|
556 |
+
return;
|
557 |
+
}
|
558 |
+
if (trans_b) {
|
559 |
+
CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template);
|
560 |
+
return;
|
561 |
+
}
|
562 |
+
CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template);
|
563 |
+
return;
|
564 |
+
#endif
|
565 |
+
}
|
566 |
+
|
567 |
+
} // namespace grouped_gemm
|
csrc/grouped_gemm/grouped_gemm.h
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
// // Set default if not already defined
|
4 |
+
// #ifndef GROUPED_GEMM_CUTLASS
|
5 |
+
// #define GROUPED_GEMM_CUTLASS 0
|
6 |
+
// #endif
|
7 |
+
|
8 |
+
// #include <torch/extension.h>
|
9 |
+
#include <torch/torch.h>
|
10 |
+
|
11 |
+
namespace grouped_gemm {
|
12 |
+
|
13 |
+
void GroupedGemm(torch::Tensor a,
|
14 |
+
torch::Tensor b,
|
15 |
+
torch::Tensor c,
|
16 |
+
torch::Tensor batch_sizes,
|
17 |
+
bool trans_a, bool trans_b);
|
18 |
+
|
19 |
+
} // namespace grouped_gemm
|
20 |
+
|
csrc/grouped_gemm/ops.cu
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "grouped_gemm.h"
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
namespace grouped_gemm {
|
6 |
+
|
7 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
8 |
+
m.def("gmm", &GroupedGemm, "Grouped GEMM.");
|
9 |
+
}
|
10 |
+
|
11 |
+
} // namespace grouped_gemm
|
tests/ops_test.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import itertools
|
3 |
+
|
4 |
+
from absl.testing import parameterized
|
5 |
+
import megablocks
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def allclose(x, y, pct=2.0):
|
11 |
+
mask = torch.isclose(x, y, rtol=1e-5)
|
12 |
+
pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
|
13 |
+
if pct_diff > pct:
|
14 |
+
print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
|
15 |
+
print("{:.2f}% of values not close.".format(pct_diff))
|
16 |
+
return False
|
17 |
+
return True
|
18 |
+
|
19 |
+
|
20 |
+
def add_flags(x):
|
21 |
+
out = []
|
22 |
+
for y in x:
|
23 |
+
for trans_b in (False, True):
|
24 |
+
out.append(y + (trans_b, False))
|
25 |
+
|
26 |
+
# TODO: Revisit enabling batch_sizes_on_device
|
27 |
+
# for batch_sizes_on_device in (False, True):
|
28 |
+
# out.append(y + (trans_b, batch_sizes_on_device))
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
_TEST_PROBLEMS = add_flags((
|
33 |
+
(1, 128, 128, 128),
|
34 |
+
(8, 128, 128, 128),
|
35 |
+
(16, 128, 128, 128),
|
36 |
+
(1, 128, 256, 512),
|
37 |
+
(8, 128, 256, 512),
|
38 |
+
(16, 128, 256, 512),
|
39 |
+
))
|
40 |
+
|
41 |
+
|
42 |
+
def randn(bs, x, y):
|
43 |
+
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
|
44 |
+
return out.cuda().to(torch.bfloat16)
|
45 |
+
|
46 |
+
|
47 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
48 |
+
batch_sizes = batch_sizes.cpu().numpy()
|
49 |
+
|
50 |
+
out = []
|
51 |
+
start = 0
|
52 |
+
for i, size in enumerate(batch_sizes):
|
53 |
+
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
|
54 |
+
out.append(a[start:start + size, :] @ rhs)
|
55 |
+
start += size
|
56 |
+
return torch.cat(out)
|
57 |
+
|
58 |
+
|
59 |
+
@parameterized.parameters(*_TEST_PROBLEMS)
|
60 |
+
class OpsTest(parameterized.TestCase):
|
61 |
+
|
62 |
+
def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
|
63 |
+
torch.manual_seed(0)
|
64 |
+
a = randn(z, m, k).view(-1, k)
|
65 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
66 |
+
batch_sizes = torch.tensor([m] * z)
|
67 |
+
if batch_sizes_on_device:
|
68 |
+
batch_sizes = batch_sizes.cuda()
|
69 |
+
|
70 |
+
a.requires_grad_(True)
|
71 |
+
b.requires_grad_(True)
|
72 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
73 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
74 |
+
|
75 |
+
# out = ops.gmm(a, b, batch_sizes, trans_b)
|
76 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
77 |
+
# print("out", out)
|
78 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
79 |
+
self.assertTrue(allclose(out, expected_out))
|
80 |
+
|
81 |
+
# Check gradients.
|
82 |
+
out.sum().backward()
|
83 |
+
expected_out.sum().backward()
|
84 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
85 |
+
self.assertTrue(allclose(b.grad, b_ref.grad))
|
86 |
+
|
87 |
+
def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
|
88 |
+
torch.manual_seed(0)
|
89 |
+
a = randn(z, m, k).view(-1, k)
|
90 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
91 |
+
|
92 |
+
dist = torch.rand(z, )
|
93 |
+
dist /= dist.sum()
|
94 |
+
batch_sizes = (dist * m).to(torch.long)
|
95 |
+
error = m * z - batch_sizes.sum()
|
96 |
+
batch_sizes[-1] += error
|
97 |
+
assert batch_sizes.sum() == (m * z)
|
98 |
+
if batch_sizes_on_device:
|
99 |
+
batch_sizes = batch_sizes.cuda()
|
100 |
+
|
101 |
+
a.requires_grad_(True)
|
102 |
+
b.requires_grad_(True)
|
103 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
104 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
105 |
+
|
106 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
107 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
108 |
+
self.assertTrue(allclose(out, expected_out))
|
109 |
+
|
110 |
+
# Check gradients.
|
111 |
+
out.sum().backward()
|
112 |
+
expected_out.sum().backward()
|
113 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
114 |
+
|
115 |
+
# TODO: Review to ensure that the gradients are correct.
|
116 |
+
# self.assertTrue(allclose(b.grad, b_ref.grad))
|
117 |
+
|
118 |
+
|
119 |
+
# @parameterized.parameters(False, True)
|
120 |
+
@parameterized.parameters(False, False)
|
121 |
+
class EdgeCasesTest(unittest.TestCase):
|
122 |
+
|
123 |
+
def testGroupedGemm_ZeroSize(self, batch_sizes_on_device):
|
124 |
+
torch.manual_seed(0)
|
125 |
+
m = 16384
|
126 |
+
k = 4096
|
127 |
+
n = 14336
|
128 |
+
num_experts = 8
|
129 |
+
|
130 |
+
a = randn(num_experts, m // num_experts, k).view(-1, k)
|
131 |
+
b = randn(num_experts, k, n)
|
132 |
+
batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long)
|
133 |
+
if batch_sizes_on_device:
|
134 |
+
batch_sizes = batch_sizes.cuda()
|
135 |
+
|
136 |
+
a.requires_grad_(True)
|
137 |
+
b.requires_grad_(True)
|
138 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
139 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
140 |
+
|
141 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes)
|
142 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes)
|
143 |
+
self.assertTrue(allclose(out, expected_out))
|
144 |
+
|
145 |
+
# Check gradients.
|
146 |
+
out.sum().backward()
|
147 |
+
expected_out.sum().backward()
|
148 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
149 |
+
self.assertTrue(allclose(b.grad, b_ref.grad))
|
150 |
+
|
151 |
+
def testGroupedGemm_ZeroK(self, batch_sizes_on_device):
|
152 |
+
sz = 128
|
153 |
+
total_tokens = 192
|
154 |
+
|
155 |
+
a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
|
156 |
+
b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
|
157 |
+
c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16)
|
158 |
+
batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long)
|
159 |
+
if batch_sizes_on_device:
|
160 |
+
batch_sizes = batch_sizes.cuda()
|
161 |
+
|
162 |
+
megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c)
|
163 |
+
self.assertTrue((c[0] == 0).all())
|
164 |
+
self.assertTrue((c[1] == 128).all())
|
165 |
+
self.assertTrue((c[2] == 0).all())
|
166 |
+
self.assertTrue((c[3] == 64).all())
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == '__main__':
|
170 |
+
unittest.main()
|
tests/test_gg.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import megablocks
|
3 |
+
|
4 |
+
|
5 |
+
def randn(bs, x, y):
|
6 |
+
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
|
7 |
+
return out.cuda().to(torch.bfloat16)
|
8 |
+
|
9 |
+
|
10 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
11 |
+
batch_sizes = batch_sizes.cpu().numpy()
|
12 |
+
|
13 |
+
out = []
|
14 |
+
start = 0
|
15 |
+
for i, size in enumerate(batch_sizes):
|
16 |
+
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
|
17 |
+
out.append(a[start : start + size, :] @ rhs)
|
18 |
+
start += size
|
19 |
+
return torch.cat(out)
|
20 |
+
|
21 |
+
|
22 |
+
def test_gmm():
|
23 |
+
z = 1
|
24 |
+
m = 128
|
25 |
+
n = 128
|
26 |
+
k = 128
|
27 |
+
trans_b = False
|
28 |
+
batch_sizes_on_device = False
|
29 |
+
# TODO: fix to enable batch_sizes_on_device
|
30 |
+
# batch_sizes_on_device = True
|
31 |
+
|
32 |
+
torch.manual_seed(0)
|
33 |
+
a = randn(z, m, k).view(-1, k)
|
34 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
35 |
+
batch_sizes = torch.tensor([m] * z)
|
36 |
+
if batch_sizes_on_device:
|
37 |
+
batch_sizes = batch_sizes.cuda()
|
38 |
+
|
39 |
+
a.requires_grad_(True)
|
40 |
+
b.requires_grad_(True)
|
41 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
42 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
43 |
+
|
44 |
+
# out = ops.gmm(a, b, batch_sizes, trans_b)
|
45 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
46 |
+
print("out", out)
|
47 |
+
|
48 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
49 |
+
|
50 |
+
assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}"
|
51 |
+
|
52 |
+
out.sum().backward()
|
53 |
+
|
54 |
+
expected_out.sum().backward()
|
55 |
+
assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}"
|
56 |
+
assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}"
|
57 |
+
print("Test passed successfully!")
|
torch-ext/megablocks/__init__.py
CHANGED
@@ -5,11 +5,15 @@ import torch
|
|
5 |
|
6 |
from ._ops import ops
|
7 |
|
8 |
-
from
|
9 |
-
from
|
10 |
-
|
11 |
-
|
12 |
-
from
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# This section contains the direct kernel exports (not inlcuded in the original code)
|
15 |
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
|
|
5 |
|
6 |
from ._ops import ops
|
7 |
|
8 |
+
from .grouped_gemm import backend as gg_backend
|
9 |
+
from .grouped_gemm import ops as gg_ops
|
10 |
+
|
11 |
+
|
12 |
+
from .layers.arguments import Arguments
|
13 |
+
from .layers.dmoe import ParallelDroplessMLP, dMoE
|
14 |
+
from .layers.glu import SparseGLU
|
15 |
+
from .layers.mlp import MLP, SparseMLP
|
16 |
+
from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
|
17 |
|
18 |
# This section contains the direct kernel exports (not inlcuded in the original code)
|
19 |
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
torch-ext/megablocks/grouped_gemm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import ops
|
2 |
+
from . import backend
|
torch-ext/megablocks/grouped_gemm/backend.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NOTE: Torch needs to be imported before the custom
|
2 |
+
# extensions. Otherwise libc10.so cannot be found.
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# # TODO(tgale): Wrap this in a try-block with better
|
6 |
+
# # error message and instructions for building the
|
7 |
+
# # c++ operations.
|
8 |
+
# import grouped_gemm_backend as backend
|
9 |
+
|
10 |
+
# We import the backend operations from the megablocks package as
|
11 |
+
# grouped_gemm is vendored in megablocks in this repository.
|
12 |
+
# from ... import _ops as backend
|
13 |
+
from megablocks._ops import ops as backend # type: ignore
|
14 |
+
|
15 |
+
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
16 |
+
assert not (trans_a and trans_b)
|
17 |
+
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
|
18 |
+
assert a.ndim == 2, "Expected 2d tensor for 'a'"
|
19 |
+
assert b.ndim == (2 if trans_a else 3)
|
20 |
+
|
21 |
+
shape = (
|
22 |
+
(batch_sizes.shape[0], a.shape[1], b.shape[1])
|
23 |
+
if trans_a else
|
24 |
+
(a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
|
25 |
+
)
|
26 |
+
return torch.empty(*shape, device=a.device, dtype=a.dtype)
|
27 |
+
|
28 |
+
def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
|
29 |
+
if c is None:
|
30 |
+
c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
|
31 |
+
backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
|
32 |
+
return c
|
torch-ext/megablocks/grouped_gemm/ops.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import backend
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class GroupedGemm(torch.autograd.Function):
|
6 |
+
|
7 |
+
@staticmethod
|
8 |
+
def forward(ctx, a, b, batch_sizes, trans_b):
|
9 |
+
ctx.save_for_backward(a, b, batch_sizes)
|
10 |
+
ctx.trans_b = trans_b
|
11 |
+
return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def backward(ctx, grad):
|
15 |
+
grad = grad.contiguous()
|
16 |
+
a, b, batch_sizes = ctx.saved_tensors
|
17 |
+
trans_b = ctx.trans_b
|
18 |
+
|
19 |
+
agrad = None
|
20 |
+
if ctx.needs_input_grad[0]:
|
21 |
+
agrad = backend.gmm(
|
22 |
+
grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
|
23 |
+
|
24 |
+
bgrad = None
|
25 |
+
if ctx.needs_input_grad[1]:
|
26 |
+
lhs, rhs = (grad, a) if trans_b else (a, grad)
|
27 |
+
bgrad = backend.gmm(
|
28 |
+
lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
|
29 |
+
return agrad, bgrad, None, None
|
30 |
+
|
31 |
+
|
32 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
33 |
+
return GroupedGemm.apply(a, b, batch_sizes, trans_b)
|
torch-ext/megablocks/grouped_gemm_util.py
CHANGED
@@ -4,7 +4,8 @@ import warnings
|
|
4 |
|
5 |
_grouped_gemm_is_available: bool = False
|
6 |
try:
|
7 |
-
import grouped_gemm
|
|
|
8 |
_grouped_gemm_is_available = True
|
9 |
except ImportError as error:
|
10 |
warnings.warn('Grouped GEMM not available.')
|
@@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available():
|
|
22 |
assert _grouped_gemm_is_available, msg
|
23 |
|
24 |
|
25 |
-
backend = grouped_gemm.backend if grouped_gemm_is_available() else None
|
26 |
-
ops = grouped_gemm.ops if grouped_gemm_is_available() else None
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
_grouped_gemm_is_available: bool = False
|
6 |
try:
|
7 |
+
# import grouped_gemm
|
8 |
+
pass
|
9 |
_grouped_gemm_is_available = True
|
10 |
except ImportError as error:
|
11 |
warnings.warn('Grouped GEMM not available.')
|
|
|
23 |
assert _grouped_gemm_is_available, msg
|
24 |
|
25 |
|
26 |
+
# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
|
27 |
+
# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
|
28 |
+
|
29 |
+
|
30 |
+
from .grouped_gemm import backend as ops
|
31 |
+
from .grouped_gemm import ops as backend
|
torch-ext/megablocks/layers/__init__.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# SPDX-License-Identifier: Apache-2.0
|
3 |
|
4 |
# from megablocks.layers.dmoe import dMoE
|
5 |
-
from
|
6 |
|
7 |
__all__ = [
|
8 |
'MoE',
|
|
|
2 |
# SPDX-License-Identifier: Apache-2.0
|
3 |
|
4 |
# from megablocks.layers.dmoe import dMoE
|
5 |
+
from .moe import MoE
|
6 |
|
7 |
__all__ = [
|
8 |
'MoE',
|
torch-ext/torch_binding.cpp
CHANGED
@@ -9,6 +9,8 @@
|
|
9 |
#include "new_replicate.h"
|
10 |
#include "new_sort.h"
|
11 |
|
|
|
|
|
12 |
// void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
13 |
torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
|
14 |
megablocks::exclusive_cumsum(x, dim, out);
|
@@ -70,6 +72,12 @@ torch::Tensor sort_wrapper(torch::Tensor x, int64_t end_bit, torch::Tensor x_out
|
|
70 |
return x_out;
|
71 |
}
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
// Reference implementation:
|
74 |
//
|
75 |
// m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
|
@@ -101,6 +109,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
101 |
|
102 |
ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
|
103 |
ops.impl("sort", torch::kCUDA, &sort_wrapper);
|
|
|
|
|
|
|
|
|
104 |
}
|
105 |
|
106 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
9 |
#include "new_replicate.h"
|
10 |
#include "new_sort.h"
|
11 |
|
12 |
+
#include "grouped_gemm/grouped_gemm.h"
|
13 |
+
|
14 |
// void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
15 |
torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
|
16 |
megablocks::exclusive_cumsum(x, dim, out);
|
|
|
72 |
return x_out;
|
73 |
}
|
74 |
|
75 |
+
// GroupedGemm operation
|
76 |
+
torch::Tensor gmm(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes, bool trans_a, bool trans_b) {
|
77 |
+
grouped_gemm::GroupedGemm(a, b, c, batch_sizes, trans_a, trans_b);
|
78 |
+
return c;
|
79 |
+
}
|
80 |
+
|
81 |
// Reference implementation:
|
82 |
//
|
83 |
// m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
|
|
|
109 |
|
110 |
ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
|
111 |
ops.impl("sort", torch::kCUDA, &sort_wrapper);
|
112 |
+
|
113 |
+
// Register the gmm GroupedGemm operation
|
114 |
+
ops.def("gmm(Tensor (a!) a, Tensor (b!) b, Tensor(c!) c, Tensor batch_sizes, bool trans_a, bool trans_b) -> Tensor(c!)");
|
115 |
+
ops.impl("gmm", torch::kCUDA, &gmm);
|
116 |
}
|
117 |
|
118 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|