kernel
drbh commited on
Commit
2595c46
·
0 Parent(s):

feat: initial port of megablocks to builder format

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +5 -0
  3. README.md +6 -0
  4. build.toml +30 -0
  5. csrc/bak.ops.cu +21 -0
  6. csrc/cuda_util.h +62 -0
  7. csrc/cumsum.h +163 -0
  8. csrc/histogram.h +86 -0
  9. csrc/indices.h +95 -0
  10. csrc/new_cumsum.cu +161 -0
  11. csrc/new_cumsum.h +11 -0
  12. csrc/new_histogram.cu +85 -0
  13. csrc/new_histogram.h +10 -0
  14. csrc/new_indices.cu +97 -0
  15. csrc/new_indices.h +14 -0
  16. csrc/new_replicate.cu +210 -0
  17. csrc/new_replicate.h +17 -0
  18. csrc/new_sort.cu +90 -0
  19. csrc/new_sort.h +13 -0
  20. csrc/replicate.h +211 -0
  21. csrc/sort.h +91 -0
  22. flake.lock +164 -0
  23. flake.nix +18 -0
  24. tests/__init__.py +0 -0
  25. tests/test_mb_moe.py +6 -0
  26. torch-ext/megablocks/__init__.py +191 -0
  27. torch-ext/megablocks/_version.py +6 -0
  28. torch-ext/megablocks/backend/__init__.py +2 -0
  29. torch-ext/megablocks/backend/kernels.py +543 -0
  30. torch-ext/megablocks/bak.__init__.py +23 -0
  31. torch-ext/megablocks/benchmark_util.py +35 -0
  32. torch-ext/megablocks/grouped_gemm_util.py +26 -0
  33. torch-ext/megablocks/layers/__init__.py +10 -0
  34. torch-ext/megablocks/layers/activation_fn.py +33 -0
  35. torch-ext/megablocks/layers/all_to_all.py +54 -0
  36. torch-ext/megablocks/layers/arguments.py +100 -0
  37. torch-ext/megablocks/layers/common.py +26 -0
  38. torch-ext/megablocks/layers/dmlp_registry.py +42 -0
  39. torch-ext/megablocks/layers/dmoe.py +327 -0
  40. torch-ext/megablocks/layers/gelu.py +43 -0
  41. torch-ext/megablocks/layers/glu.py +223 -0
  42. torch-ext/megablocks/layers/memory_test.py +102 -0
  43. torch-ext/megablocks/layers/memory_test.sh +12 -0
  44. torch-ext/megablocks/layers/mlp.py +574 -0
  45. torch-ext/megablocks/layers/moe.py +475 -0
  46. torch-ext/megablocks/layers/mpu.py +93 -0
  47. torch-ext/megablocks/layers/router.py +114 -0
  48. torch-ext/megablocks/layers/sharedexpert_registry.py +30 -0
  49. torch-ext/megablocks/ops/__init__.py +35 -0
  50. torch-ext/megablocks/ops/all_to_all_benchmark.py +60 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .venv
2
+ __pycache__
3
+ .bak
4
+ megablocks-moe/.bak
5
+ .pytest_cache
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - kernel
5
+ ---
6
+
build.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "megablocks"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h"
9
+ ]
10
+
11
+ [kernel.megablocks]
12
+ backend = "cuda"
13
+ src = [
14
+ "csrc/new_cumsum.h",
15
+ "csrc/new_cumsum.cu",
16
+ "csrc/new_histogram.h",
17
+ "csrc/new_histogram.cu",
18
+ "csrc/new_indices.h",
19
+ "csrc/new_indices.cu",
20
+ "csrc/new_replicate.cu",
21
+ "csrc/new_replicate.h",
22
+ "csrc/new_sort.h",
23
+ "csrc/new_sort.cu",
24
+ ]
25
+ depends = [ "torch", "cutlass_3_8" ]
26
+
27
+ [test]
28
+ python-git-packages = [
29
+ { url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" }
30
+ ]
csrc/bak.ops.cu ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "cumsum.h"
2
+ #include "histogram.h"
3
+ #include "indices.h"
4
+ #include "replicate.h"
5
+ #include "sort.h"
6
+
7
+ #include <torch/extension.h>
8
+
9
+ namespace megablocks {
10
+
11
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12
+ m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
13
+ m.def("histogram", &histogram, "even width histogram.");
14
+ m.def("inclusive_cumsum", &inclusive_cumsum, "batched inclusive cumsum");
15
+ m.def("indices", &indices, "indices construction for sparse matrix.");
16
+ m.def("replicate_forward", &replicate_forward, "(fwd) replicate a vector dynamically.");
17
+ m.def("replicate_backward", &replicate_backward, "(bwd) replicate a vector dynamically.");
18
+ m.def("sort", &sort, "key/value sort.");
19
+ }
20
+
21
+ } // namespace megablocks
csrc/cuda_util.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef BLOCKPARTY_CSRC_CUDA_UTIL_H_
2
+ #define BLOCKPARTY_CSRC_CUDA_UTIL_H_
3
+
4
+ #include <cuda_fp16.h>
5
+ #include <cuda_runtime.h>
6
+ // #include <torch/extension.h>
7
+
8
+ namespace megablocks {
9
+
10
+ typedef __half2 half2;
11
+
12
+ struct __align__(8) half4 {
13
+ half2 x, y;
14
+ };
15
+
16
+ struct __align__(16) half8 {
17
+ half2 x, y, z, w;
18
+ };
19
+
20
+ template <class To, class From>
21
+ __device__ __forceinline__ To BitCast(const From& src) noexcept {
22
+ To dst;
23
+ std::memcpy(&dst, &src, sizeof(To));
24
+ return dst;
25
+ }
26
+
27
+ template <typename T>
28
+ __device__ __forceinline__ void Store(const T& value, T* ptr) {
29
+ *ptr = value;
30
+ }
31
+
32
+ template <typename T>
33
+ __device__ __forceinline__ T Load(const T* address) {
34
+ return __ldg(address);
35
+ }
36
+
37
+ __device__ __forceinline__ half4 Load(const half4* address) {
38
+ float2 x = __ldg(reinterpret_cast<const float2*>(address));
39
+ return BitCast<half4>(x);
40
+ }
41
+
42
+ __device__ __forceinline__ half8 Load(const half8* address) {
43
+ float4 x = __ldg(reinterpret_cast<const float4*>(address));
44
+ return BitCast<half8>(x);
45
+ }
46
+
47
+ template <typename T>
48
+ __device__ __forceinline__ T Zero() { return 0; };
49
+
50
+ template <>
51
+ __device__ __forceinline__ half2 Zero<half2>() {
52
+ return {(c10::Half)0., (c10::Half)0.};
53
+ };
54
+
55
+ template <>
56
+ __device__ __forceinline__ half4 Zero<half4>() {
57
+ return {Zero<half2>(), Zero<half2>()};
58
+ };
59
+
60
+ } // namespace megablocks
61
+
62
+ #endif // BLOCKPARTY_CSRC_CUDA_UTIL_H_
csrc/cumsum.h ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define CUB_IGNORE_DEPRECATED_API
2
+
3
+ #undef CUB_WRAPPED_NAMESPACE
4
+ #define CUB_WRAPPED_NAMESPACE megablocks
5
+
6
+ #include <cstdint>
7
+
8
+ #include <cub/cub.cuh>
9
+ #include <c10/cuda/CUDAStream.h>
10
+ #include <torch/all.h>
11
+ // #include <torch/extension.h>
12
+
13
+ #define CUDA_CALL(code) \
14
+ do { \
15
+ cudaError_t status = code; \
16
+ std::string err = cudaGetErrorString(status); \
17
+ TORCH_CHECK(status == cudaSuccess, err); \
18
+ } while (0)
19
+
20
+ namespace megablocks {
21
+
22
+ struct Inclusive {};
23
+ struct Exclusive {};
24
+
25
+ template <typename Type> struct Cumsum {
26
+
27
+ template<
28
+ typename InputIteratorT,
29
+ typename OutputIteratorT>
30
+ static void Run(void * d_temp_storage,
31
+ size_t & temp_storage_bytes,
32
+ InputIteratorT d_in,
33
+ OutputIteratorT d_out,
34
+ int num_items,
35
+ cudaStream_t stream = 0,
36
+ bool debug_synchronous = false) {
37
+ CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage,
38
+ temp_storage_bytes,
39
+ d_in,
40
+ d_out,
41
+ num_items,
42
+ stream));//,
43
+ //debug_synchronous));
44
+ }
45
+ };
46
+
47
+ template <> struct Cumsum<Inclusive> {
48
+ template<
49
+ typename InputIteratorT,
50
+ typename OutputIteratorT>
51
+ static void Run(void * d_temp_storage,
52
+ size_t & temp_storage_bytes,
53
+ InputIteratorT d_in,
54
+ OutputIteratorT d_out,
55
+ int num_items,
56
+ cudaStream_t stream = 0,
57
+ bool debug_synchronous = false) {
58
+ CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage,
59
+ temp_storage_bytes,
60
+ d_in,
61
+ d_out,
62
+ num_items,
63
+ stream));//,
64
+ //debug_synchronous));
65
+ }
66
+ };
67
+
68
+ template <typename SumType, typename T>
69
+ void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
70
+ // Get temporary storage size.
71
+ size_t scratchpad_bytes = 0;
72
+ Cumsum<SumType>::Run(nullptr,
73
+ scratchpad_bytes,
74
+ x.data_ptr<T>(),
75
+ out.data_ptr<T>(),
76
+ x.size(1),
77
+ c10::cuda::getCurrentCUDAStream());
78
+
79
+ // Allocate scratchpad.
80
+ //
81
+ // NOTE: We scale for the batch dimension so we can run in parallel.
82
+ auto options = torch::TensorOptions()
83
+ .dtype(torch::kInt8)
84
+ .device(x.device());
85
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
86
+ options);
87
+
88
+ // Run the kernel.
89
+ //
90
+ // NOTE: Using different streams for each issue does not appear to
91
+ // yield performance gains for our problem set. The overhead of
92
+ // event/stream synchronization appears to outweigh the benfits.
93
+ // We could write a true batched cumsum, but this would require
94
+ // significant code duplication from cub and we might move away
95
+ // from this formulation anyways.
96
+ for (int i = 0; i < x.size(0); ++i) {
97
+ void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
98
+ Cumsum<SumType>::Run(scratchpad_ptr,
99
+ scratchpad_bytes,
100
+ x.data_ptr<T>() + x.size(1) * i,
101
+ out.data_ptr<T>() + x.size(1) * i,
102
+ x.size(1),
103
+ c10::cuda::getCurrentCUDAStream());
104
+ }
105
+ }
106
+
107
+ void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
108
+ // Validate the input matrix.
109
+ TORCH_CHECK(x.is_cuda());
110
+ TORCH_CHECK(x.ndimension() == 2);
111
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
112
+ x.scalar_type() == torch::kInt32 ||
113
+ x.scalar_type() == torch::kInt64);
114
+ TORCH_CHECK(out.is_cuda());
115
+ TORCH_CHECK(out.ndimension() == 2);
116
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
117
+
118
+ // NOTE: We currently only support contraction across the contiguous
119
+ // dimension in the matrix.
120
+ TORCH_CHECK(dim == 1);
121
+
122
+ switch (x.scalar_type()) {
123
+ case torch::kInt16:
124
+ cub_cumsum<Exclusive, short>(x, dim, out);
125
+ return;
126
+ case torch::kInt32:
127
+ cub_cumsum<Exclusive, int>(x, dim, out);
128
+ return;
129
+ }
130
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
131
+ cub_cumsum<Exclusive, long>(x, dim, out);
132
+ }
133
+
134
+ void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
135
+ // Validate the input matrix.
136
+ TORCH_CHECK(x.is_cuda());
137
+ TORCH_CHECK(x.ndimension() == 2);
138
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
139
+ x.scalar_type() == torch::kInt32 ||
140
+ x.scalar_type() == torch::kInt64);
141
+ TORCH_CHECK(out.is_cuda());
142
+ TORCH_CHECK(out.ndimension() == 2);
143
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
144
+
145
+ // NOTE: We currently only support contraction across the contiguous
146
+ // dimension in the matrix.
147
+ TORCH_CHECK(dim == 1);
148
+
149
+ switch (x.scalar_type()) {
150
+ case torch::kInt16:
151
+ cub_cumsum<Inclusive, short>(x, dim, out);
152
+ return;
153
+ case torch::kInt32:
154
+ cub_cumsum<Inclusive, int>(x, dim, out);
155
+ return;
156
+ }
157
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
158
+ cub_cumsum<Inclusive, long>(x, dim, out);
159
+ }
160
+
161
+ } // namespace megablocks
162
+
163
+ #undef CUB_WRAPPED_NAMESPACE
csrc/histogram.h ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include <cstdint>
5
+
6
+ #include <cub/cub.cuh>
7
+ #include <c10/cuda/CUDAStream.h>
8
+ // #include <torch/extension.h>
9
+
10
+ #define CUDA_CALL(code) \
11
+ do { \
12
+ cudaError_t status = code; \
13
+ std::string err = cudaGetErrorString(status); \
14
+ TORCH_CHECK(status == cudaSuccess, err); \
15
+ } while (0)
16
+
17
+ namespace megablocks {
18
+
19
+ template <typename T>
20
+ torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
21
+ // Allocate the count buffer.
22
+ auto options = torch::TensorOptions()
23
+ .dtype(torch::kInt32)
24
+ .device(x.device());
25
+ torch::Tensor out = torch::empty({x.size(0), num_bins}, options);
26
+
27
+ // Exit early if there is not work to do.
28
+ if (out.numel() == 0) return out;
29
+
30
+ // Get scratchpad size.
31
+ size_t scratchpad_bytes = 0;
32
+ CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr,
33
+ scratchpad_bytes,
34
+ x.data_ptr<T>(),
35
+ out.data_ptr<int>(),
36
+ /*num_levels=*/num_bins + 1,
37
+ /*lower_level=*/0,
38
+ /*upper_level=*/num_bins,
39
+ /*num_samples=*/int(x.size(1)),
40
+ c10::cuda::getCurrentCUDAStream()));
41
+
42
+ // Allocate scratchpad.
43
+ options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
44
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
45
+
46
+ // Run the kernel.
47
+ for (int i = 0; i < x.size(0); ++i) {
48
+ CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
49
+ scratchpad_bytes,
50
+ x.data_ptr<T>() + x.size(1) * i,
51
+ out.data_ptr<int>() + out.size(1) * i,
52
+ /*num_levels=*/num_bins + 1,
53
+ /*lower_level=*/0,
54
+ /*upper_level=*/num_bins,
55
+ /*num_samples=*/int(x.size(1)),
56
+ c10::cuda::getCurrentCUDAStream()));
57
+ }
58
+ return out;
59
+ }
60
+
61
+ torch::Tensor histogram(torch::Tensor x, int num_bins) {
62
+ TORCH_CHECK(x.is_cuda());
63
+ TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
64
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
65
+ x.scalar_type() == torch::kInt32 ||
66
+ x.scalar_type() == torch::kInt64);
67
+ bool no_batch = x.ndimension() == 1;
68
+ if (no_batch) x = x.view({1, x.numel()});
69
+
70
+ if (x.scalar_type() == torch::kInt16) {
71
+ auto out = cub_histogram<short>(x, num_bins);
72
+ return no_batch ? out.flatten() : out;
73
+ } else if (x.scalar_type() == torch::kInt32) {
74
+ auto out = cub_histogram<int>(x, num_bins);
75
+ return no_batch ? out.flatten() : out;
76
+ } else {
77
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
78
+ auto out = cub_histogram<long>(x, num_bins);
79
+ return no_batch ? out.flatten() : out;
80
+ }
81
+ }
82
+
83
+ } // namespace megablocks
84
+
85
+ #undef CUDA_CALL
86
+ #undef CUB_WRAPPED_NAMESPACE
csrc/indices.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cstdint>
2
+ #include <c10/util/Half.h>
3
+ // #include <torch/extension.h>
4
+ #include <c10/cuda/CUDAStream.h>
5
+
6
+ #define CUDA_CALL(code) \
7
+ do { \
8
+ cudaError_t status = code; \
9
+ std::string err = cudaGetErrorString(status); \
10
+ TORCH_CHECK(status == cudaSuccess, err); \
11
+ } while (0)
12
+
13
+ namespace megablocks {
14
+ namespace construct_indices {
15
+
16
+ // We expect the number of outputs per block to be small. For
17
+ // example, with ffn_hidden_size=4096, we only need to write
18
+ // 32 elements per block per iteration.
19
+ const int kThreadsPerBlock = 32;
20
+
21
+ __global__ void __launch_bounds__(kThreadsPerBlock)
22
+ ConstructIndicesKernel(short * __restrict__ indices,
23
+ int num_columns,
24
+ int block_size,
25
+ const int * __restrict__ padded_bins) {
26
+ // Load the offset for this bins indices.
27
+ int start = 0;
28
+ if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1);
29
+ int end = __ldg(padded_bins + blockIdx.x);
30
+
31
+ // Divide the start and end into blocks.
32
+ start /= block_size;
33
+ end /= block_size;
34
+
35
+ // Offset the output buffer to the start of the bin.
36
+ indices += (start + blockIdx.y) * num_columns + threadIdx.x;
37
+
38
+ // Write the indices to the output.
39
+ int bin_offset = blockIdx.y;
40
+ int num_rows = end - start;
41
+ for (; bin_offset < num_rows; num_rows -= gridDim.y) {
42
+ short *out = indices;
43
+ for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
44
+ *out = bid + (blockIdx.x * num_columns);
45
+ out += kThreadsPerBlock;
46
+ }
47
+ indices += gridDim.y * num_columns;
48
+ }
49
+ }
50
+
51
+ cudaError_t ConstructIndices(short * __restrict__ indices,
52
+ int output_block_rows,
53
+ int output_block_columns,
54
+ int block_size,
55
+ const int * __restrict__ padded_bins,
56
+ int num_bins,
57
+ cudaStream_t stream) {
58
+ dim3 block_dim(kThreadsPerBlock);
59
+ dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins));
60
+ ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices,
61
+ output_block_columns,
62
+ block_size,
63
+ padded_bins);
64
+ return cudaGetLastError();
65
+ }
66
+
67
+ } // namespace construct_indices
68
+
69
+ void indices(torch::Tensor padded_bins,
70
+ int block_size,
71
+ int output_block_rows,
72
+ int output_block_columns,
73
+ torch::Tensor out) {
74
+ TORCH_CHECK(padded_bins.is_cuda());
75
+ TORCH_CHECK(padded_bins.ndimension() == 1);
76
+ TORCH_CHECK(padded_bins.scalar_type() == torch::kInt);
77
+
78
+ TORCH_CHECK(out.is_cuda());
79
+ TORCH_CHECK(out.ndimension() == 1);
80
+ TORCH_CHECK(out.scalar_type() == torch::kInt16);
81
+ TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns));
82
+
83
+ // Exit early if there is no work to do.
84
+ if (out.numel() == 0) return;
85
+
86
+ CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(),
87
+ output_block_rows,
88
+ output_block_columns,
89
+ block_size,
90
+ padded_bins.data_ptr<int>(),
91
+ padded_bins.numel(),
92
+ c10::cuda::getCurrentCUDAStream()));
93
+ }
94
+
95
+ } // namespace megablocks
csrc/new_cumsum.cu ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define CUB_IGNORE_DEPRECATED_API
2
+
3
+ #undef CUB_WRAPPED_NAMESPACE
4
+ #define CUB_WRAPPED_NAMESPACE megablocks
5
+
6
+ #include "new_cumsum.h"
7
+ #include <cstdint>
8
+ #include <cub/cub.cuh>
9
+ #include <c10/cuda/CUDAStream.h>
10
+
11
+ #define CUDA_CALL(code) \
12
+ do { \
13
+ cudaError_t status = code; \
14
+ std::string err = cudaGetErrorString(status); \
15
+ TORCH_CHECK(status == cudaSuccess, err); \
16
+ } while (0)
17
+
18
+ namespace megablocks {
19
+
20
+ struct Inclusive {};
21
+ struct Exclusive {};
22
+
23
+ template <typename Type> struct Cumsum {
24
+
25
+ template<
26
+ typename InputIteratorT,
27
+ typename OutputIteratorT>
28
+ static void Run(void * d_temp_storage,
29
+ size_t & temp_storage_bytes,
30
+ InputIteratorT d_in,
31
+ OutputIteratorT d_out,
32
+ int num_items,
33
+ cudaStream_t stream = 0,
34
+ bool debug_synchronous = false) {
35
+ CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage,
36
+ temp_storage_bytes,
37
+ d_in,
38
+ d_out,
39
+ num_items,
40
+ stream));//,
41
+ //debug_synchronous));
42
+ }
43
+ };
44
+
45
+ template <> struct Cumsum<Inclusive> {
46
+ template<
47
+ typename InputIteratorT,
48
+ typename OutputIteratorT>
49
+ static void Run(void * d_temp_storage,
50
+ size_t & temp_storage_bytes,
51
+ InputIteratorT d_in,
52
+ OutputIteratorT d_out,
53
+ int num_items,
54
+ cudaStream_t stream = 0,
55
+ bool debug_synchronous = false) {
56
+ CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage,
57
+ temp_storage_bytes,
58
+ d_in,
59
+ d_out,
60
+ num_items,
61
+ stream));//,
62
+ //debug_synchronous));
63
+ }
64
+ };
65
+
66
+ template <typename SumType, typename T>
67
+ void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
68
+ // Get temporary storage size.
69
+ size_t scratchpad_bytes = 0;
70
+ Cumsum<SumType>::Run(nullptr,
71
+ scratchpad_bytes,
72
+ x.data_ptr<T>(),
73
+ out.data_ptr<T>(),
74
+ x.size(1),
75
+ c10::cuda::getCurrentCUDAStream());
76
+
77
+ // Allocate scratchpad.
78
+ //
79
+ // NOTE: We scale for the batch dimension so we can run in parallel.
80
+ auto options = torch::TensorOptions()
81
+ .dtype(torch::kInt8)
82
+ .device(x.device());
83
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
84
+ options);
85
+
86
+ // Run the kernel.
87
+ //
88
+ // NOTE: Using different streams for each issue does not appear to
89
+ // yield performance gains for our problem set. The overhead of
90
+ // event/stream synchronization appears to outweigh the benfits.
91
+ // We could write a true batched cumsum, but this would require
92
+ // significant code duplication from cub and we might move away
93
+ // from this formulation anyways.
94
+ for (int i = 0; i < x.size(0); ++i) {
95
+ void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
96
+ Cumsum<SumType>::Run(scratchpad_ptr,
97
+ scratchpad_bytes,
98
+ x.data_ptr<T>() + x.size(1) * i,
99
+ out.data_ptr<T>() + x.size(1) * i,
100
+ x.size(1),
101
+ c10::cuda::getCurrentCUDAStream());
102
+ }
103
+ }
104
+
105
+ void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
106
+ // Validate the input matrix.
107
+ TORCH_CHECK(x.is_cuda());
108
+ TORCH_CHECK(x.ndimension() == 2);
109
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
110
+ x.scalar_type() == torch::kInt32 ||
111
+ x.scalar_type() == torch::kInt64);
112
+ TORCH_CHECK(out.is_cuda());
113
+ TORCH_CHECK(out.ndimension() == 2);
114
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
115
+
116
+ // NOTE: We currently only support contraction across the contiguous
117
+ // dimension in the matrix.
118
+ TORCH_CHECK(dim == 1);
119
+
120
+ switch (x.scalar_type()) {
121
+ case torch::kInt16:
122
+ cub_cumsum<Exclusive, short>(x, dim, out);
123
+ return;
124
+ case torch::kInt32:
125
+ cub_cumsum<Exclusive, int>(x, dim, out);
126
+ return;
127
+ }
128
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
129
+ cub_cumsum<Exclusive, long>(x, dim, out);
130
+ }
131
+
132
+ void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
133
+ // Validate the input matrix.
134
+ TORCH_CHECK(x.is_cuda());
135
+ TORCH_CHECK(x.ndimension() == 2);
136
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
137
+ x.scalar_type() == torch::kInt32 ||
138
+ x.scalar_type() == torch::kInt64);
139
+ TORCH_CHECK(out.is_cuda());
140
+ TORCH_CHECK(out.ndimension() == 2);
141
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
142
+
143
+ // NOTE: We currently only support contraction across the contiguous
144
+ // dimension in the matrix.
145
+ TORCH_CHECK(dim == 1);
146
+
147
+ switch (x.scalar_type()) {
148
+ case torch::kInt16:
149
+ cub_cumsum<Inclusive, short>(x, dim, out);
150
+ return;
151
+ case torch::kInt32:
152
+ cub_cumsum<Inclusive, int>(x, dim, out);
153
+ return;
154
+ }
155
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
156
+ cub_cumsum<Inclusive, long>(x, dim, out);
157
+ }
158
+
159
+ } // namespace megablocks
160
+
161
+ #undef CUB_WRAPPED_NAMESPACE
csrc/new_cumsum.h ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Forward declarations for the public interface functions
8
+ void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
9
+ void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
10
+
11
+ } // namespace megablocks
csrc/new_histogram.cu ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include "new_histogram.h"
5
+ #include <cstdint>
6
+ #include <cub/cub.cuh>
7
+ #include <c10/cuda/CUDAStream.h>
8
+
9
+ #define CUDA_CALL(code) \
10
+ do { \
11
+ cudaError_t status = code; \
12
+ std::string err = cudaGetErrorString(status); \
13
+ TORCH_CHECK(status == cudaSuccess, err); \
14
+ } while (0)
15
+
16
+ namespace megablocks {
17
+
18
+ template <typename T>
19
+ torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
20
+ // Allocate the count buffer.
21
+ auto options = torch::TensorOptions()
22
+ .dtype(torch::kInt32)
23
+ .device(x.device());
24
+ torch::Tensor out = torch::empty({x.size(0), num_bins}, options);
25
+
26
+ // Exit early if there is not work to do.
27
+ if (out.numel() == 0) return out;
28
+
29
+ // Get scratchpad size.
30
+ size_t scratchpad_bytes = 0;
31
+ CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr,
32
+ scratchpad_bytes,
33
+ x.data_ptr<T>(),
34
+ out.data_ptr<int>(),
35
+ /*num_levels=*/num_bins + 1,
36
+ /*lower_level=*/0,
37
+ /*upper_level=*/num_bins,
38
+ /*num_samples=*/int(x.size(1)),
39
+ c10::cuda::getCurrentCUDAStream()));
40
+
41
+ // Allocate scratchpad.
42
+ options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
43
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
44
+
45
+ // Run the kernel.
46
+ for (int i = 0; i < x.size(0); ++i) {
47
+ CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
48
+ scratchpad_bytes,
49
+ x.data_ptr<T>() + x.size(1) * i,
50
+ out.data_ptr<int>() + out.size(1) * i,
51
+ /*num_levels=*/num_bins + 1,
52
+ /*lower_level=*/0,
53
+ /*upper_level=*/num_bins,
54
+ /*num_samples=*/int(x.size(1)),
55
+ c10::cuda::getCurrentCUDAStream()));
56
+ }
57
+ return out;
58
+ }
59
+
60
+ torch::Tensor histogram(torch::Tensor x, int num_bins) {
61
+ TORCH_CHECK(x.is_cuda());
62
+ TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
63
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
64
+ x.scalar_type() == torch::kInt32 ||
65
+ x.scalar_type() == torch::kInt64);
66
+ bool no_batch = x.ndimension() == 1;
67
+ if (no_batch) x = x.view({1, x.numel()});
68
+
69
+ if (x.scalar_type() == torch::kInt16) {
70
+ auto out = cub_histogram<short>(x, num_bins);
71
+ return no_batch ? out.flatten() : out;
72
+ } else if (x.scalar_type() == torch::kInt32) {
73
+ auto out = cub_histogram<int>(x, num_bins);
74
+ return no_batch ? out.flatten() : out;
75
+ } else {
76
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
77
+ auto out = cub_histogram<long>(x, num_bins);
78
+ return no_batch ? out.flatten() : out;
79
+ }
80
+ }
81
+
82
+ } // namespace megablocks
83
+
84
+ #undef CUDA_CALL
85
+ #undef CUB_WRAPPED_NAMESPACE
csrc/new_histogram.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Public interface function for computing histograms
8
+ torch::Tensor histogram(torch::Tensor x, int num_bins);
9
+
10
+ } // namespace megablocks
csrc/new_indices.cu ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "new_indices.h"
2
+ #include <cstdint>
3
+ #include <c10/util/Half.h>
4
+ #include <c10/cuda/CUDAStream.h>
5
+
6
+ #define CUDA_CALL(code) \
7
+ do { \
8
+ cudaError_t status = code; \
9
+ std::string err = cudaGetErrorString(status); \
10
+ TORCH_CHECK(status == cudaSuccess, err); \
11
+ } while (0)
12
+
13
+ namespace megablocks {
14
+ namespace construct_indices {
15
+
16
+ // We expect the number of outputs per block to be small. For
17
+ // example, with ffn_hidden_size=4096, we only need to write
18
+ // 32 elements per block per iteration.
19
+ const int kThreadsPerBlock = 32;
20
+
21
+ __global__ void __launch_bounds__(kThreadsPerBlock)
22
+ ConstructIndicesKernel(short * __restrict__ indices,
23
+ int num_columns,
24
+ int block_size,
25
+ const int * __restrict__ padded_bins) {
26
+ // Load the offset for this bins indices.
27
+ int start = 0;
28
+ if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1);
29
+ int end = __ldg(padded_bins + blockIdx.x);
30
+
31
+ // Divide the start and end into blocks.
32
+ start /= block_size;
33
+ end /= block_size;
34
+
35
+ // Offset the output buffer to the start of the bin.
36
+ indices += (start + blockIdx.y) * num_columns + threadIdx.x;
37
+
38
+ // Write the indices to the output.
39
+ int bin_offset = blockIdx.y;
40
+ int num_rows = end - start;
41
+ for (; bin_offset < num_rows; num_rows -= gridDim.y) {
42
+ short *out = indices;
43
+ for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
44
+ *out = bid + (blockIdx.x * num_columns);
45
+ out += kThreadsPerBlock;
46
+ }
47
+ indices += gridDim.y * num_columns;
48
+ }
49
+ }
50
+
51
+ cudaError_t ConstructIndices(short * __restrict__ indices,
52
+ int output_block_rows,
53
+ int output_block_columns,
54
+ int block_size,
55
+ const int * __restrict__ padded_bins,
56
+ int num_bins,
57
+ cudaStream_t stream) {
58
+ dim3 block_dim(kThreadsPerBlock);
59
+ dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins));
60
+ ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices,
61
+ output_block_columns,
62
+ block_size,
63
+ padded_bins);
64
+ return cudaGetLastError();
65
+ }
66
+
67
+ } // namespace construct_indices
68
+
69
+ void indices(torch::Tensor padded_bins,
70
+ int block_size,
71
+ int output_block_rows,
72
+ int output_block_columns,
73
+ torch::Tensor out) {
74
+ TORCH_CHECK(padded_bins.is_cuda());
75
+ TORCH_CHECK(padded_bins.ndimension() == 1);
76
+ TORCH_CHECK(padded_bins.scalar_type() == torch::kInt);
77
+
78
+ TORCH_CHECK(out.is_cuda());
79
+ TORCH_CHECK(out.ndimension() == 1);
80
+ TORCH_CHECK(out.scalar_type() == torch::kInt16);
81
+ TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns));
82
+
83
+ // Exit early if there is no work to do.
84
+ if (out.numel() == 0) return;
85
+
86
+ CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(),
87
+ output_block_rows,
88
+ output_block_columns,
89
+ block_size,
90
+ padded_bins.data_ptr<int>(),
91
+ padded_bins.numel(),
92
+ c10::cuda::getCurrentCUDAStream()));
93
+ }
94
+
95
+ } // namespace megablocks
96
+
97
+ #undef CUDA_CALL
csrc/new_indices.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Public interface function for constructing indices from padded bins
8
+ void indices(torch::Tensor padded_bins,
9
+ int block_size,
10
+ int output_block_rows,
11
+ int output_block_columns,
12
+ torch::Tensor out);
13
+
14
+ } // namespace megablocks
csrc/new_replicate.cu ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include "new_replicate.h"
5
+ #include <cstdint>
6
+ #include <cub/cub.cuh>
7
+ #include <c10/util/Half.h>
8
+ #include <c10/cuda/CUDAStream.h>
9
+
10
+ #define CUDA_CALL(code) \
11
+ do { \
12
+ cudaError_t status = code; \
13
+ std::string err = cudaGetErrorString(status); \
14
+ TORCH_CHECK(status == cudaSuccess, err); \
15
+ } while (0)
16
+
17
+ namespace megablocks {
18
+ namespace replicate {
19
+
20
+ template <typename T, int kThreadsPerBlock>
21
+ __global__ void __launch_bounds__(kThreadsPerBlock)
22
+ ReplicateForwardKernel(T * __restrict__ x,
23
+ int * __restrict__ bins,
24
+ T * __restrict__ out,
25
+ int columns) {
26
+ // Offset to this threadblocks batch.
27
+ //
28
+ // x is [batch_size, num_bins]
29
+ // out is [batch_size, columns]
30
+ // bins is [num_bins]
31
+ int batch_idx = blockIdx.y;
32
+ int num_bins = gridDim.x;
33
+ x += batch_idx * num_bins;
34
+ out += batch_idx * columns;
35
+
36
+ // Load the start/end for this bin.
37
+ int bin_idx = blockIdx.x;
38
+ int start = 0;
39
+ if (bin_idx > 0) start = __ldg(bins + bin_idx - 1);
40
+ int end = __ldg(bins + bin_idx);
41
+
42
+ // Load the value to replicate.
43
+ T value = __ldg((T*)x + bin_idx);
44
+
45
+ // Offset to this threadblocks bin and this threads
46
+ // offset within the bin.
47
+ int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
48
+ out += start + bin_offset;
49
+
50
+ // Replicate the value to the output.
51
+ //
52
+ // TODO(tgale): Vectorize these stores.
53
+ int num_elements = end - start;
54
+ const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
55
+ T *out_ptr = (T*)out;
56
+ for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
57
+ *out_ptr = value;
58
+ out_ptr += kElementsPerLoop;
59
+ }
60
+ }
61
+
62
+ template <typename T>
63
+ cudaError_t ReplicateForward(T *x,
64
+ int batch_size,
65
+ int num_bins,
66
+ int *bins,
67
+ T *out,
68
+ int columns,
69
+ cudaStream_t stream) {
70
+ const int kThreadsPerBlock = 64;
71
+ dim3 block_dim(kThreadsPerBlock, 1, 1);
72
+ int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
73
+ dim3 grid_dim(num_bins, batch_size, group_size);
74
+ ReplicateForwardKernel<T, kThreadsPerBlock><<<
75
+ grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
76
+ return cudaGetLastError();
77
+ }
78
+
79
+ void cub_segmented_reduce(torch::Tensor grad,
80
+ torch::Tensor bins,
81
+ torch::Tensor out,
82
+ cudaStream_t stream) {
83
+ // Append a zero to the bin boundaries for CUB.
84
+ torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
85
+ CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
86
+ 0,
87
+ offsets.numel() * sizeof(int),
88
+ stream));
89
+ CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
90
+ bins.data_ptr<int>(),
91
+ bins.numel() * sizeof(int),
92
+ cudaMemcpyDeviceToDevice,
93
+ stream));
94
+
95
+ // Get temporary buffer size.
96
+ size_t scratchpad_bytes = 0;
97
+ CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
98
+ scratchpad_bytes,
99
+ grad.data_ptr<c10::Half>(),
100
+ out.data_ptr<c10::Half>(),
101
+ bins.numel(),
102
+ offsets.data_ptr<int>(),
103
+ offsets.data_ptr<int>() + 1,
104
+ stream));
105
+
106
+ // Allocate scratchpad.
107
+ auto options = torch::TensorOptions()
108
+ .dtype(torch::kInt8)
109
+ .device(grad.device());
110
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
111
+
112
+ // Run the kernel for each batch item.
113
+ for (int i = 0; i < grad.size(0); ++i) {
114
+ int num_bins = out.size(1);
115
+ int num_values = grad.size(1);
116
+ CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
117
+ scratchpad_bytes,
118
+ grad.data_ptr<c10::Half>() + i * num_values,
119
+ out.data_ptr<c10::Half>() + i * num_bins,
120
+ bins.numel(),
121
+ offsets.data_ptr<int>(),
122
+ offsets.data_ptr<int>() + 1,
123
+ stream));
124
+ }
125
+ }
126
+
127
+ } // namespace replicate
128
+
129
+ void replicate_forward(torch::Tensor x,
130
+ torch::Tensor bins,
131
+ torch::Tensor out) {
132
+ // Validate the inputs.
133
+ TORCH_CHECK(x.is_cuda());
134
+ TORCH_CHECK(x.ndimension() == 2);
135
+ TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
136
+ x.scalar_type() == torch::kInt16 ||
137
+ x.scalar_type() == torch::kInt32);
138
+ TORCH_CHECK(bins.is_cuda());
139
+ TORCH_CHECK(bins.ndimension() == 1);
140
+ TORCH_CHECK(bins.scalar_type() == torch::kInt);
141
+ TORCH_CHECK(out.is_cuda());
142
+ TORCH_CHECK(out.ndimension() == 2);
143
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
144
+
145
+ // Batch dimensions should match for input/output.
146
+ TORCH_CHECK(x.size(0) == out.size(0));
147
+
148
+ // One input for each bin (in each batch).
149
+ TORCH_CHECK(x.size(1) == bins.size(0));
150
+
151
+ // Exit early if there is no work to do.
152
+ if (out.numel() == 0) return;
153
+
154
+ switch (x.scalar_type()) {
155
+ case torch::kFloat16:
156
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
157
+ x.size(0),
158
+ x.size(1),
159
+ bins.data_ptr<int>(),
160
+ out.data_ptr<c10::Half>(),
161
+ out.size(1),
162
+ c10::cuda::getCurrentCUDAStream()));
163
+ return;
164
+ case torch::kInt32:
165
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
166
+ x.size(0),
167
+ x.size(1),
168
+ bins.data_ptr<int>(),
169
+ out.data_ptr<int>(),
170
+ out.size(1),
171
+ c10::cuda::getCurrentCUDAStream()));
172
+ return;
173
+ }
174
+ TORCH_CHECK(x.scalar_type() == torch::kInt16);
175
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
176
+ x.size(0),
177
+ x.size(1),
178
+ bins.data_ptr<int>(),
179
+ out.data_ptr<short>(),
180
+ out.size(1),
181
+ c10::cuda::getCurrentCUDAStream()));
182
+ }
183
+
184
+ void replicate_backward(torch::Tensor grad,
185
+ torch::Tensor bins,
186
+ torch::Tensor out) {
187
+ // Validate the inputs.
188
+ TORCH_CHECK(grad.is_cuda());
189
+ TORCH_CHECK(grad.ndimension() == 2);
190
+ TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
191
+ TORCH_CHECK(bins.is_cuda());
192
+ TORCH_CHECK(bins.ndimension() == 1);
193
+ TORCH_CHECK(bins.scalar_type() == torch::kInt);
194
+ TORCH_CHECK(out.is_cuda());
195
+ TORCH_CHECK(out.ndimension() == 2);
196
+ TORCH_CHECK(out.scalar_type() == torch::kFloat16);
197
+
198
+ // Batch dimensions should match for input/output.
199
+ TORCH_CHECK(grad.size(0) == out.size(0));
200
+
201
+ // One output for each bin (in each batch).
202
+ TORCH_CHECK(out.size(1) == bins.size(0));
203
+
204
+ replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
205
+ }
206
+
207
+ } // namespace megablocks
208
+
209
+ #undef CUDA_CALL
210
+ #undef CUB_WRAPPED_NAMESPACE
csrc/new_replicate.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Forward pass: replicate values from x according to bin sizes
8
+ void replicate_forward(torch::Tensor x,
9
+ torch::Tensor bins,
10
+ torch::Tensor out);
11
+
12
+ // Backward pass: reduce gradients back to bins using segmented reduction
13
+ void replicate_backward(torch::Tensor grad,
14
+ torch::Tensor bins,
15
+ torch::Tensor out);
16
+
17
+ } // namespace megablocks
csrc/new_sort.cu ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include "new_sort.h"
5
+ #include <cstdint>
6
+ #include <cub/cub.cuh>
7
+ #include <c10/cuda/CUDAStream.h>
8
+
9
+ #define CUDA_CALL(code) \
10
+ do { \
11
+ cudaError_t status = code; \
12
+ std::string err = cudaGetErrorString(status); \
13
+ TORCH_CHECK(status == cudaSuccess, err); \
14
+ } while (0)
15
+
16
+ namespace megablocks {
17
+
18
+ template <typename T>
19
+ void cub_radix_sort(torch::Tensor x,
20
+ int end_bit,
21
+ torch::Tensor x_out,
22
+ torch::Tensor iota_out) {
23
+ // Get iota for values in sort.
24
+ torch::Tensor iota = torch::arange(0, x.numel(), x.options());
25
+
26
+ // Get temporary buffer size.
27
+ size_t scratchpad_bytes = 0;
28
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr,
29
+ scratchpad_bytes,
30
+ x.data_ptr<T>(),
31
+ x_out.data_ptr<T>(),
32
+ iota.data_ptr<T>(),
33
+ iota_out.data_ptr<T>(),
34
+ x.numel(),
35
+ /*begin_bit*/0,
36
+ /*end_bit=*/end_bit,
37
+ c10::cuda::getCurrentCUDAStream()));
38
+
39
+ // Allocate scratchpad.
40
+ auto options = torch::TensorOptions()
41
+ .dtype(torch::kInt8)
42
+ .device(x.device());
43
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
44
+
45
+ // Run the kernel.
46
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
47
+ scratchpad_bytes,
48
+ x.data_ptr<T>(),
49
+ x_out.data_ptr<T>(),
50
+ iota.data_ptr<T>(),
51
+ iota_out.data_ptr<T>(),
52
+ x.numel(),
53
+ /*begin_bit=*/0,
54
+ /*end_bit=*/end_bit,
55
+ c10::cuda::getCurrentCUDAStream()));
56
+ }
57
+
58
+ void sort(torch::Tensor x,
59
+ int end_bit,
60
+ torch::Tensor x_out,
61
+ torch::Tensor iota_out) {
62
+ TORCH_CHECK(x.is_cuda());
63
+ TORCH_CHECK(x.ndimension() == 1);
64
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
65
+ x.scalar_type() == torch::kInt32 ||
66
+ x.scalar_type() == torch::kInt64);
67
+ TORCH_CHECK(x_out.is_cuda());
68
+ TORCH_CHECK(x_out.ndimension() == 1);
69
+ TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
70
+ TORCH_CHECK(iota_out.is_cuda());
71
+ TORCH_CHECK(iota_out.ndimension() == 1);
72
+ TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
73
+
74
+ // Exit early if there is not work to do.
75
+ if (x_out.numel() == 0) return;
76
+
77
+ switch (x.scalar_type()) {
78
+ case torch::kInt16:
79
+ return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
80
+ case torch::kInt32:
81
+ return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
82
+ }
83
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
84
+ return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
85
+ }
86
+
87
+ } // namespace megablocks
88
+
89
+ #undef CUDA_CALL
90
+ #undef CUB_WRAPPED_NAMESPACE
csrc/new_sort.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Public interface function for radix sorting with indices
8
+ void sort(torch::Tensor x,
9
+ int end_bit,
10
+ torch::Tensor x_out,
11
+ torch::Tensor iota_out);
12
+
13
+ } // namespace megablocks
csrc/replicate.h ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include <cstdint>
5
+
6
+ #include <cub/cub.cuh>
7
+ #include <c10/util/Half.h>
8
+ #include <c10/cuda/CUDAStream.h>
9
+ // #include <torch/extension.h>
10
+
11
+ #define CUDA_CALL(code) \
12
+ do { \
13
+ cudaError_t status = code; \
14
+ std::string err = cudaGetErrorString(status); \
15
+ TORCH_CHECK(status == cudaSuccess, err); \
16
+ } while (0)
17
+
18
+ namespace megablocks {
19
+ namespace replicate {
20
+
21
+ template <typename T, int kThreadsPerBlock>
22
+ __global__ void __launch_bounds__(kThreadsPerBlock)
23
+ ReplicateForwardKernel(T * __restrict__ x,
24
+ int * __restrict__ bins,
25
+ T * __restrict__ out,
26
+ int columns) {
27
+ // Offset to this threadblocks batch.
28
+ //
29
+ // x is [batch_size, num_bins]
30
+ // out is [batch_size, columns]
31
+ // bins is [num_bins]
32
+ int batch_idx = blockIdx.y;
33
+ int num_bins = gridDim.x;
34
+ x += batch_idx * num_bins;
35
+ out += batch_idx * columns;
36
+
37
+ // Load the start/end for this bin.
38
+ int bin_idx = blockIdx.x;
39
+ int start = 0;
40
+ if (bin_idx > 0) start = __ldg(bins + bin_idx - 1);
41
+ int end = __ldg(bins + bin_idx);
42
+
43
+ // Load the value to replicate.
44
+ T value = __ldg((T*)x + bin_idx);
45
+
46
+ // Offset to this threadblocks bin and this threads
47
+ // offset within the bin.
48
+ int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
49
+ out += start + bin_offset;
50
+
51
+ // Replicate the value to the output.
52
+ //
53
+ // TODO(tgale): Vectorize these stores.
54
+ int num_elements = end - start;
55
+ const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
56
+ T *out_ptr = (T*)out;
57
+ for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
58
+ *out_ptr = value;
59
+ out_ptr += kElementsPerLoop;
60
+ }
61
+ }
62
+
63
+ template <typename T>
64
+ cudaError_t ReplicateForward(T *x,
65
+ int batch_size,
66
+ int num_bins,
67
+ int *bins,
68
+ T *out,
69
+ int columns,
70
+ cudaStream_t stream) {
71
+ const int kThreadsPerBlock = 64;
72
+ dim3 block_dim(kThreadsPerBlock, 1, 1);
73
+ int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
74
+ dim3 grid_dim(num_bins, batch_size, group_size);
75
+ ReplicateForwardKernel<T, kThreadsPerBlock><<<
76
+ grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
77
+ return cudaGetLastError();
78
+ }
79
+
80
+ void cub_segmented_reduce(torch::Tensor grad,
81
+ torch::Tensor bins,
82
+ torch::Tensor out,
83
+ cudaStream_t stream) {
84
+ // Append a zero to the bin boundaries for CUB.
85
+ torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
86
+ CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
87
+ 0,
88
+ offsets.numel() * sizeof(int),
89
+ stream));
90
+ CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
91
+ bins.data_ptr<int>(),
92
+ bins.numel() * sizeof(int),
93
+ cudaMemcpyDeviceToDevice,
94
+ stream));
95
+
96
+ // Get temporary buffer size.
97
+ size_t scratchpad_bytes = 0;
98
+ CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
99
+ scratchpad_bytes,
100
+ grad.data_ptr<c10::Half>(),
101
+ out.data_ptr<c10::Half>(),
102
+ bins.numel(),
103
+ offsets.data_ptr<int>(),
104
+ offsets.data_ptr<int>() + 1,
105
+ stream));
106
+
107
+ // Allocate scratchpad.
108
+ auto options = torch::TensorOptions()
109
+ .dtype(torch::kInt8)
110
+ .device(grad.device());
111
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
112
+
113
+ // Run the kernel for each batch item.
114
+ for (int i = 0; i < grad.size(0); ++i) {
115
+ int num_bins = out.size(1);
116
+ int num_values = grad.size(1);
117
+ CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
118
+ scratchpad_bytes,
119
+ grad.data_ptr<c10::Half>() + i * num_values,
120
+ out.data_ptr<c10::Half>() + i * num_bins,
121
+ bins.numel(),
122
+ offsets.data_ptr<int>(),
123
+ offsets.data_ptr<int>() + 1,
124
+ stream));
125
+ }
126
+ }
127
+
128
+ } // namespace replicate
129
+
130
+ void replicate_forward(torch::Tensor x,
131
+ torch::Tensor bins,
132
+ torch::Tensor out) {
133
+ // Validate the inputs.
134
+ TORCH_CHECK(x.is_cuda());
135
+ TORCH_CHECK(x.ndimension() == 2);
136
+ TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
137
+ x.scalar_type() == torch::kInt16 ||
138
+ x.scalar_type() == torch::kInt32);
139
+ TORCH_CHECK(bins.is_cuda());
140
+ TORCH_CHECK(bins.ndimension() == 1);
141
+ TORCH_CHECK(bins.scalar_type() == torch::kInt);
142
+ TORCH_CHECK(out.is_cuda());
143
+ TORCH_CHECK(out.ndimension() == 2);
144
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
145
+
146
+ // Batch dimensions should match for input/output.
147
+ TORCH_CHECK(x.size(0) == out.size(0));
148
+
149
+ // One input for each bin (in each batch).
150
+ TORCH_CHECK(x.size(1) == bins.size(0));
151
+
152
+ // Exit early if there is no work to do.
153
+ if (out.numel() == 0) return;
154
+
155
+ switch (x.scalar_type()) {
156
+ case torch::kFloat16:
157
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
158
+ x.size(0),
159
+ x.size(1),
160
+ bins.data_ptr<int>(),
161
+ out.data_ptr<c10::Half>(),
162
+ out.size(1),
163
+ c10::cuda::getCurrentCUDAStream()));
164
+ return;
165
+ case torch::kInt32:
166
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
167
+ x.size(0),
168
+ x.size(1),
169
+ bins.data_ptr<int>(),
170
+ out.data_ptr<int>(),
171
+ out.size(1),
172
+ c10::cuda::getCurrentCUDAStream()));
173
+ return;
174
+ }
175
+ TORCH_CHECK(x.scalar_type() == torch::kInt16);
176
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
177
+ x.size(0),
178
+ x.size(1),
179
+ bins.data_ptr<int>(),
180
+ out.data_ptr<short>(),
181
+ out.size(1),
182
+ c10::cuda::getCurrentCUDAStream()));
183
+ }
184
+
185
+ void replicate_backward(torch::Tensor grad,
186
+ torch::Tensor bins,
187
+ torch::Tensor out) {
188
+ // Validate the inputs.
189
+ TORCH_CHECK(grad.is_cuda());
190
+ TORCH_CHECK(grad.ndimension() == 2);
191
+ TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
192
+ TORCH_CHECK(bins.is_cuda());
193
+ TORCH_CHECK(bins.ndimension() == 1);
194
+ TORCH_CHECK(bins.scalar_type() == torch::kInt);
195
+ TORCH_CHECK(out.is_cuda());
196
+ TORCH_CHECK(out.ndimension() == 2);
197
+ TORCH_CHECK(out.scalar_type() == torch::kFloat16);
198
+
199
+ // Batch dimensions should match for input/output.
200
+ TORCH_CHECK(grad.size(0) == out.size(0));
201
+
202
+ // One output for each bin (in each batch).
203
+ TORCH_CHECK(out.size(1) == bins.size(0));
204
+
205
+ replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
206
+ }
207
+
208
+ } // namespace megablocks
209
+
210
+ #undef CUDA_CALL
211
+ #undef CUB_WRAPPED_NAMESPACE
csrc/sort.h ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include <cstdint>
5
+
6
+ #include <cub/cub.cuh>
7
+ #include <c10/cuda/CUDAStream.h>
8
+ // #include <torch/extension.h>
9
+
10
+ #define CUDA_CALL(code) \
11
+ do { \
12
+ cudaError_t status = code; \
13
+ std::string err = cudaGetErrorString(status); \
14
+ TORCH_CHECK(status == cudaSuccess, err); \
15
+ } while (0)
16
+
17
+ namespace megablocks {
18
+
19
+ template <typename T>
20
+ void cub_radix_sort(torch::Tensor x,
21
+ int end_bit,
22
+ torch::Tensor x_out,
23
+ torch::Tensor iota_out) {
24
+ // Get iota for values in sort.
25
+ torch::Tensor iota = torch::arange(0, x.numel(), x.options());
26
+
27
+ // Get temporary buffer size.
28
+ size_t scratchpad_bytes = 0;
29
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr,
30
+ scratchpad_bytes,
31
+ x.data_ptr<T>(),
32
+ x_out.data_ptr<T>(),
33
+ iota.data_ptr<T>(),
34
+ iota_out.data_ptr<T>(),
35
+ x.numel(),
36
+ /*begin_bit*/0,
37
+ /*end_bit=*/end_bit,
38
+ c10::cuda::getCurrentCUDAStream()));
39
+
40
+ // Allocate scratchpad.
41
+ auto options = torch::TensorOptions()
42
+ .dtype(torch::kInt8)
43
+ .device(x.device());
44
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
45
+
46
+ // Run the kernel.
47
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
48
+ scratchpad_bytes,
49
+ x.data_ptr<T>(),
50
+ x_out.data_ptr<T>(),
51
+ iota.data_ptr<T>(),
52
+ iota_out.data_ptr<T>(),
53
+ x.numel(),
54
+ /*begin_bit=*/0,
55
+ /*end_bit=*/end_bit,
56
+ c10::cuda::getCurrentCUDAStream()));
57
+ }
58
+
59
+ void sort(torch::Tensor x,
60
+ int end_bit,
61
+ torch::Tensor x_out,
62
+ torch::Tensor iota_out) {
63
+ TORCH_CHECK(x.is_cuda());
64
+ TORCH_CHECK(x.ndimension() == 1);
65
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
66
+ x.scalar_type() == torch::kInt32 ||
67
+ x.scalar_type() == torch::kInt64);
68
+ TORCH_CHECK(x_out.is_cuda());
69
+ TORCH_CHECK(x_out.ndimension() == 1);
70
+ TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
71
+ TORCH_CHECK(iota_out.is_cuda());
72
+ TORCH_CHECK(iota_out.ndimension() == 1);
73
+ TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
74
+
75
+ // Exit early if there is not work to do.
76
+ if (x_out.numel() == 0) return;
77
+
78
+ switch (x.scalar_type()) {
79
+ case torch::kInt16:
80
+ return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
81
+ case torch::kInt32:
82
+ return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
83
+ }
84
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
85
+ return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
86
+ }
87
+
88
+ } // namespace megablocks
89
+
90
+ #undef CUDA_CALL
91
+ #undef CUB_WRAPPED_NAMESPACE
flake.lock ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1748598786,
77
+ "owner": "huggingface",
78
+ "repo": "hf-nix",
79
+ "rev": "6ca679441494139fde1f2355691ddb5dc8170269",
80
+ "type": "github"
81
+ },
82
+ "original": {
83
+ "owner": "huggingface",
84
+ "repo": "hf-nix",
85
+ "type": "github"
86
+ }
87
+ },
88
+ "kernel-builder": {
89
+ "inputs": {
90
+ "flake-compat": "flake-compat",
91
+ "flake-utils": "flake-utils",
92
+ "hf-nix": "hf-nix",
93
+ "nixpkgs": [
94
+ "kernel-builder",
95
+ "hf-nix",
96
+ "nixpkgs"
97
+ ]
98
+ },
99
+ "locked": {
100
+ "lastModified": 1749576434,
101
+ "narHash": "sha256-wSdtZih2fMQ3ne/U7OKIhmP43zCIuRBhJ5zMMz747u0=",
102
+ "path": "/home/ubuntu/Projects/kernel-builder",
103
+ "type": "path"
104
+ },
105
+ "original": {
106
+ "path": "/home/ubuntu/Projects/kernel-builder",
107
+ "type": "path"
108
+ }
109
+ },
110
+ "nixpkgs": {
111
+ "locked": {
112
+ "lastModified": 1747820358,
113
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
114
+ "owner": "danieldk",
115
+ "repo": "nixpkgs",
116
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
117
+ "type": "github"
118
+ },
119
+ "original": {
120
+ "owner": "danieldk",
121
+ "ref": "cudatoolkit-12.9-kernel-builder",
122
+ "repo": "nixpkgs",
123
+ "type": "github"
124
+ }
125
+ },
126
+ "root": {
127
+ "inputs": {
128
+ "kernel-builder": "kernel-builder"
129
+ }
130
+ },
131
+ "systems": {
132
+ "locked": {
133
+ "lastModified": 1681028828,
134
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
135
+ "owner": "nix-systems",
136
+ "repo": "default",
137
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
138
+ "type": "github"
139
+ },
140
+ "original": {
141
+ "owner": "nix-systems",
142
+ "repo": "default",
143
+ "type": "github"
144
+ }
145
+ },
146
+ "systems_2": {
147
+ "locked": {
148
+ "lastModified": 1681028828,
149
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
150
+ "owner": "nix-systems",
151
+ "repo": "default",
152
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
153
+ "type": "github"
154
+ },
155
+ "original": {
156
+ "owner": "nix-systems",
157
+ "repo": "default",
158
+ "type": "github"
159
+ }
160
+ }
161
+ },
162
+ "root": "root",
163
+ "version": 7
164
+ }
flake.nix ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for megablocks_moe kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "path:/home/ubuntu/Projects/kernel-builder";
6
+ # kernel-builder.url = "github:huggingface/kernel-builder/v0.4.0";
7
+ };
8
+
9
+ outputs =
10
+ {
11
+ self,
12
+ kernel-builder,
13
+ }:
14
+ kernel-builder.lib.genFlakeOutputs {
15
+ path = ./.;
16
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
17
+ };
18
+ }
tests/__init__.py ADDED
File without changes
tests/test_mb_moe.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import megablocks
2
+
3
+ def test_import():
4
+ """Simple test to check if the module can be imported."""
5
+ print("megablocks_moe module imported successfully.")
6
+ print("Available functions:", dir(megablocks))
torch-ext/megablocks/__init__.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from ._ops import ops
7
+
8
+ from megablocks.layers.arguments import Arguments
9
+ from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
10
+ from megablocks.layers.glu import SparseGLU
11
+ from megablocks.layers.mlp import MLP, SparseMLP
12
+ from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
13
+
14
+ # This section contains the direct kernel exports (not inlcuded in the original code)
15
+ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
16
+ """
17
+ Compute exclusive cumulative sum along the specified dimension.
18
+
19
+ Args:
20
+ x: Input tensor
21
+ dim: Dimension along which to compute cumsum
22
+ out: Output tensor (modified in-place)
23
+
24
+ Returns:
25
+ The output tensor
26
+ """
27
+ return ops.exclusive_cumsum(x, dim, out)
28
+
29
+
30
+ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ Compute inclusive cumulative sum along the specified dimension.
33
+
34
+ Args:
35
+ x: Input tensor
36
+ dim: Dimension along which to compute cumsum
37
+ out: Output tensor (modified in-place)
38
+
39
+ Returns:
40
+ The output tensor
41
+ """
42
+ return ops.inclusive_cumsum(x, dim, out)
43
+
44
+
45
+ def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
46
+ """
47
+ Compute histogram of input tensor values.
48
+
49
+ Args:
50
+ x: Input tensor
51
+ num_bins: Number of histogram bins
52
+
53
+ Returns:
54
+ Histogram tensor with counts for each bin
55
+ """
56
+ return ops.histogram(x, num_bins)
57
+
58
+
59
+ def indices(
60
+ padded_bins: torch.Tensor,
61
+ block_size: int,
62
+ output_block_rows: int,
63
+ output_block_columns: int,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Construct indices from padded bins for sparse operations.
67
+
68
+ Args:
69
+ padded_bins: Tensor containing bin boundaries
70
+ block_size: Size of each block
71
+ output_block_rows: Number of rows in output blocks
72
+ output_block_columns: Number of columns in output blocks
73
+
74
+ Returns:
75
+ Tensor containing constructed indices
76
+ """
77
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
78
+
79
+
80
+ def replicate_forward(
81
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
82
+ ) -> torch.Tensor:
83
+ """
84
+ Forward pass of replicate operation - replicate values according to bin sizes.
85
+
86
+ Args:
87
+ x: Input tensor with values to replicate
88
+ bins: Tensor containing bin sizes
89
+ out: Output tensor (modified in-place)
90
+
91
+ Returns:
92
+ The output tensor
93
+ """
94
+ return ops.replicate_forward(x, bins, out)
95
+
96
+
97
+ def replicate_backward(
98
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
99
+ ) -> torch.Tensor:
100
+ """
101
+ Backward pass of replicate operation - reduce gradients back to bins.
102
+
103
+ Args:
104
+ grad: Gradient tensor to reduce
105
+ bins: Tensor containing bin sizes
106
+ out: Output tensor (modified in-place)
107
+
108
+ Returns:
109
+ The output tensor
110
+ """
111
+ return ops.replicate_backward(grad, bins, out)
112
+
113
+
114
+ def sort(
115
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
116
+ ) -> torch.Tensor:
117
+ """
118
+ Radix sort with index tracking.
119
+
120
+ Args:
121
+ x: Input tensor to sort
122
+ end_bit: Number of bits to consider in sorting
123
+ x_out: Output tensor for sorted values
124
+ iota_out: Output tensor for sorted indices
125
+
126
+ Returns:
127
+ The sorted values tensor
128
+ """
129
+ return ops.sort(x, end_bit, x_out, iota_out)
130
+
131
+
132
+ # Convenience functions for common use cases
133
+ def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
134
+ """
135
+ Compute cumulative sum with automatic output allocation.
136
+
137
+ Args:
138
+ x: Input tensor
139
+ dim: Dimension along which to compute cumsum (default: last dimension)
140
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
141
+
142
+ Returns:
143
+ New tensor containing the cumulative sum
144
+ """
145
+ out = torch.empty_like(x)
146
+ if exclusive:
147
+ return exclusive_cumsum(x, dim, out)
148
+ else:
149
+ return inclusive_cumsum(x, dim, out)
150
+
151
+
152
+ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
153
+ """
154
+ Sort tensor and return both sorted values and indices.
155
+
156
+ Args:
157
+ x: Input tensor to sort
158
+ end_bit: Number of bits to consider in sorting
159
+
160
+ Returns:
161
+ Tuple of (sorted_values, sorted_indices)
162
+ """
163
+ x_out = torch.empty_like(x)
164
+ iota_out = torch.empty_like(x)
165
+ sort(x, end_bit, x_out, iota_out)
166
+ return x_out, iota_out
167
+
168
+
169
+ # Export public API
170
+ __all__ = [
171
+ # Direct kernel exports
172
+ "exclusive_cumsum",
173
+ "inclusive_cumsum",
174
+ "histogram",
175
+ "indices",
176
+ "replicate_forward",
177
+ "replicate_backward",
178
+ "sort",
179
+ "cumsum",
180
+ "argsort",
181
+ # Original exports
182
+ "Arguments",
183
+ "ParallelDroplessMLP",
184
+ "dMoE",
185
+ "SparseGLU",
186
+ "MLP",
187
+ "SparseMLP",
188
+ "MoE",
189
+ "ParallelMLP",
190
+ "get_load_balancing_loss",
191
+ ]
torch-ext/megablocks/_version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """The MegaBlocks Version."""
5
+
6
+ __version__ = '0.11.0.dev0'
torch-ext/megablocks/backend/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
torch-ext/megablocks/backend/kernels.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ def assert_is_tensor(x, ndim):
10
+ if x.ndim != ndim:
11
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
12
+
13
+
14
+ def assert_is_matrix(x):
15
+ assert_is_tensor(x, 2)
16
+
17
+
18
+ def assert_is_vector(x):
19
+ if x.ndim != 1:
20
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
21
+
22
+
23
+ def assert_equal(a, b):
24
+ if a != b:
25
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
26
+
27
+
28
+ # a: (tokens, hidden_size), real.
29
+ # indices: (tokens * top_k), integer.
30
+ # bin_ids: (tokens * top_k), integer.
31
+ # weights: (tokens * top_k), real.
32
+ # bins: (num_experts), integer.
33
+ # padded_bins: (num_experts), integer.
34
+ @triton.autotune(
35
+ configs=[
36
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
37
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
38
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
39
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
40
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
41
+ ],
42
+ key=['NUM_COLUMNS'],
43
+ )
44
+ @triton.jit
45
+ def _padded_copy(
46
+ a,
47
+ b,
48
+ indices,
49
+ bin_ids,
50
+ weights,
51
+ bins,
52
+ padded_bins,
53
+ NUM_COLUMNS: tl.constexpr,
54
+ TOP_K: tl.constexpr,
55
+ BLOCK_X: tl.constexpr,
56
+ A_TO_B: tl.constexpr,
57
+ SCALE: tl.constexpr,
58
+ ):
59
+ # Our index into array 'a'.
60
+ index_a = tl.load(indices + tl.program_id(0))
61
+
62
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
63
+ # number of rows since they could be padded.
64
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
65
+
66
+ # Now we know what bin we're assigned to, but we need to know how
67
+ # many threadblocks were assigned to earlier bins so we can offset
68
+ # in our bin properly.
69
+ offset_in_bin = tl.program_id(0)
70
+ if bin_idx > 0:
71
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
72
+
73
+ # Load the starting index of our bin in array 'b'.
74
+ index_b = offset_in_bin
75
+ if bin_idx > 0:
76
+ index_b += tl.load(padded_bins + bin_idx - 1)
77
+
78
+ # Offset the input and output pointers.
79
+ #
80
+ # If we're going from A to B, divide the input index to copy
81
+ # the same input repeatedly. If we're going from B to A we
82
+ # need to reduce the result. Using atomics is slow, so we
83
+ # do the reduce step in a second kernel.
84
+ offset = index_a // TOP_K if A_TO_B else index_a
85
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
86
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
87
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
88
+
89
+ # Load the scale, if requested.
90
+ scale = tl.load(weights + index_a) if SCALE else 1
91
+
92
+ # Swap the pointers depending on the direction.
93
+ iptr = a if A_TO_B else b
94
+ optr = b if A_TO_B else a
95
+
96
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
97
+ for _ in range(iterations):
98
+ mask = offsets < NUM_COLUMNS
99
+ x = tl.load(iptr + offsets, mask=mask)
100
+ x = x.to(tl.float32) * scale.to(tl.float32)
101
+
102
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
103
+
104
+ offsets += BLOCK_X
105
+
106
+
107
+ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
108
+ # Validate the input shapes.
109
+ assert_is_matrix(x)
110
+ assert_is_vector(indices)
111
+ assert_is_vector(bin_ids)
112
+ assert_is_vector(bins)
113
+ assert_is_vector(padded_bins)
114
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
115
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
116
+ assert_equal(bins.size(), padded_bins.size())
117
+
118
+ if weights is not None:
119
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
120
+
121
+ # NOTE: Because of the padding, the output size is dynamic.
122
+ # We load the final padded bin bound to get the output rows.
123
+ output_rows = padded_bins[-1].cpu().item()
124
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
125
+ _padded_copy[(indices.shape[0],)](
126
+ x,
127
+ out,
128
+ indices,
129
+ bin_ids,
130
+ weights,
131
+ bins,
132
+ padded_bins,
133
+ NUM_COLUMNS=x.shape[1],
134
+ A_TO_B=True,
135
+ TOP_K=top_k,
136
+ SCALE=weights is not None,
137
+ )
138
+ return out
139
+
140
+
141
+ def gather(x, indices, bin_ids, weights, bins, top_k):
142
+ # Validate the input shapes.
143
+ assert_is_matrix(x)
144
+ assert_is_vector(indices)
145
+ assert_is_vector(bin_ids)
146
+ assert_is_vector(bins)
147
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
148
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
149
+
150
+ if weights is not None:
151
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
152
+
153
+ # NOTE: There is no padding so the output rows equals the
154
+ # input rows multiplied by top_k.
155
+ output_rows = x.shape[0] * top_k
156
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
157
+ _padded_copy[(indices.shape[0],)](
158
+ x,
159
+ out,
160
+ indices,
161
+ bin_ids,
162
+ weights,
163
+ bins,
164
+ bins,
165
+ NUM_COLUMNS=x.shape[1],
166
+ A_TO_B=True,
167
+ TOP_K=top_k,
168
+ SCALE=weights is not None,
169
+ )
170
+ return out
171
+
172
+
173
+ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
174
+ # Validate the input shapes.
175
+ assert_is_matrix(x)
176
+ assert_is_vector(indices)
177
+ assert_is_vector(bin_ids)
178
+ assert_is_vector(bins)
179
+ assert_is_vector(padded_bins)
180
+ assert_equal(indices.shape[0], bin_ids.shape[0])
181
+ assert_equal(bins.size(), padded_bins.size())
182
+
183
+ if weights is not None:
184
+ assert_equal(indices.shape[0], weights.shape[0])
185
+
186
+ tokens = indices.shape[0] // top_k
187
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
188
+ _padded_copy[(indices.shape[0],)](
189
+ out,
190
+ x,
191
+ indices,
192
+ bin_ids,
193
+ weights,
194
+ bins,
195
+ padded_bins,
196
+ NUM_COLUMNS=x.shape[1],
197
+ A_TO_B=False,
198
+ TOP_K=top_k,
199
+ SCALE=weights is not None,
200
+ )
201
+
202
+ # Reduce along the top-k dimension, if needed.
203
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
204
+
205
+
206
+ def scatter(x, indices, bin_ids, weights, bins, top_k):
207
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
208
+
209
+
210
+ # x: (tokens, top_k, hidden_size), real
211
+ # grad: (tokens, hidden_size), real.
212
+ # wgrad: (tokens, top_k), real.
213
+ # indices: (tokens * top_k), integer.
214
+ # bin_ids: (tokens * top_k), integer.
215
+ # bins: (num_experts), integer.
216
+ # padded_bins: (num_experts), integer.
217
+ @triton.autotune(
218
+ configs=[
219
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
220
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
221
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
222
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
223
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
224
+ ],
225
+ key=['NUM_COLUMNS'],
226
+ )
227
+ @triton.jit
228
+ def _padded_copy_wgrad(
229
+ x,
230
+ grad,
231
+ wgrad,
232
+ indices,
233
+ bin_ids,
234
+ bins,
235
+ padded_bins,
236
+ NUM_COLUMNS: tl.constexpr,
237
+ TOP_K: tl.constexpr,
238
+ BLOCK_X: tl.constexpr,
239
+ ):
240
+ # Our index into 'tokens * top_k'.
241
+ index_out = tl.load(indices + tl.program_id(0))
242
+
243
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
244
+ # number of rows since they could be padded.
245
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
246
+
247
+ # Now we know what bin we're assigned to, but we need to know how
248
+ # many threadblocks were assigned to earlier bins so we can offset
249
+ # in our bin properly.
250
+ offset_in_bin = tl.program_id(0)
251
+ if bin_idx > 0:
252
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
253
+
254
+ # Load the starting index of our bin in array 'x'.
255
+ index_x = offset_in_bin
256
+ if bin_idx > 0:
257
+ index_x += tl.load(padded_bins + bin_idx - 1)
258
+
259
+ # Offset the input and output pointers.
260
+ wgrad += index_out
261
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
262
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
263
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
264
+
265
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
266
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
267
+ for _ in range(iterations):
268
+ mask = offsets < NUM_COLUMNS
269
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
270
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
271
+ acc += data * scale
272
+ offsets += BLOCK_X
273
+
274
+ # Reduce to get the final result and store.
275
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
276
+ tl.store(wgrad, out)
277
+
278
+
279
+ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
280
+ # Validate the input shapes.
281
+ assert_is_matrix(x)
282
+ assert_is_matrix(grad)
283
+ assert_is_vector(indices)
284
+ assert_is_vector(bin_ids)
285
+ assert_is_vector(bins)
286
+ assert_is_vector(padded_bins)
287
+ assert_equal(indices.shape[0], bin_ids.shape[0])
288
+ assert_equal(bins.size(), padded_bins.size())
289
+
290
+ tokens = indices.shape[0] // top_k
291
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
292
+ _padded_copy_wgrad[(indices.shape[0],)](
293
+ x,
294
+ grad,
295
+ out,
296
+ indices,
297
+ bin_ids,
298
+ bins,
299
+ padded_bins,
300
+ NUM_COLUMNS=x.shape[1],
301
+ TOP_K=top_k,
302
+ )
303
+ return out
304
+
305
+
306
+ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
307
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
308
+
309
+
310
+ # a: (tokens, hidden_size), real.
311
+ # b: (num_experts, expert_capacity, num_columns), real.
312
+ # indices: (tokens * top_k), integer.
313
+ # weights: (tokens * top_k), real.
314
+ # bins: (num_experts), integer.
315
+ @triton.autotune(
316
+ configs=[
317
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
318
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
319
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
320
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
321
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
322
+ ],
323
+ key=['NUM_COLUMNS'],
324
+ )
325
+ @triton.jit
326
+ def _binned_copy(
327
+ a,
328
+ b,
329
+ num_experts,
330
+ expert_capacity,
331
+ indices,
332
+ weights,
333
+ bins,
334
+ NUM_COLUMNS: tl.constexpr,
335
+ TOP_K: tl.constexpr,
336
+ BLOCK_X: tl.constexpr,
337
+ A_TO_B: tl.constexpr,
338
+ SCALE: tl.constexpr,
339
+ ):
340
+ # Load our indices into the output.
341
+ expert_idx = tl.program_id(0)
342
+ entry_idx = tl.program_id(1)
343
+
344
+ # Calculate our offset into the output.
345
+ index_b = expert_idx * expert_capacity + entry_idx
346
+
347
+ # Load the index bounds for our bin and calculate
348
+ # the number of tokens assigned to our expert.
349
+ start = 0
350
+ if expert_idx > 0:
351
+ start = tl.load(bins + expert_idx - 1)
352
+ end = tl.load(bins + expert_idx)
353
+ num_tokens = end - start
354
+
355
+ # Calculate our offset into the input. If we don't
356
+ # have an input exit early.
357
+ if entry_idx >= num_tokens:
358
+ return
359
+ index_a = tl.load(indices + start + entry_idx)
360
+
361
+ # Offset the input and output pointers.
362
+ #
363
+ # If we're going from A to B, divide the input index to copy
364
+ # the same input repeatedly. If we're going from B to A we
365
+ # need to reduce the result. Using atomics is slow, so we
366
+ # do the reduce step in a second kernel.
367
+ offset = index_a // TOP_K if A_TO_B else index_a
368
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
369
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
370
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
371
+
372
+ # Load the scale, if requested.
373
+ scale = tl.load(weights + index_a) if SCALE else 1
374
+
375
+ # Swap the pointers depending on the direction.
376
+ #
377
+ # NOTE: We need to zero the output in both directions.
378
+ iptr = a if A_TO_B else b
379
+ optr = b if A_TO_B else a
380
+
381
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
382
+ for _ in range(iterations):
383
+ mask = offsets < NUM_COLUMNS
384
+ x = tl.load(iptr + offsets, mask=mask)
385
+ x = x.to(tl.float32) * scale.to(tl.float32)
386
+
387
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
388
+
389
+ offsets += BLOCK_X
390
+
391
+
392
+ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
393
+ # Validate the input shapes.
394
+ assert_is_matrix(x)
395
+ assert_is_vector(indices)
396
+ assert_is_vector(bins)
397
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
398
+
399
+ if weights is not None:
400
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
401
+
402
+ num_experts = bins.shape[0]
403
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
404
+
405
+ _binned_copy[(num_experts, expert_capacity)](
406
+ x,
407
+ out,
408
+ num_experts,
409
+ expert_capacity,
410
+ indices,
411
+ weights,
412
+ bins,
413
+ NUM_COLUMNS=x.shape[1],
414
+ A_TO_B=True,
415
+ TOP_K=top_k,
416
+ SCALE=weights is not None,
417
+ )
418
+ return out
419
+
420
+
421
+ def binned_scatter(x, indices, weights, bins, top_k):
422
+ # Validate the input shapes.
423
+ assert_is_tensor(x, 3)
424
+ assert_is_vector(indices)
425
+ assert_is_vector(bins)
426
+ assert_equal(bins.shape[0], x.shape[0])
427
+
428
+ if weights is not None:
429
+ assert_equal(indices.shape[0], weights.shape[0])
430
+
431
+ num_experts, expert_capacity, hidden_size = x.shape
432
+ tokens = indices.shape[0] // top_k
433
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
434
+ _binned_copy[(num_experts, expert_capacity)](
435
+ out,
436
+ x,
437
+ num_experts,
438
+ expert_capacity,
439
+ indices,
440
+ weights,
441
+ bins,
442
+ NUM_COLUMNS=hidden_size,
443
+ A_TO_B=False,
444
+ TOP_K=top_k,
445
+ SCALE=weights is not None,
446
+ )
447
+
448
+ # Reduce along the top-k dimension, if needed.
449
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
450
+
451
+
452
+ # a: (tokens, hidden_size), real.
453
+ # b: (num_experts, expert_capacity, num_columns), real.
454
+ # indices: (tokens * top_k), integer.
455
+ # weights: (tokens * top_k), real.
456
+ # bins: (num_experts), integer.
457
+ @triton.autotune(
458
+ configs=[
459
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
460
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
461
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
462
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
463
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
464
+ ],
465
+ key=['NUM_COLUMNS'],
466
+ )
467
+ @triton.jit
468
+ def _binned_copy_wgrad(
469
+ x,
470
+ grad,
471
+ wgrad,
472
+ num_experts,
473
+ expert_capacity,
474
+ indices,
475
+ bins,
476
+ NUM_COLUMNS: tl.constexpr,
477
+ TOP_K: tl.constexpr,
478
+ BLOCK_X: tl.constexpr,
479
+ ):
480
+ # Load our indices into the output.
481
+ expert_idx = tl.program_id(0)
482
+ entry_idx = tl.program_id(1)
483
+
484
+ # Calculate our offset into the output.
485
+ index_x = expert_idx * expert_capacity + entry_idx
486
+
487
+ # Load the index bounds for our bin and calculate
488
+ # the number of tokens assigned to our expert.
489
+ start = 0
490
+ if expert_idx > 0:
491
+ start = tl.load(bins + expert_idx - 1)
492
+ end = tl.load(bins + expert_idx)
493
+ num_tokens = end - start
494
+
495
+ # Calculate our offset into the input. If we don't
496
+ # have an input exit early.
497
+ if entry_idx >= num_tokens:
498
+ return
499
+ index_out = tl.load(indices + start + entry_idx)
500
+
501
+ # Offset the input and output pointers.
502
+ wgrad += index_out
503
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
504
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
505
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
506
+
507
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
508
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
509
+ for _ in range(iterations):
510
+ mask = offsets < NUM_COLUMNS
511
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
512
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
513
+ acc += data * scale
514
+ offsets += BLOCK_X
515
+
516
+ # Reduce to get the final result and store.
517
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
518
+ tl.store(wgrad, out)
519
+
520
+
521
+ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
522
+ # Validate the input shapes.
523
+ assert_is_tensor(x, 3)
524
+ assert_is_matrix(grad)
525
+ assert_is_vector(indices)
526
+ assert_is_vector(bins)
527
+ assert_equal(bins.shape[0], x.shape[0])
528
+
529
+ num_experts, expert_capacity, hidden_size = x.shape
530
+ tokens = indices.shape[0] // top_k
531
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
532
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
533
+ x,
534
+ grad,
535
+ out,
536
+ num_experts,
537
+ expert_capacity,
538
+ indices,
539
+ bins,
540
+ NUM_COLUMNS=hidden_size,
541
+ TOP_K=top_k,
542
+ )
543
+ return out
torch-ext/megablocks/bak.__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from megablocks_moe.megablocks import (
2
+ MoE,
3
+ dMoE,
4
+ get_load_balancing_loss,
5
+ ParallelMLP,
6
+ ParallelDroplessMLP,
7
+ SparseMLP,
8
+ MLP,
9
+ SparseGLU,
10
+ Arguments,
11
+ )
12
+
13
+ __all__ = [
14
+ "MoE",
15
+ "dMoE",
16
+ "get_load_balancing_loss",
17
+ "ParallelMLP",
18
+ "ParallelDroplessMLP",
19
+ "SparseMLP",
20
+ "MLP",
21
+ "SparseGLU",
22
+ "Arguments",
23
+ ]
torch-ext/megablocks/benchmark_util.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def log_benchmark(name, arguments, time, std):
9
+ print('=' * 60)
10
+ print(f'{name} Benchmark')
11
+ print('Benchmark Parameters:')
12
+ for (key, value) in arguments.items():
13
+ print(f'{key} = {value}')
14
+ print('Results:')
15
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
16
+ print('=' * 60)
17
+
18
+
19
+ def benchmark_function(fn, iterations=100, warmup=10):
20
+ # Warmup iterations.
21
+ for _ in range(warmup):
22
+ fn()
23
+
24
+ times = []
25
+ for i in range(iterations):
26
+ start = torch.cuda.Event(enable_timing=True)
27
+ end = torch.cuda.Event(enable_timing=True)
28
+
29
+ start.record()
30
+ fn()
31
+ end.record()
32
+
33
+ torch.cuda.synchronize()
34
+ times.append(start.elapsed_time(end))
35
+ return np.mean(times), np.std(times)
torch-ext/megablocks/grouped_gemm_util.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import warnings
4
+
5
+ _grouped_gemm_is_available: bool = False
6
+ try:
7
+ import grouped_gemm
8
+ _grouped_gemm_is_available = True
9
+ except ImportError as error:
10
+ warnings.warn('Grouped GEMM not available.')
11
+
12
+
13
+ def grouped_gemm_is_available():
14
+ return _grouped_gemm_is_available
15
+
16
+
17
+ def assert_grouped_gemm_is_available():
18
+ msg = (
19
+ 'Grouped GEMM not available. Please run '
20
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
21
+ )
22
+ assert _grouped_gemm_is_available, msg
23
+
24
+
25
+ backend = grouped_gemm.backend if grouped_gemm_is_available() else None
26
+ ops = grouped_gemm.ops if grouped_gemm_is_available() else None
torch-ext/megablocks/layers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # from megablocks.layers.dmoe import dMoE
5
+ from megablocks.layers.moe import MoE
6
+
7
+ __all__ = [
8
+ 'MoE',
9
+ # 'dMoE',
10
+ ]
torch-ext/megablocks/layers/activation_fn.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Callable, Union
5
+
6
+ import torch
7
+ from stk import Matrix
8
+
9
+
10
+ def act_fn(
11
+ x: Matrix,
12
+ function: Callable,
13
+ return_grad_fn: bool = False,
14
+ **kwargs,
15
+ ) -> Union[tuple[Matrix, Any] | Matrix]:
16
+ assert isinstance(x, Matrix)
17
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
18
+ if return_grad_fn:
19
+ x.data.requires_grad = True
20
+ out = function(x.data, **kwargs)
21
+ y = Matrix(
22
+ x.size(),
23
+ out,
24
+ x.row_indices,
25
+ x.column_indices,
26
+ x.offsets,
27
+ x.column_indices_t,
28
+ x.offsets_t,
29
+ x.block_offsets_t,
30
+ )
31
+ if return_grad_fn:
32
+ return y, out.backward
33
+ return y
torch-ext/megablocks/layers/all_to_all.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ class AllToAllOp(torch.autograd.Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
12
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
13
+
14
+ ctx.input_shape = x.shape
15
+ ctx.output_split_sizes = output_split_sizes
16
+ ctx.input_split_sizes = input_split_sizes
17
+ ctx.group = group
18
+ handle = dist.all_to_all_single(
19
+ out,
20
+ x,
21
+ output_split_sizes=output_split_sizes,
22
+ input_split_sizes=input_split_sizes,
23
+ group=group,
24
+ async_op=async_op,
25
+ )
26
+ return out, handle
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad, _):
30
+ if ctx.needs_input_grad[0]:
31
+ out = torch.empty(
32
+ ctx.input_shape,
33
+ device=grad.device,
34
+ dtype=grad.dtype,
35
+ )
36
+ dist.all_to_all_single(
37
+ out,
38
+ grad,
39
+ output_split_sizes=ctx.input_split_sizes,
40
+ input_split_sizes=ctx.output_split_sizes,
41
+ group=ctx.group,
42
+ )
43
+ return out, None, None, None, None
44
+ return None, None, None, None, None
45
+
46
+
47
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
48
+ return AllToAllOp.apply(
49
+ x,
50
+ output_split_sizes,
51
+ input_split_sizes,
52
+ group,
53
+ async_op,
54
+ )
torch-ext/megablocks/layers/arguments.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import dataclasses
5
+ from functools import partial
6
+ from typing import Any, Callable, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+
12
+ import megablocks.grouped_gemm_util as grouped_gemm
13
+
14
+ # Type annotation for in-place Tensor initialization function.
15
+ InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
16
+
17
+ _ALLOWED_BITWIDTHS = (-1, 4, 8)
18
+
19
+ DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Arguments:
24
+ # Model arguments.
25
+ hidden_size: int = 1024
26
+ ffn_hidden_size: int = 4096
27
+ num_layers: int = 1
28
+ bias: bool = True
29
+ return_bias: bool = True
30
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
31
+
32
+ # MoE arguments.
33
+ moe_num_experts: int = 1
34
+ moe_top_k: int = 1
35
+ moe_capacity_factor: int = 1
36
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
37
+ moe_loss_weight: float = 0.1
38
+ moe_jitter_eps: Optional[float] = None
39
+ moe_lbl_in_fp32: bool = False
40
+
41
+ # Parallelism arguments.
42
+ moe_expert_model_parallelism: bool = False
43
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
44
+ pipeline_model_parallel_size: int = 1
45
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
46
+
47
+ # Compute arguments.
48
+ memory_optimized_mlp: bool = False
49
+ mlp_type: str = 'mlp'
50
+ mlp_impl: str = 'sparse'
51
+
52
+ # Initialization arguments.
53
+ fp16: bool = True
54
+ bf16: bool = False
55
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
56
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
57
+ output_layer_init_method: InitFn = init_method
58
+
59
+ # Benchmarking arguments.
60
+ uniform_expert_assignment: bool = False
61
+
62
+ # shared expert arguments
63
+ shared_expert: bool = False # enable using shared expert
64
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
65
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
66
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
67
+ shared_expert_hidden_size: Optional[
68
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
69
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
70
+
71
+ # Router Z-loss arguments
72
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
73
+ moe_zloss_in_fp32: bool = False
74
+
75
+ def __post_init__(self):
76
+ # Sparse MLP is not supported with triton >=3.2.0
77
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
78
+ if self.__getattribute__('mlp_impl') == 'sparse':
79
+ try:
80
+ import triton
81
+ if triton.__version__ >= '3.2.0':
82
+ raise ValueError(
83
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
84
+ )
85
+ except ImportError:
86
+ raise ImportError('Triton is required for sparse MLP implementation')
87
+
88
+ if self.__getattribute__('mlp_impl') == 'grouped':
89
+ grouped_gemm.assert_grouped_gemm_is_available()
90
+
91
+ if self.shared_expert_hidden_size is None:
92
+ self.shared_expert_hidden_size = self.ffn_hidden_size
93
+
94
+
95
+ def from_megatron(megatron_args: Any):
96
+ args = Arguments()
97
+ for field in dataclasses.fields(args):
98
+ if hasattr(megatron_args, field.name):
99
+ setattr(args, field.name, getattr(megatron_args, field.name))
100
+ return args
torch-ext/megablocks/layers/common.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from megablocks.layers.arguments import Arguments
7
+
8
+
9
+ def dtype(args: Arguments):
10
+ if args.fp16:
11
+ return torch.float16
12
+ elif args.bf16:
13
+ return torch.bfloat16
14
+ return None
15
+
16
+
17
+ def cast_if_autocast_enabled(tensor):
18
+ if torch.is_autocast_enabled():
19
+ if tensor.device.type == 'cuda':
20
+ dtype = torch.get_autocast_gpu_dtype()
21
+ elif tensor.device.type == 'cpu':
22
+ dtype = torch.get_autocast_cpu_dtype()
23
+ else:
24
+ raise NotImplementedError()
25
+ return tensor.to(dtype=dtype)
26
+ return tensor
torch-ext/megablocks/layers/dmlp_registry.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ from megablocks.layers import glu, mlp
7
+ from megablocks.layers.arguments import Arguments
8
+
9
+ MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
+
11
+ _REGISTRY = {
12
+ 'mlp': {
13
+ 'grouped': mlp.GroupedMLP,
14
+ 'sparse': mlp.SparseMLP,
15
+ },
16
+ 'glu': {
17
+ 'grouped': glu.GroupedGLU,
18
+ 'sparse': glu.SparseGLU,
19
+ },
20
+ }
21
+
22
+
23
+ def get(args: Arguments) -> MlpType:
24
+ """Returns an MLP for use in a dMoE instance.
25
+
26
+ Uses the provided arguments to instantiate the appropriate
27
+ MLP instance. This only contains MLPs for use in dMoEs
28
+ (ie. only for the dropless versions of MoEs).
29
+
30
+ Args:
31
+ args: propagated Arguments dataclass.
32
+
33
+ Returns:
34
+ An instantiated MLP constructed using the input args.
35
+ """
36
+ if args.mlp_type not in _REGISTRY:
37
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
38
+
39
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
40
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
41
+
42
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
torch-ext/megablocks/layers/dmoe.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import stk.ops
6
+ import torch
7
+ from stk import Matrix
8
+
9
+ import megablocks.ops as ops
10
+ # from megablocks.ops import ops
11
+ from megablocks.layers import common, dmlp_registry, moe, mpu
12
+ from megablocks.layers.arguments import Arguments
13
+
14
+
15
+ def promote_scalar(x):
16
+ return x.view(1) if not len(x.size()) else x
17
+
18
+
19
+ class ParallelDroplessMLP(moe.ParallelMLP):
20
+
21
+ def __init__(self, args: Arguments):
22
+ super(ParallelDroplessMLP, self).__init__(args)
23
+ self.hidden_size = args.hidden_size
24
+ self.ffn_hidden_size = mpu.features_per_rank(args)
25
+ self.blocking = 128
26
+ self.mlp = dmlp_registry.get(args)
27
+
28
+ # Calculate the number of bits needed to represent the column indices
29
+ # in the intermediate sparse matrix.
30
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
31
+ self.transpose_sort_end_bit = max(
32
+ int(np.ceil(np.log2(max_column_index))),
33
+ 1,
34
+ )
35
+
36
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
37
+ block_columns = size[1] // self.blocking
38
+
39
+ # Sort row indices by column indices to get the transposed matrix's
40
+ # column indices.
41
+ #
42
+ # NOTE: Our sort operation uses the same width indices as the input values.
43
+ # To avoid overflow when we have large activation matrices we cast to
44
+ # 32-bit before sorting.
45
+ _, gather_indices = ops.sort(
46
+ column_indices.int(),
47
+ self.transpose_sort_end_bit,
48
+ )
49
+
50
+ # There are a constant number of blocks in every row of the sparse matrix.
51
+ # A blocks offset is:
52
+ #
53
+ # row_index * blocks_per_row + column_index % blocks_per_row
54
+ #
55
+ # Once we have the block offsets ordered for transposition we can divide
56
+ # by blocks_per_row to get the transposed column indices.
57
+ column_indices_t = row_indices.gather(0, gather_indices.long())
58
+ block_offsets_t = gather_indices.int()
59
+
60
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
61
+ nnz_per_column = ops.histogram(column_indices, block_columns)
62
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
63
+ if nnz_per_column.dim() == 0:
64
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
65
+ nnz_per_column = nnz_per_column.unsqueeze(0)
66
+ offsets_t = torch.cat([zero, nnz_per_column])
67
+ return column_indices_t, offsets_t, block_offsets_t
68
+
69
+ def topology(self, x, padded_bins):
70
+ padded_tokens, _ = x.size()
71
+ assert padded_tokens % self.blocking == 0
72
+ if self.ffn_hidden_size % self.blocking != 0:
73
+ raise ValueError(
74
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
75
+ f'the block size {self.blocking}. Please update your configuration.',
76
+ )
77
+
78
+ # Offsets for the sparse matrix. All rows have the
79
+ # same number of nonzero blocks dictated by the
80
+ # dimensionality of a single expert.
81
+ block_rows = padded_tokens // self.blocking
82
+ blocks_per_row = self.ffn_hidden_size // self.blocking
83
+ offsets = torch.arange(
84
+ 0,
85
+ block_rows * blocks_per_row + 1,
86
+ blocks_per_row,
87
+ dtype=torch.int32,
88
+ device=x.device,
89
+ )
90
+
91
+ # Indices for the sparse matrix. The indices for
92
+ # the intermediate matrix are dynamic depending
93
+ # on the mapping of tokens to experts.
94
+ column_indices = ops.topology(
95
+ padded_bins,
96
+ self.blocking,
97
+ block_rows,
98
+ blocks_per_row,
99
+ )
100
+
101
+ # TODO(tgale): This is unused. Remove the need for this in stk.
102
+ # For now, use meta init to save the device memory.
103
+ data = torch.empty(
104
+ column_indices.numel(),
105
+ self.blocking,
106
+ self.blocking,
107
+ dtype=common.dtype(self.args),
108
+ device='meta',
109
+ )
110
+ shape = (
111
+ padded_tokens,
112
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
113
+ )
114
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
115
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
116
+ shape,
117
+ row_indices,
118
+ column_indices,
119
+ offsets,
120
+ )
121
+ return stk.Matrix(
122
+ shape,
123
+ data,
124
+ row_indices,
125
+ column_indices,
126
+ offsets,
127
+ column_indices_t,
128
+ offsets_t,
129
+ block_offsets_t,
130
+ )
131
+
132
+ def indices_and_padded_bins(self, top_experts):
133
+ # Sort the expert ids to produce the scatter/gather
134
+ # indices for the permutation.
135
+ top_experts = top_experts.int()
136
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
137
+
138
+ # Histogram the expert ids to identify the number of
139
+ # tokens routed to each expert.
140
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
141
+
142
+ # Round the token counts up to the block size used in
143
+ # the matrix muliplications. Caculate the starting
144
+ # position of each bin.
145
+ padded_tokens_per_expert = ops.round_up(
146
+ tokens_per_expert,
147
+ self.blocking,
148
+ )
149
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
150
+ padded_bins = promote_scalar(padded_bins)
151
+
152
+ # Calculate the bin bounds for the sorted tokens.
153
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
154
+ bins = promote_scalar(bins)
155
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
156
+
157
+ def sparse_forward_once(self, x, expert_weights, top_experts):
158
+ # x: [sl, bs, hs]
159
+ # expert_weights: [sl * bs, top-k]
160
+ # top_experts: [sl * bs, top-k]
161
+ expert_weights = expert_weights.flatten()
162
+ top_experts = top_experts.flatten()
163
+ with torch.no_grad():
164
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
165
+
166
+ # Route the tokens for MoE computation.
167
+ x = x.view(-1, x.shape[-1])
168
+ x = ops.padded_gather(
169
+ x,
170
+ indices,
171
+ bin_ids,
172
+ bins,
173
+ padded_bins,
174
+ self.top_k,
175
+ )
176
+
177
+ # Create the sparse matrix topology.
178
+ with torch.no_grad():
179
+ topo = self.topology(x, padded_bins)
180
+
181
+ # Perform the expert computation.
182
+ x = self.mlp(x, topo)
183
+
184
+ # Un-route the data for the MoE output.
185
+ x = ops.padded_scatter(
186
+ x,
187
+ indices,
188
+ bin_ids,
189
+ expert_weights,
190
+ bins,
191
+ padded_bins,
192
+ self.top_k,
193
+ )
194
+ return x, tokens_per_expert
195
+
196
+ # For use in the base-class parallel_forward_once.
197
+ def sparse_permute_and_compute(
198
+ self,
199
+ x,
200
+ tokens_per_expert,
201
+ indices,
202
+ bin_ids,
203
+ expert_weights,
204
+ bins,
205
+ expert_capactiy, # unused
206
+ top_k,
207
+ ):
208
+
209
+ # Round the token counts up to the block size used in the matrix
210
+ # multiplication. Calculate the starting position of each bin.
211
+ padded_tokens_per_expert = ops.round_up(
212
+ tokens_per_expert,
213
+ self.blocking,
214
+ )
215
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
216
+ padded_bins = promote_scalar(padded_bins)
217
+
218
+ # Route the tokens for MoE computation.
219
+ x = x.view(-1, x.shape[-1])
220
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
221
+
222
+ # Create the sparse matrix topology.
223
+ with torch.no_grad():
224
+ topo = self.topology(x, padded_bins)
225
+
226
+ # Perform the expert computation.
227
+ x = self.mlp(x, topo)
228
+
229
+ # Un-route the data for the MoE output.
230
+ return ops.padded_scatter(
231
+ x,
232
+ indices,
233
+ bin_ids,
234
+ expert_weights,
235
+ bins,
236
+ padded_bins,
237
+ top_k,
238
+ )
239
+
240
+ def grouped_forward_once(self, x, expert_weights, top_experts):
241
+ # x: [sl, bs, hs]
242
+ # expert_weights: [sl * bs, top-k]
243
+ # top_experts: [sl * bs, top-k]
244
+ expert_weights = expert_weights.flatten()
245
+ top_experts = top_experts.flatten()
246
+ with torch.no_grad():
247
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
248
+
249
+ out = self.grouped_permute_and_compute(
250
+ x,
251
+ tokens_per_expert,
252
+ indices,
253
+ bin_ids,
254
+ expert_weights,
255
+ bins,
256
+ -1, # unused
257
+ self.args.moe_top_k,
258
+ )
259
+ return out, tokens_per_expert
260
+
261
+ def grouped_permute_and_compute(
262
+ self,
263
+ x,
264
+ tokens_per_expert,
265
+ indices,
266
+ bin_ids,
267
+ expert_weights,
268
+ bins,
269
+ expert_capactiy, # unused
270
+ top_k,
271
+ ):
272
+
273
+ # Route the tokens for MoE computation.
274
+ x = x.view(-1, x.shape[-1])
275
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
276
+
277
+ # Perform the expert computation.
278
+ x = self.mlp(x, tokens_per_expert)
279
+
280
+ # Un-route the data for the MoE output.
281
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
282
+
283
+ def forward_once(self, x, expert_weights, top_experts):
284
+ if self.args.mlp_impl == 'sparse':
285
+ return self.sparse_forward_once(x, expert_weights, top_experts)
286
+ else:
287
+ return self.grouped_forward_once(x, expert_weights, top_experts)
288
+
289
+ def permute_and_compute(
290
+ self,
291
+ x,
292
+ tokens_per_expert,
293
+ indices,
294
+ bin_ids,
295
+ expert_weights,
296
+ bins,
297
+ expert_capactiy,
298
+ top_k,
299
+ ):
300
+ if self.args.mlp_impl == 'sparse':
301
+ return self.sparse_permute_and_compute(
302
+ x,
303
+ tokens_per_expert,
304
+ indices,
305
+ bin_ids,
306
+ expert_weights,
307
+ bins,
308
+ expert_capactiy,
309
+ top_k,
310
+ )
311
+ else:
312
+ return self.grouped_permute_and_compute(
313
+ x,
314
+ tokens_per_expert,
315
+ indices,
316
+ bin_ids,
317
+ expert_weights,
318
+ bins,
319
+ expert_capactiy,
320
+ top_k,
321
+ )
322
+
323
+
324
+ class dMoE(moe.MoE):
325
+
326
+ def _init_experts_mlp(self, args: Arguments):
327
+ return ParallelDroplessMLP(args)
torch-ext/megablocks/layers/gelu.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import stk
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ @torch.jit.script
10
+ def _gelu_backward_inplace(g, x):
11
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
12
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
13
+ return g.mul_(ff)
14
+
15
+
16
+ def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
17
+ # NOTE: The two sparse matrices must have the same topology.
18
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
19
+ return stk.Matrix(
20
+ x.size(),
21
+ _gelu_backward_inplace(grad.data, x.data),
22
+ x.row_indices,
23
+ x.column_indices,
24
+ x.offsets,
25
+ x.column_indices_t,
26
+ x.offsets_t,
27
+ x.block_offsets_t,
28
+ )
29
+ return _gelu_backward_inplace(grad, x)
30
+
31
+
32
+ def gelu(x: stk.Matrix):
33
+ assert isinstance(x, stk.Matrix)
34
+ return stk.Matrix(
35
+ x.size(),
36
+ F.gelu(x.data, approximate='tanh'),
37
+ x.row_indices,
38
+ x.column_indices,
39
+ x.offsets,
40
+ x.column_indices_t,
41
+ x.offsets_t,
42
+ x.block_offsets_t,
43
+ )
torch-ext/megablocks/layers/glu.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import stk.ops
5
+ import torch
6
+
7
+ from megablocks import grouped_gemm_util as gg
8
+ from megablocks.layers import common, mpu
9
+ from megablocks.layers.activation_fn import act_fn
10
+ from megablocks.layers.arguments import Arguments
11
+ from megablocks.layers.mlp import (
12
+ SharedMLP,
13
+ SparseMLP,
14
+ create_dmoe_expert_weights,
15
+ resolve_dtensor,
16
+ )
17
+
18
+
19
+ class SparseGLU(SparseMLP):
20
+
21
+ def __init__(self, args: Arguments):
22
+ super().__init__(args)
23
+ self.v1 = torch.nn.Parameter(
24
+ torch.empty(
25
+ self._num_rows_per_rank,
26
+ args.hidden_size,
27
+ device=args.device,
28
+ dtype=common.dtype(args),
29
+ ),
30
+ )
31
+ with torch.no_grad():
32
+ self.v1.copy_(
33
+ create_dmoe_expert_weights(
34
+ args,
35
+ args.moe_num_experts,
36
+ args.ffn_hidden_size,
37
+ args.hidden_size,
38
+ args.init_method,
39
+ ),
40
+ )
41
+
42
+ mpu.set_expert_model_parallel_attributes(
43
+ self.v1,
44
+ self._should_set_parallelism_attribute,
45
+ )
46
+
47
+ def forward(self, x, topo):
48
+ if self.args.memory_optimized_mlp:
49
+ raise NotImplementedError(
50
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
51
+ )
52
+
53
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
54
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
55
+
56
+ # Compute the GLU.
57
+ x1 = stk.ops.sdd(x, w1.t(), topo)
58
+ x2 = stk.ops.sdd(x, v1.t(), topo)
59
+
60
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
61
+ x1 = stk.ops.mul(activation_fn_out, x2)
62
+
63
+ return stk.ops.dsd(x1, w2)
64
+
65
+
66
+ class MemoryOptimizedGroupedGLU(torch.autograd.Function):
67
+ """GroupedMLP with manually scheduled memory reuse."""
68
+
69
+ @staticmethod
70
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
71
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
72
+ # Cast inputs using ctx dtype from AMP
73
+ if ctx._fwd_used_autocast:
74
+ x = x.to(ctx._dtype)
75
+ w1 = w1.to(ctx._dtype)
76
+ v1 = v1.to(ctx._dtype)
77
+ w2 = w2.to(ctx._dtype)
78
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
79
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
80
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
81
+
82
+ # Layer 0: x @ w1.t().
83
+ assert gg.backend is not None
84
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
85
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
86
+
87
+ # GeLU.
88
+ activation_fn_out = activation_fn(sdd_out) * v1_out
89
+
90
+ # Layer 1: x @ w2.
91
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
92
+
93
+ # NOTE: Save the input to the layer and the activation_fn input for
94
+ # gradient computation. We'll re-compute the activation_fn forward
95
+ # pass in the backward pass to avoid materializing another
96
+ # intermediate.
97
+ ctx.x_shape = x.shape
98
+ ctx.sdd_out_shape = sdd_out.shape
99
+ ctx.dtype = x.dtype
100
+ ctx.activation_fn = activation_fn
101
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
102
+ return dsd_out
103
+
104
+ @staticmethod
105
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
106
+ def backward(ctx, ddsd_out):
107
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
108
+ raise ValueError('Expected all MLP inputs to need grad.')
109
+
110
+ # Unpack saved tensors
111
+ # dtype = ctx.dtype
112
+ saved_tensors = ctx.saved_tensors
113
+ w1, v1, w2 = saved_tensors[:3]
114
+ batch_sizes = saved_tensors[3]
115
+ x = saved_tensors[4]
116
+ sdd_out, v1_out = saved_tensors[5:7]
117
+
118
+ # Rematerialize activation_fn output.
119
+ activation_fn = ctx.activation_fn
120
+ with torch.set_grad_enabled(True):
121
+ sdd_out.requires_grad = True
122
+ v1_out.requires_grad = True
123
+ activation_fn_out = activation_fn(sdd_out) * v1_out
124
+ activation_grad_fn = activation_fn_out.backward
125
+
126
+ # Compute dw2 with recomputed activation_fn output.
127
+ assert gg.backend is not None
128
+ dw2 = gg.backend.gmm(
129
+ activation_fn_out,
130
+ ddsd_out,
131
+ batch_sizes,
132
+ trans_a=True,
133
+ )
134
+
135
+ # Compute dactivation_fn_out.
136
+ #
137
+ # NOTE: We reuse the activation_fn_out allocation.
138
+ dactivation_fn_out = activation_fn_out
139
+ gg.backend.gmm(
140
+ ddsd_out,
141
+ w2,
142
+ batch_sizes,
143
+ trans_b=True,
144
+ c=dactivation_fn_out,
145
+ )
146
+
147
+ # Compute dsdd_out.
148
+ #
149
+ # NOTE: This reuses the dactivation_fn_out allocation.
150
+ assert activation_grad_fn is not None
151
+ activation_grad_fn(dactivation_fn_out)
152
+ dsdd_out = sdd_out.grad
153
+ dv1_out = v1_out.grad
154
+
155
+ # Compute dw1.
156
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
157
+
158
+ # Compute dv1.
159
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
160
+
161
+ # Compute dx.
162
+ #
163
+ # NOTE: This reuses the ddsd_out allocation.
164
+ dx = ddsd_out
165
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
166
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
167
+ return dx, dw1, dv1, dw2, None, None
168
+
169
+
170
+ memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
171
+
172
+
173
+ class GroupedGLU(SparseGLU):
174
+
175
+ def forward(self, x, tokens_per_expert):
176
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
177
+ w1, v1, w2 = (
178
+ self.scale_grad(self.w1),
179
+ self.scale_grad(self.v1),
180
+ self.scale_grad(self.w2),
181
+ )
182
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
183
+
184
+ # Re-shape the weights for the grouped GEMMs.
185
+ ne = mpu.experts_per_rank(self.args)
186
+ w1 = w1.view(ne, -1, self.args.hidden_size)
187
+ v1 = v1.view(ne, -1, self.args.hidden_size)
188
+ w2 = w2.view(ne, -1, self.args.hidden_size)
189
+
190
+ if self.args.memory_optimized_mlp:
191
+ return memory_optimized_grouped_glu(
192
+ x,
193
+ w1,
194
+ v1,
195
+ w2,
196
+ batch_sizes,
197
+ self.args.activation_fn,
198
+ )
199
+
200
+ # Compute the MLP.
201
+ assert gg.ops is not None
202
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
203
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
204
+ x1 = self.args.activation_fn(x1) * x2
205
+ return gg.ops.gmm(x1, w2, batch_sizes)
206
+
207
+
208
+ class SharedGLU(SharedMLP):
209
+ """GPU for shared expert.
210
+
211
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
212
+ """
213
+
214
+ def __init__(self, args: Arguments):
215
+ super().__init__(args)
216
+ self.gate_proj = args.fc_cls(
217
+ args.hidden_size,
218
+ self.args.shared_expert_hidden_size,
219
+ **self.fc_kwargs,
220
+ )
221
+
222
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
223
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
torch-ext/megablocks/layers/memory_test.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import gc
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ from megablocks.layers import arguments, dmoe
10
+
11
+ _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
12
+
13
+
14
+ def get_tensors():
15
+ ptrs = set()
16
+ out = []
17
+ for obj in gc.get_objects():
18
+ if torch.is_tensor(obj):
19
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
20
+ continue
21
+ out.append(obj)
22
+ ptrs.add(obj.data_ptr())
23
+ return out
24
+
25
+
26
+ def test_memory(
27
+ group,
28
+ batch_size,
29
+ sequence_length,
30
+ hidden_size,
31
+ ffn_hidden_size,
32
+ num_experts,
33
+ top_k,
34
+ ):
35
+ args = arguments.Arguments(
36
+ hidden_size=hidden_size,
37
+ ffn_hidden_size=ffn_hidden_size,
38
+ moe_num_experts=num_experts,
39
+ moe_top_k=top_k,
40
+ moe_expert_model_parallelism=True,
41
+ expert_parallel_group=group,
42
+ fp16=False,
43
+ bf16=True,
44
+ device=torch.cuda.current_device(),
45
+ )
46
+ layer = dmoe.dMoE(args).cuda()
47
+
48
+ x = torch.randn((batch_size, sequence_length, hidden_size),
49
+ device=torch.cuda.current_device(),
50
+ dtype=torch.bfloat16).requires_grad_(True)
51
+ torch.cuda.empty_cache()
52
+
53
+ # Run forward + backward.
54
+ # with torch.autograd.detect_anomaly():
55
+ out, _ = layer(x)
56
+ out.mean().backward()
57
+
58
+ # Report peak memory.
59
+ mem = torch.cuda.max_memory_allocated()
60
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
61
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
62
+
63
+ # Calculate weight and gradient memory usage.
64
+ weight_memory = 2 * (
65
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
66
+ )
67
+
68
+ def grad_numel(x):
69
+ if x.grad is not None:
70
+ return x.grad.numel()
71
+ return 0
72
+
73
+ grad_memory = 2 * (
74
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
75
+ )
76
+ weight_memory += grad_memory
77
+
78
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
79
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
80
+
81
+ # Manually calculate GPU memory usage from the garbage
82
+ # collector.
83
+ gc.collect()
84
+ total = 0
85
+ tensors = get_tensors()
86
+ tensors = sorted(tensors, key=lambda x: -x.numel())
87
+ for i, t in enumerate(tensors):
88
+ total += t.numel()
89
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
90
+ del tensors
91
+
92
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
93
+
94
+
95
+ if __name__ == '__main__':
96
+ assert dist.is_available()
97
+ group = dist.init_process_group(backend='nccl')
98
+ local_rank = dist.get_rank(group)
99
+ torch.cuda.set_device(local_rank)
100
+
101
+ for args in _TESTS:
102
+ test_memory(group, *args)
torch-ext/megablocks/layers/memory_test.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ DISTRIBUTED_ARGUMENTS="\
4
+ --nproc_per_node 1 \
5
+ --nnodes 1 \
6
+ --node_rank 0 \
7
+ --master_addr localhost \
8
+ --master_port 6000"
9
+
10
+ python -m torch.distributed.launch \
11
+ ${DISTRIBUTED_ARGUMENTS} \
12
+ megablocks/layers/memory_test.py
torch-ext/megablocks/layers/mlp.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ import stk
7
+ import stk.backend.triton_kernels
8
+ import stk.ops
9
+ import torch
10
+ from packaging import version
11
+
12
+ from megablocks import grouped_gemm_util as gg
13
+ from megablocks.layers import common, gelu, mpu
14
+ from megablocks.layers.activation_fn import act_fn
15
+ from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
16
+
17
+
18
+ class ScaleGradient(torch.autograd.Function):
19
+
20
+ @staticmethod
21
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
22
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
23
+ ctx.scale = scale
24
+ return x
25
+
26
+ @staticmethod
27
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
28
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
29
+ return grad * ctx.scale, None
30
+
31
+
32
+ scale_gradient = ScaleGradient.apply
33
+
34
+
35
+ def resolve_dtensor(weight: torch.Tensor):
36
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
37
+ from torch.distributed._tensor import DTensor
38
+ if isinstance(weight, DTensor):
39
+ return weight.to_local()
40
+ return weight
41
+
42
+
43
+ def create_moe_expert_weights(
44
+ args: Arguments,
45
+ num_experts: int,
46
+ ffn_hidden_size: int,
47
+ hidden_size: int,
48
+ init_method: InitFn,
49
+ ):
50
+ # Create the entire weight matrix such that the sampled weights will
51
+ # not vary between data parallelism and expert model parallelism for
52
+ # the same random seed.
53
+ master_weights = torch.empty(
54
+ num_experts,
55
+ ffn_hidden_size,
56
+ hidden_size,
57
+ device=args.device,
58
+ dtype=common.dtype(args),
59
+ )
60
+ init_method(master_weights)
61
+
62
+ if not args.moe_expert_model_parallelism:
63
+ return master_weights
64
+
65
+ # Calculate the amount of sharding in each dimension.
66
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
67
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
68
+
69
+ # Calculate the experts per rank.
70
+ #
71
+ # NOTE: We assign ranks to be expert parallel before going
72
+ # tensor parallel.
73
+ rank = mpu.get_expert_parallel_rank(args)
74
+ expert_rank = rank % expert_sharding_degree
75
+ num_experts_per_rank = num_experts // expert_sharding_degree
76
+ start_expert = expert_rank * num_experts_per_rank
77
+ end_expert = (expert_rank + 1) * num_experts_per_rank
78
+
79
+ # Calculate the rows per rank.
80
+ row_rank = rank // expert_sharding_degree
81
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
82
+ start_row = row_rank * num_rows_per_rank
83
+ end_row = (row_rank + 1) * num_rows_per_rank
84
+
85
+ # Slice the weight matrix to get the chunk for this rank.
86
+ with torch.no_grad():
87
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
88
+ return weights
89
+
90
+
91
+ class MLP(torch.nn.Module):
92
+
93
+ def __init__(self, args: Arguments):
94
+ super().__init__()
95
+ self.args = args
96
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
97
+ experts_per_rank = mpu.experts_per_rank(args)
98
+
99
+ self.w1 = torch.nn.Parameter(
100
+ torch.empty(
101
+ experts_per_rank,
102
+ args.hidden_size,
103
+ mpu.features_per_rank(args),
104
+ device=args.device,
105
+ dtype=common.dtype(args),
106
+ ),
107
+ )
108
+ self.w2 = torch.nn.Parameter(
109
+ torch.empty(
110
+ experts_per_rank,
111
+ mpu.features_per_rank(args),
112
+ args.hidden_size,
113
+ device=args.device,
114
+ dtype=common.dtype(args),
115
+ ),
116
+ )
117
+ mpu.set_expert_model_parallel_attributes(
118
+ self.w1,
119
+ args.moe_expert_model_parallelism,
120
+ )
121
+ mpu.set_expert_model_parallel_attributes(
122
+ self.w2,
123
+ args.moe_expert_model_parallelism,
124
+ )
125
+
126
+ # Initialize the parameters for the MLP.
127
+ #
128
+ # NOTE: It is important that we create the weight tensors prior
129
+ # to creating the master weights and slicing our the piece for
130
+ # this rank. If the master weights are created first the PyTorch
131
+ # caching allocator appears to use the same memory block for these
132
+ # and the slice which causes large increases in our peak memory
133
+ # usage.
134
+ with torch.no_grad():
135
+ w1 = create_moe_expert_weights(
136
+ args,
137
+ args.moe_num_experts,
138
+ args.ffn_hidden_size,
139
+ args.hidden_size,
140
+ args.init_method,
141
+ )
142
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
143
+ self.w2.copy_(
144
+ create_moe_expert_weights(
145
+ args,
146
+ args.moe_num_experts,
147
+ args.ffn_hidden_size,
148
+ args.hidden_size,
149
+ args.output_layer_init_method,
150
+ ),
151
+ )
152
+
153
+ self.gradient_scale = None
154
+ if self.args.moe_expert_model_parallelism:
155
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
156
+
157
+ def scale_grad(self, w):
158
+ if self.gradient_scale is None:
159
+ return w
160
+ return scale_gradient(w, self.gradient_scale)
161
+
162
+ def forward(self, x):
163
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
164
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
165
+ x = torch.bmm(x, w1)
166
+ x = self.args.activation_fn(x)
167
+ return torch.bmm(x, w2)
168
+
169
+
170
+ def create_dmoe_expert_weights(
171
+ args: Arguments,
172
+ num_experts: int,
173
+ rows: int,
174
+ columns: int,
175
+ init_method: InitFn,
176
+ ):
177
+ weights = create_moe_expert_weights(
178
+ args,
179
+ num_experts,
180
+ rows,
181
+ columns,
182
+ init_method,
183
+ )
184
+ return weights.view([-1, columns])
185
+
186
+
187
+ class MemoryOptimizedMLP(torch.autograd.Function):
188
+ """Sparse MLP with manually scheduled memory reuse."""
189
+
190
+ @staticmethod
191
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
192
+ def forward(ctx, x, w1, w2, topo, activation_fn):
193
+ # Cast inputs using ctx dtype from AMP
194
+ if ctx._fwd_used_autocast:
195
+ x = x.to(ctx._dtype)
196
+ w1 = w1.to(ctx._dtype)
197
+ w2 = w2.to(ctx._dtype)
198
+ # x: [m, k], w1: [n, k], w2: [n, k]
199
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
200
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
201
+
202
+ topo_tensors = (
203
+ topo.row_indices,
204
+ topo.column_indices,
205
+ topo.offsets,
206
+ topo.column_indices_t,
207
+ topo.offsets_t,
208
+ topo.block_offsets_t,
209
+ )
210
+
211
+ # Layer 0: x @ w1.t().
212
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
213
+
214
+ # GeLU.
215
+ activation_fn_out = act_fn(sdd_out, activation_fn)
216
+
217
+ # Layer 1: x @ w2.
218
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
219
+
220
+ # NOTE: Save the input to the layer and the activation_fn input for
221
+ # gradient computation. We'll re-compute the activation_fn forward
222
+ # pass in the backward pass to avoid materializing another
223
+ # intermediate.
224
+ ctx.shape = topo.shape
225
+ ctx.x_shape = x.shape
226
+ ctx.sdd_out_shape = sdd_out.data.shape
227
+ ctx.dtype = x.dtype
228
+ ctx.activation_fn = activation_fn
229
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
230
+ return dsd_out
231
+
232
+ @staticmethod
233
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
234
+ def backward(ctx, ddsd_out):
235
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
236
+ raise ValueError('Expected all MLP inputs to need grad.')
237
+
238
+ # unpack saved tensors
239
+ # dtype = ctx.dtype
240
+ saved_tensors = ctx.saved_tensors
241
+ w1, w2 = saved_tensors[:2]
242
+ topo_tensors = saved_tensors[2:8]
243
+ x = saved_tensors[8]
244
+ sdd_out_data = saved_tensors[9]
245
+
246
+ # rematerialize activation function output
247
+ activation_fn = ctx.activation_fn
248
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
249
+ activation_fn_out, activation_grad_fn = act_fn(
250
+ sdd_out,
251
+ activation_fn,
252
+ return_grad_fn=True,
253
+ )
254
+
255
+ # Compute dw2 with recomputed activation_fn output.
256
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
257
+
258
+ # Compute dactivation_fn_out.
259
+ #
260
+ # NOTE: We reuse the activation_fn_out allocation.
261
+ dactivation_fn_out = activation_fn_out
262
+ stk.backend.triton_kernels.sdd(
263
+ ddsd_out,
264
+ w2.t(),
265
+ dactivation_fn_out.shape,
266
+ dactivation_fn_out.data,
267
+ dactivation_fn_out.offsets,
268
+ dactivation_fn_out.row_indices,
269
+ dactivation_fn_out.column_indices,
270
+ )
271
+
272
+ # Compute dsdd_out.
273
+ #
274
+ # NOTE: This reuses the dactivation_fn_out allocation.
275
+ if activation_fn is DEFAULT_ACTIVATION_FN:
276
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
277
+ else:
278
+ assert activation_grad_fn is not None
279
+ activation_grad_fn(dactivation_fn_out.data)
280
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
281
+
282
+ # Compute dw1.
283
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
284
+
285
+ # Compute dx.
286
+ #
287
+ # NOTE: This reuses the ddsd_out allocation.
288
+ stk.backend.triton_kernels.dsd(
289
+ dsdd_out.shape,
290
+ dsdd_out.data,
291
+ dsdd_out.offsets,
292
+ dsdd_out.row_indices,
293
+ dsdd_out.column_indices,
294
+ dsdd_out.offsets_t,
295
+ dsdd_out.column_indices_t,
296
+ dsdd_out.block_offsets_t,
297
+ False,
298
+ w1,
299
+ ddsd_out,
300
+ )
301
+ dx = ddsd_out
302
+ return dx, dw1, dw2, None, None
303
+
304
+
305
+ memory_optimized_mlp = MemoryOptimizedMLP.apply
306
+
307
+
308
+ class SparseMLP(torch.nn.Module):
309
+
310
+ def __init__(self, args: Arguments):
311
+ super().__init__()
312
+ self.args = args
313
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
314
+
315
+ self.w1 = torch.nn.Parameter(
316
+ torch.empty(
317
+ self._num_rows_per_rank,
318
+ args.hidden_size,
319
+ device=args.device,
320
+ dtype=common.dtype(args),
321
+ ),
322
+ )
323
+ self.w2 = torch.nn.Parameter(
324
+ torch.empty(
325
+ self._num_rows_per_rank,
326
+ args.hidden_size,
327
+ device=args.device,
328
+ dtype=common.dtype(args),
329
+ ),
330
+ )
331
+
332
+ # Initialize the parameters for the MLP.
333
+ #
334
+ # NOTE: It is important that we create the weight tensors prior
335
+ # to creating the master weights and slicing our the piece for
336
+ # this rank. If the master weights are created first the PyTorch
337
+ # caching allocator appears to use the same memory block for these
338
+ # and the slice which causes large increases in our peak memory
339
+ # usage.
340
+ with torch.no_grad():
341
+ self.w1.copy_(
342
+ create_dmoe_expert_weights(
343
+ args,
344
+ args.moe_num_experts,
345
+ args.ffn_hidden_size,
346
+ args.hidden_size,
347
+ args.init_method,
348
+ ),
349
+ )
350
+ self.w2.copy_(
351
+ create_dmoe_expert_weights(
352
+ args,
353
+ args.moe_num_experts,
354
+ args.ffn_hidden_size,
355
+ args.hidden_size,
356
+ args.output_layer_init_method,
357
+ ),
358
+ )
359
+
360
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
361
+ mpu.set_expert_model_parallel_attributes(
362
+ self.w1,
363
+ self._should_set_parallelism_attribute,
364
+ )
365
+ mpu.set_expert_model_parallel_attributes(
366
+ self.w2,
367
+ self._should_set_parallelism_attribute,
368
+ )
369
+
370
+ self.gradient_scale = None
371
+ if self.args.moe_expert_model_parallelism:
372
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
373
+
374
+ def scale_grad(self, w):
375
+ if self.gradient_scale is None:
376
+ return w
377
+ return scale_gradient(w, self.gradient_scale)
378
+
379
+ def forward(self, x, topo):
380
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
381
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
382
+ if self.args.memory_optimized_mlp:
383
+ return memory_optimized_mlp(
384
+ x,
385
+ w1,
386
+ w2,
387
+ topo,
388
+ self.args.activation_fn,
389
+ )
390
+
391
+ # Compute the MLP.
392
+ x = stk.ops.sdd(x, w1.t(), topo)
393
+ activation_fn_out = act_fn(x, self.args.activation_fn)
394
+ return stk.ops.dsd(activation_fn_out, w2)
395
+
396
+
397
+ class MemoryOptimizedGroupedMLP(torch.autograd.Function):
398
+ """GroupedMLP with manually scheduled memory reuse."""
399
+
400
+ @staticmethod
401
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
402
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
403
+ # Cast inputs using ctx dtype from AMP
404
+ if ctx._fwd_used_autocast:
405
+ x = x.to(ctx._dtype)
406
+ w1 = w1.to(ctx._dtype)
407
+ w2 = w2.to(ctx._dtype)
408
+ # x: [m, k], w1: [n, k], w2: [n, k]
409
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
410
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
411
+
412
+ # Layer 0: x @ w1.t().
413
+ assert gg.backend is not None
414
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
415
+
416
+ # activation_fn
417
+ activation_fn_out = activation_fn(sdd_out)
418
+
419
+ # Layer 1: x @ w2.
420
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
421
+
422
+ # NOTE: Save the input to the layer and the activation_fn input for
423
+ # gradient computation. We'll re-compute the activation_fn forward
424
+ # pass in the backward pass to avoid materializing another
425
+ # intermediate.
426
+ ctx.x_shape = x.shape
427
+ ctx.sdd_out_shape = sdd_out.shape
428
+ ctx.dtype = x.dtype
429
+ ctx.activation_fn = activation_fn
430
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
431
+ return dsd_out
432
+
433
+ @staticmethod
434
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
435
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
436
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
437
+ raise ValueError('Expected all MLP inputs to need grad.')
438
+
439
+ # Unpack saved tensors
440
+ # dtype = ctx.dtype
441
+ saved_tensors = ctx.saved_tensors
442
+ w1, w2 = saved_tensors[:2]
443
+ batch_sizes = saved_tensors[2]
444
+ x = saved_tensors[3]
445
+ sdd_out = saved_tensors[4]
446
+
447
+ # Rematerialize activation_fn output.
448
+ activation_fn = ctx.activation_fn
449
+ with torch.set_grad_enabled(True):
450
+ sdd_out.requires_grad = True
451
+ activation_fn_out = activation_fn(sdd_out)
452
+ activation_grad_fn = activation_fn_out.backward
453
+
454
+ # Compute dw2 with recomputed activation_fn output.
455
+ assert gg.backend is not None
456
+ dw2 = gg.backend.gmm(
457
+ activation_fn_out,
458
+ ddsd_out,
459
+ batch_sizes,
460
+ trans_a=True,
461
+ )
462
+
463
+ # Compute dactivation_fn_out.
464
+ #
465
+ # NOTE: We reuse the activation_fn_out allocation.
466
+ dactivation_fn_out = activation_fn_out
467
+ gg.backend.gmm(
468
+ ddsd_out,
469
+ w2,
470
+ batch_sizes,
471
+ trans_b=True,
472
+ c=dactivation_fn_out,
473
+ )
474
+
475
+ # Compute dsdd_out.
476
+ #
477
+ # NOTE: This reuses the dactivation_fn_out allocation.
478
+ if activation_fn is DEFAULT_ACTIVATION_FN:
479
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
480
+ else:
481
+ assert activation_grad_fn is not None
482
+ activation_grad_fn(dactivation_fn_out)
483
+ dsdd_out = sdd_out.grad
484
+
485
+ # Compute dw1.
486
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
487
+
488
+ # Compute dx.
489
+ #
490
+ # NOTE: This reuses the ddsd_out allocation.
491
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
492
+ dx = ddsd_out
493
+ return dx, dw1, dw2, None, None
494
+
495
+
496
+ memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
497
+
498
+
499
+ class GroupedMLP(SparseMLP):
500
+
501
+ def forward(self, x, tokens_per_expert):
502
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
503
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
504
+
505
+ # Re-shape the weights for the grouped GEMMs.
506
+ ne = mpu.experts_per_rank(self.args)
507
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
508
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
509
+
510
+ if self.args.memory_optimized_mlp:
511
+ return memory_optimized_grouped_mlp(
512
+ x,
513
+ w1,
514
+ w2,
515
+ batch_sizes,
516
+ self.args.activation_fn,
517
+ )
518
+
519
+ # Compute the MLP.
520
+ assert gg.ops is not None
521
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
522
+ x = self.args.activation_fn(x)
523
+ return gg.ops.gmm(x, w2, batch_sizes)
524
+
525
+
526
+ class SharedMLP(torch.nn.Module):
527
+ """MLP for shared expert.
528
+
529
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
530
+ """
531
+
532
+ def __init__(self, args: Arguments):
533
+ super().__init__()
534
+ self.args = args
535
+ self.fc_kwargs: dict[str, Any] = {
536
+ 'bias': args.bias,
537
+ 'device': args.device,
538
+ }
539
+ self.fc_kwargs.update(args.fc_kwargs)
540
+
541
+ self.up_proj = args.fc_cls(
542
+ args.hidden_size,
543
+ args.shared_expert_hidden_size,
544
+ **self.fc_kwargs,
545
+ )
546
+ self.act = args.activation_fn
547
+ self.down_proj = args.fc_cls(
548
+ args.shared_expert_hidden_size,
549
+ args.hidden_size,
550
+ **self.fc_kwargs,
551
+ )
552
+ self.down_proj._is_residual = True # a flag for llm-foundry init
553
+
554
+ def add_experts_sharedexpert(
555
+ self,
556
+ shared_expert_out: torch.Tensor,
557
+ expert_out: torch.Tensor,
558
+ ) -> torch.Tensor:
559
+ # Helper function to add expert output to shared expert output
560
+ # with optional weighted sum.
561
+ if self.args.shared_expert_weighted_sum:
562
+ # enable using weighted sum for shared expert output
563
+ # wieghted by number of experts used
564
+ t_experts = self.args.moe_top_k + 1
565
+ sh_mlp_out = shared_expert_out / t_experts
566
+ return sh_mlp_out.add(
567
+ expert_out,
568
+ alpha=(self.args.moe_top_k / t_experts),
569
+ )
570
+
571
+ return shared_expert_out + expert_out
572
+
573
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
574
+ return self.down_proj(self.act(self.up_proj(x)))
torch-ext/megablocks/layers/moe.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ import megablocks.ops as ops
10
+ from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
+ from megablocks.layers.all_to_all import all_to_all
12
+ from megablocks.layers.arguments import Arguments
13
+
14
+ _LOAD_BALANCING_LOSS = []
15
+
16
+
17
+ def save_load_balancing_loss(loss):
18
+ global _LOAD_BALANCING_LOSS
19
+ _LOAD_BALANCING_LOSS.append(loss)
20
+
21
+
22
+ def get_load_balancing_loss():
23
+ global _LOAD_BALANCING_LOSS
24
+ return _LOAD_BALANCING_LOSS
25
+
26
+
27
+ def clear_load_balancing_loss():
28
+ global _LOAD_BALANCING_LOSS
29
+ _LOAD_BALANCING_LOSS.clear()
30
+
31
+
32
+ def batched_load_balancing_loss(args: Arguments):
33
+ if args.moe_loss_weight == 0:
34
+ return 0.0
35
+
36
+ # tokens_per_expert[i].shape = (num_experts)
37
+ # expert_scores[i].shape = (tokens, num_experts)
38
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
39
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
40
+ if args.num_layers_per_virtual_pipeline_stage is not None:
41
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
42
+
43
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
44
+ raise ValueError(
45
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
46
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
47
+ f'{args.num_layers}\npipeline_model_parallel_size = '
48
+ f'{args.pipeline_model_parallel_size}\n'
49
+ 'num_layers_per_virtual_pipeline_stage'
50
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
51
+ )
52
+ if len(expert_scores) != num_layers_per_pipeline_stage:
53
+ raise ValueError(
54
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
55
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
56
+ f'{args.num_layers}\npipeline_model_parallel_size = '
57
+ f'{args.pipeline_model_parallel_size}\n'
58
+ 'num_layers_per_virtual_pipeline_stage'
59
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
60
+ )
61
+
62
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
63
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
64
+
65
+ tokens = expert_scores[0].shape[0]
66
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
67
+
68
+ # Concatenate the contributions of each layer and convert to
69
+ # the correct types and formats for the dot product.
70
+ expert_scores = torch.cat(expert_scores, dim=1)
71
+ if args.moe_lbl_in_fp32:
72
+ expert_scores = expert_scores.float()
73
+ if tokens != 0:
74
+ expert_scores = expert_scores.mean(dim=0)
75
+ else:
76
+ expert_scores = expert_scores.sum(dim=0)
77
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
78
+
79
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
80
+ assert tokens_per_expert.numel() == expected_values
81
+ assert expert_scores.numel() == expected_values
82
+
83
+ # Calculate the total scale across all factors.
84
+ #
85
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
86
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
87
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
88
+ scale = scale_numerator / scale_denominator
89
+ return scale * torch.dot(tokens_per_expert, expert_scores)
90
+
91
+
92
+ # NOTE: This class defines MoE expert computation, including expert model parallel
93
+ # communication. When using FSDP on top of MegaBlocks this is the module that should
94
+ # be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
95
+ # parallel all2all.
96
+ class ParallelMLP(torch.nn.Module):
97
+
98
+ def __init__(self, args: Arguments):
99
+ super(ParallelMLP, self).__init__()
100
+ self.args = args
101
+
102
+ # Calculate the number of experts in total and the number of experts
103
+ # owned by this rank.
104
+ # world_size = mpu.get_expert_parallel_world_size(args)
105
+ self.num_experts = args.moe_num_experts
106
+ self.top_k = self.args.moe_top_k
107
+
108
+ # Calculate the number of bits needed to represent the expert indices
109
+ # so that we can pass it to radix sort.
110
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
111
+
112
+ # Expert MLP.
113
+ self.mlp = mlp.MLP(args)
114
+
115
+ self.bias: Optional[torch.Tensor]
116
+ if self.args.bias:
117
+ # Note that the output bias is not parallelized with expert
118
+ # model parallelism.
119
+ self.bias = torch.nn.Parameter(
120
+ torch.empty(
121
+ args.hidden_size,
122
+ device=args.device,
123
+ dtype=common.dtype(args),
124
+ ),
125
+ )
126
+ torch.nn.init.zeros_(self.bias)
127
+ else:
128
+ self.register_parameter('bias', None)
129
+
130
+ # Select the forward function for the operating mode.
131
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
132
+
133
+ def expert_capacity(self, tokens: int) -> int:
134
+ world_size = mpu.get_expert_parallel_world_size(self.args)
135
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
136
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
137
+
138
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
139
+ """Calculate the load balancing loss contribution."""
140
+ assert len(expert_scores.size()) == 2
141
+ tokens, num_experts = expert_scores.size()
142
+ assert num_experts == self.num_experts
143
+ assert len(tokens_per_expert.size()) == 1
144
+ num_experts, = tokens_per_expert.size()
145
+ assert num_experts == self.num_experts
146
+ scale = self.num_experts / (tokens * self.top_k)
147
+ return scale * torch.dot(
148
+ tokens_per_expert.to(expert_scores.dtype),
149
+ expert_scores.mean(dim=0),
150
+ )
151
+
152
+ def indices_and_bins(self,
153
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
154
+ # Sort the expert ids to produce the scatter/gather
155
+ # indices for the permutation.
156
+ #
157
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
158
+ # prior? Could we place the `torch.max` operation to return
159
+ # 32-bit expert indices?
160
+ top_expert = top_expert.int()
161
+ output = ops.sort(top_expert, self.sort_end_bit)
162
+ assert output is not None
163
+ bin_ids, indices = output
164
+
165
+ # Histogram the expert ids to identify the number of
166
+ # tokens routed to each expert.
167
+ #
168
+ # TODO(tgale): Does the sorted data produce a more favorable
169
+ # data distribution for histogram? Or is the op parallelism
170
+ # worth more?
171
+ tokens_per_expert = ops.histogram(top_expert, self.num_experts)
172
+
173
+ # Calculate the bin bounds for the sorted tokens.
174
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
175
+ assert bins is not None
176
+ bins = bins.view(1) if not len(bins.size()) else bins
177
+
178
+ assert isinstance(indices, torch.Tensor)
179
+ assert isinstance(bin_ids, torch.Tensor)
180
+ assert isinstance(bins, torch.Tensor)
181
+ assert isinstance(tokens_per_expert, torch.Tensor)
182
+
183
+ return indices, bin_ids, bins, tokens_per_expert
184
+
185
+ def permute_and_compute(
186
+ self,
187
+ x: torch.Tensor,
188
+ tokens_per_expert: int, # unused
189
+ indices: torch.Tensor,
190
+ bin_ids: torch.Tensor, # unused
191
+ expert_weights: torch.Tensor,
192
+ bins: torch.Tensor,
193
+ expert_capacity: int,
194
+ top_k: int,
195
+ ):
196
+ # Route the tokens for MoE computation.
197
+ x = x.view(-1, x.shape[-1])
198
+ output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
199
+ assert output is not None
200
+ x = output
201
+
202
+ # Perform the expert computation. Note that we don't
203
+ # use biases for these linear operations.
204
+ x = self.mlp(x)
205
+
206
+ # Un-route the data for the MoE output.
207
+ return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
208
+
209
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
210
+ # x: [sl, bs, hs]
211
+ # expert_weights: [sl * bs, top-k]
212
+ # top_experts: [sl * bs, top-k]
213
+ expert_weights = expert_weights.flatten()
214
+ top_experts = top_experts.flatten()
215
+ with torch.no_grad():
216
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
217
+
218
+ # If expert_capacity is set to zero, set the number of tokens
219
+ # per expert to the maximum we need to avoid dropping tokens.
220
+ sl, bs, _ = x.size()
221
+ expert_capacity = self.expert_capacity(sl * bs)
222
+ if expert_capacity == 0:
223
+ expert_capacity = torch.max(tokens_per_expert).item()
224
+
225
+ x = self.permute_and_compute(
226
+ x,
227
+ tokens_per_expert,
228
+ indices,
229
+ bin_ids,
230
+ expert_weights,
231
+ bins,
232
+ expert_capacity,
233
+ self.top_k,
234
+ )
235
+ return x, tokens_per_expert
236
+
237
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
238
+ # NOTE: This function implements the same computation as forward_once
239
+ # but with expert model parallelism.
240
+ #
241
+ # 1. Permute the tokens locally so that they are grouped by their
242
+ # expert assignments. This allows us to transfer all of the tokens
243
+ # for a remote device in one communication primitive.
244
+ #
245
+ # 2. Permute the tokens across the expert parallel devices. After
246
+ # this is completed each device has all of the tokens assigned to
247
+ # its set of experts in its local HBM.
248
+ #
249
+ # 3. Permute the tokens locally so that they are grouped by their
250
+ # expert assignement. After the distributed permutation the tokens
251
+ # are grouped by which device they came from. We re-order them
252
+ # locally to allow for efficient computation.
253
+ #
254
+ # After this series of permutations we compute the linear layers
255
+ # and then repeat these three steps in reverse to produce the final
256
+ # output.
257
+ #
258
+ # Compute the mapping of local tokens to experts.
259
+ expert_weights = expert_weights.flatten()
260
+ top_experts = top_experts.flatten()
261
+ with torch.no_grad():
262
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
263
+
264
+ # If we're sharding the experts along the hidden dimension
265
+ # multiple devices own parts of the same sets of experts.
266
+ # Replicate the token counts so every device gets the counts.
267
+ repeated_tokens_per_expert = ops.repeat(
268
+ tokens_per_expert,
269
+ (mpu.hidden_sharding_degree(self.args),),
270
+ )
271
+
272
+ # Pass token count information to the device on which the
273
+ # target expert resides.
274
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
275
+ tpe_handle = dist.all_to_all_single(
276
+ parallel_tokens_per_expert,
277
+ repeated_tokens_per_expert,
278
+ group=self.args.expert_parallel_group,
279
+ async_op=True,
280
+ )
281
+
282
+ # Permute locally and without any padding so that tokens for each
283
+ # parallel device are stored contiguously.
284
+ #
285
+ # This view updates the shape of the tensor from [sl, bs, hs] to
286
+ # [sl * bs, hs] prior to the permutation.
287
+ x = x.view(-1, x.shape[-1])
288
+ output = ops.gather(x, indices, bin_ids, bins, self.top_k)
289
+ assert output is not None
290
+ x = output
291
+
292
+ # Compute the number of tokens that will be received from each
293
+ # device and permute the input data across the devices.
294
+ with torch.no_grad():
295
+ tpe_handle.wait()
296
+ experts_per_rank = mpu.experts_per_rank(self.args)
297
+
298
+ # Reshape to [world_size, num_experts_per_rank].
299
+ world_size = mpu.get_expert_parallel_world_size(self.args)
300
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
301
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
302
+
303
+ # TODO(tgale): It might be faster to do this on the GPU and
304
+ # then communicate the results back to the host.
305
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
306
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
307
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
308
+
309
+ # Convert the send/recv counts to lists.
310
+ send_counts = send_counts.tolist()
311
+ recv_counts = recv_counts.tolist()
312
+ tokens_received = sum(recv_counts)
313
+
314
+ # If we're sharding the experts along the hidden dimension
315
+ # multiple devices own parts of the same sets of experts.
316
+ # Replicate the token counts so devices that share experts
317
+ # get all of the tokens assigned to them.
318
+ #
319
+ # TODO(tgale): Fuse this into the prior, local permutation.
320
+ x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
321
+
322
+ # Start the cross-device permutation asynchronously so we can
323
+ # overlap communication with computation.
324
+ parallel_x, parallel_x_handle = all_to_all(
325
+ x,
326
+ recv_counts,
327
+ send_counts,
328
+ self.args.expert_parallel_group,
329
+ async_op=True,
330
+ )
331
+
332
+ with torch.no_grad():
333
+ # After we do the cross-device permutation we have the tokens on the
334
+ # correct device but not yet grouped by expert because we received
335
+ # tokens from each device as contiguous chunks. To group the tokens
336
+ # for expert computation we'll do one more local permutation. The
337
+ # rest of this torch.no_grad() scope sets up the indices and bins
338
+ # for this permutation.
339
+ replicate_bins = ops.inclusive_cumsum(
340
+ parallel_tokens_per_expert.flatten(),
341
+ 0,
342
+ )
343
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
344
+
345
+ # Construct the expert indices for the permuted tokens.
346
+ parallel_top_expert = torch.remainder(
347
+ torch.arange(
348
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
349
+ dtype=torch.int32,
350
+ device=indices.device,
351
+ ),
352
+ mpu.experts_per_rank(self.args),
353
+ )
354
+ parallel_top_expert = ops.replicate(
355
+ parallel_top_expert.unsqueeze(dim=0),
356
+ replicate_bins,
357
+ tokens_received,
358
+ ).flatten()
359
+
360
+ # TODO(tgale): The sort_end_bit here can be reduced.
361
+ parallel_bin_ids, parallel_indices = ops.sort(
362
+ parallel_top_expert,
363
+ self.sort_end_bit,
364
+ )
365
+
366
+ # Calculate the bins boundaries from the token counts.
367
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
368
+ dim=0,
369
+ dtype=torch.int,
370
+ )
371
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
372
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
373
+
374
+ # If expert_capacity is set to zero, set the number of tokens
375
+ # per expert to the maximum we need to avoid dropping tokens.
376
+ tokens, _ = x.size()
377
+ expert_capacity = self.expert_capacity(tokens)
378
+ if expert_capacity == 0:
379
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
380
+
381
+ # Locally permute the tokens and perform the expert computation.
382
+ # Block to make sure that the cross-device permutation is complete.
383
+ if self.args.mlp_impl == 'grouped':
384
+ # GroupedMLP requires counts on CPU. We can use the tensor already
385
+ # moved to CPU for the prior all_to_all, which avoids an extra
386
+ # device synchronization.
387
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
388
+ dim=0,
389
+ dtype=torch.int,
390
+ )
391
+ parallel_x_handle.wait()
392
+ parallel_x = self.permute_and_compute(
393
+ parallel_x,
394
+ parallel_tokens_per_expert,
395
+ parallel_indices,
396
+ parallel_bin_ids,
397
+ None, # expert_weights
398
+ parallel_bins,
399
+ expert_capacity,
400
+ top_k=1,
401
+ )
402
+
403
+ # Un-permute the tokens across the devices.
404
+ x, _ = all_to_all(
405
+ parallel_x,
406
+ send_counts,
407
+ recv_counts,
408
+ self.args.expert_parallel_group,
409
+ )
410
+
411
+ # Reduce along the hidden sharding to get the final outputs.
412
+ #
413
+ # TODO(tgale): Fuse this into the following local permutation.
414
+ shape = (
415
+ mpu.hidden_sharding_degree(self.args),
416
+ -1,
417
+ self.args.hidden_size,
418
+ )
419
+ x = ops.sum(x.view(shape), dim=0)
420
+
421
+ # Un-permute locally to setup for the next series of operations.
422
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
423
+ return x, tokens_per_expert.flatten()
424
+
425
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
426
+ in_shape = x.size()
427
+
428
+ # Compute the experts.
429
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
430
+ if self.training and self.args.moe_loss_weight > 0:
431
+ save_load_balancing_loss((tokens_per_expert, scores))
432
+ x = x.view(in_shape)
433
+ if self.bias is not None:
434
+ if self.args.return_bias:
435
+ return x, self.bias
436
+ return x + self.bias
437
+ return x
438
+
439
+
440
+ class MoE(torch.nn.Module):
441
+
442
+ def __init__(self, args: Arguments):
443
+ super(MoE, self).__init__()
444
+
445
+ # Token router.
446
+ self.router = router.LearnedRouter(args)
447
+
448
+ # Expert computation helper.
449
+ self.experts = self._init_experts_mlp(args)
450
+
451
+ self.shared_expert = None
452
+ if args.shared_expert:
453
+ # SharedExpert computation helper.
454
+ self.shared_expert = sharedexpert_registry.get(args)
455
+
456
+ def _init_experts_mlp(self, args: Arguments):
457
+ return ParallelMLP(args)
458
+
459
+ def forward(self, x: torch.Tensor):
460
+ # NOTE: If we're going to cast the activations to lower precision
461
+ # do it before we permute the tokens to save bandwidth.
462
+ x = common.cast_if_autocast_enabled(x)
463
+
464
+ # Compute the expert scores and assignments.
465
+ scores, expert_weights, top_experts = self.router(x)
466
+
467
+ # Compute the experts.
468
+ out = self.experts(x, scores, expert_weights, top_experts)
469
+ if self.shared_expert is not None:
470
+ shared_expert_out = self.shared_expert(x)
471
+ out = self.shared_expert.add_experts_sharedexpert(
472
+ shared_expert_out,
473
+ out,
474
+ )
475
+ return out
torch-ext/megablocks/layers/mpu.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ from megablocks.layers.arguments import Arguments
10
+
11
+
12
+ class MoeParam(torch.Tensor):
13
+
14
+ def __init__(self):
15
+ super().__init__(self)
16
+ self.expert_model_parallel: bool
17
+
18
+
19
+ def is_moe_param(tensor: torch.Tensor) -> bool:
20
+ return hasattr(tensor, 'expert_model_parallel')
21
+
22
+
23
+ def get_expert_parallel_world_size(args: Arguments) -> int:
24
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
25
+
26
+
27
+ def get_expert_parallel_rank(args: Arguments) -> int:
28
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
29
+
30
+
31
+ def set_expert_model_parallel_attributes(
32
+ tensor: torch.Tensor,
33
+ is_parallel: bool,
34
+ ):
35
+ assert not hasattr(tensor, 'expert_model_parallel')
36
+ setattr(tensor, 'expert_model_parallel', is_parallel)
37
+
38
+
39
+ def param_is_expert_model_parallel(param: MoeParam) -> bool:
40
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
41
+
42
+
43
+ def copy_expert_model_parallel_attributes(
44
+ destination_tensor: torch.Tensor,
45
+ source_tensor: torch.Tensor,
46
+ ):
47
+ if hasattr(source_tensor, 'expert_model_parallel'):
48
+ setattr(
49
+ destination_tensor,
50
+ 'expert_model_parallel',
51
+ getattr(source_tensor, 'expert_model_parallel'),
52
+ )
53
+
54
+
55
+ def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
56
+ world_size = dist.get_world_size(group)
57
+ rank = dist.get_rank(group)
58
+ for i in range(world_size):
59
+ dist.barrier(group)
60
+ if i == rank:
61
+ print(f'rank = {rank}', *x)
62
+
63
+
64
+ # Helpers for expert/tensor sharding.
65
+ def expert_sharding_degree(args: Arguments) -> int:
66
+ world_size = get_expert_parallel_world_size(args)
67
+ esd = min(world_size, args.moe_num_experts)
68
+
69
+ if (args.moe_num_experts % esd) != 0:
70
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
71
+ return esd
72
+
73
+
74
+ def hidden_sharding_degree(args: Arguments) -> int:
75
+ world_size = get_expert_parallel_world_size(args)
76
+ esd = expert_sharding_degree(args)
77
+ hsd = world_size // esd
78
+
79
+ if (args.ffn_hidden_size % hsd) != 0:
80
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
81
+ if (esd * hsd) != world_size:
82
+ raise ValueError(
83
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
84
+ )
85
+ return hsd
86
+
87
+
88
+ def experts_per_rank(args: Arguments) -> int:
89
+ return args.moe_num_experts // expert_sharding_degree(args)
90
+
91
+
92
+ def features_per_rank(args: Arguments) -> int:
93
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
torch-ext/megablocks/layers/router.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from megablocks.layers import common
8
+ from megablocks.layers.arguments import Arguments
9
+
10
+ _ROUTER_LOGITS = []
11
+
12
+
13
+ def _save_router_logits(logits: torch.Tensor, args: Arguments):
14
+ if args.moe_zloss_weight == 0:
15
+ return
16
+ global _ROUTER_LOGITS
17
+ _ROUTER_LOGITS.append(logits)
18
+
19
+
20
+ def clear_router_zloss():
21
+ global _ROUTER_LOGITS
22
+ _ROUTER_LOGITS.clear()
23
+
24
+
25
+ def batched_router_zloss(args: Arguments):
26
+ global _ROUTER_LOGITS
27
+
28
+ if args.moe_zloss_weight == 0:
29
+ import warnings
30
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
31
+ return 0
32
+
33
+ logits_per_router = _ROUTER_LOGITS
34
+
35
+ if args.moe_zloss_in_fp32:
36
+ logits_per_router = [logits.float() for logits in logits_per_router]
37
+
38
+ unscaled_zloss_per_router = torch.stack([
39
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
40
+ ])
41
+
42
+ return args.moe_zloss_weight * unscaled_zloss_per_router
43
+
44
+
45
+ # NOTE: To enable end-to-end benchmarking without convergence we
46
+ # support a flag to force the router to assign tokens uniformly
47
+ # across the experts. We do this with a custom autograd operation
48
+ # so that PyTorch still executes the full set of router operation.
49
+ class _UniformExpertAssignment(torch.autograd.Function):
50
+
51
+ @staticmethod
52
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
53
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
54
+ out = torch.remainder(out, num_experts)
55
+ return out.view(x.shape)
56
+
57
+
58
+ _uniform_expert_assignment = _UniformExpertAssignment.apply
59
+
60
+
61
+ class LearnedRouter(torch.nn.Module):
62
+
63
+ def __init__(self, args: Arguments):
64
+ super().__init__()
65
+ self.args = args
66
+
67
+ # Learned router parameters.
68
+ #
69
+ # NOTE: This weight matrix is not parallelized with expert model
70
+ # parallelism. Each device needs the entire router weight matrix
71
+ # so that it can route its batch of data correctly.
72
+ self.layer = torch.nn.Linear(
73
+ args.hidden_size,
74
+ args.moe_num_experts,
75
+ bias=False,
76
+ dtype=common.dtype(args),
77
+ device=args.device,
78
+ )
79
+ args.init_method(self.layer.weight)
80
+
81
+ def jitter(self, x: torch.Tensor):
82
+ low: float = 1.0 - self.args.moe_jitter_eps
83
+ high: float = 1.0 + self.args.moe_jitter_eps
84
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
85
+ return low + noise * (high - low)
86
+
87
+ def _top_k(self, scores: torch.Tensor):
88
+ if self.args.moe_top_k == 1:
89
+ return scores.max(dim=-1, keepdim=True)
90
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
91
+
92
+ def forward(self, x: torch.Tensor):
93
+ if self.training and self.args.moe_jitter_eps is not None:
94
+ x = x * self.jitter(x)
95
+
96
+ logits = self.layer(x.view(-1, x.shape[-1]))
97
+ _save_router_logits(logits, self.args)
98
+ scores = logits.softmax(dim=-1)
99
+ expert_weights, expert_indices = self._top_k(scores)
100
+ if self.args.moe_normalize_expert_weights:
101
+ expert_weights = expert_weights / torch.norm(
102
+ expert_weights,
103
+ p=self.args.moe_normalize_expert_weights,
104
+ dim=-1,
105
+ keepdim=True,
106
+ )
107
+
108
+ expert_indices = (
109
+ _uniform_expert_assignment(
110
+ expert_indices,
111
+ self.args.moe_num_experts,
112
+ ) if self.args.uniform_expert_assignment else expert_indices
113
+ )
114
+ return scores, expert_weights, expert_indices
torch-ext/megablocks/layers/sharedexpert_registry.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ from megablocks.layers import glu, mlp
7
+ from megablocks.layers.arguments import Arguments
8
+
9
+ _REGISTRY = {
10
+ 'mlp': mlp.SharedMLP,
11
+ 'glu': glu.SharedGLU,
12
+ }
13
+
14
+
15
+ def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
16
+ """Returns an SharedMLP for use in a dMoE instance.
17
+
18
+ Uses the provided arguments to instantiate the appropriate
19
+ SharedMLP instance.
20
+
21
+ Args:
22
+ args: propagated Arguments dataclass.
23
+
24
+ Returns:
25
+ An instantiated SharedMLP constructed using the input args.
26
+ """
27
+ if args.mlp_type not in _REGISTRY:
28
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
29
+
30
+ return _REGISTRY[args.mlp_type](args)
torch-ext/megablocks/ops/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from megablocks.ops.binned_gather import binned_gather
5
+ from megablocks.ops.binned_scatter import binned_scatter
6
+ from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum
7
+ from megablocks.ops.gather import gather
8
+ from megablocks.ops.histogram import histogram
9
+ from megablocks.ops.padded_gather import padded_gather
10
+ from megablocks.ops.padded_scatter import padded_scatter
11
+ from megablocks.ops.repeat import repeat
12
+ from megablocks.ops.replicate import replicate
13
+ from megablocks.ops.round_up import round_up
14
+ from megablocks.ops.scatter import scatter
15
+ from megablocks.ops.sort import sort
16
+ from megablocks.ops.sum import sum
17
+ from megablocks.ops.topology import topology
18
+
19
+ __all__ = [
20
+ 'binned_gather',
21
+ 'binned_scatter',
22
+ 'exclusive_cumsum',
23
+ 'inclusive_cumsum',
24
+ 'gather',
25
+ 'histogram',
26
+ 'padded_gather',
27
+ 'padded_scatter',
28
+ 'repeat',
29
+ 'replicate',
30
+ 'round_up',
31
+ 'scatter',
32
+ 'sort',
33
+ 'sum',
34
+ 'topology',
35
+ ]
torch-ext/megablocks/ops/all_to_all_benchmark.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ from megablocks import benchmark_util
8
+ from megablocks.layers.all_to_all import all_to_all
9
+
10
+ _ALL_TO_ALL_BENCHMARK = (
11
+ (8, 1024),
12
+ (16, 1024),
13
+ (32, 1024),
14
+ (64, 1024),
15
+ (128, 1024),
16
+ (256, 1024),
17
+ (512, 1024),
18
+ (1024, 1024),
19
+ (2 * 1024, 1024),
20
+ (4 * 1024, 1024),
21
+ (8 * 1024, 1024),
22
+ (16 * 1024, 1024),
23
+ (32 * 1024, 1024),
24
+ (64 * 1024, 1024),
25
+ (128 * 1024, 1024),
26
+ (256 * 1024, 1024),
27
+ (512 * 1024, 1024),
28
+ (1024 * 1024, 1024),
29
+ )
30
+
31
+
32
+ def benchmark_all_to_all(group, sl, hs):
33
+ world_size = dist.get_world_size(group)
34
+ assert (sl % world_size) == 0
35
+ send_recv_sizes = [sl // world_size] * world_size
36
+
37
+ x = torch.randn((sl, hs)).cuda().half()
38
+
39
+ details = {
40
+ 'world_size': world_size,
41
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
42
+ }
43
+
44
+ def benchmark():
45
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
46
+
47
+ time, std = benchmark_util.benchmark_function(benchmark)
48
+
49
+ if dist.get_rank(group) == 0:
50
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ assert dist.is_available()
55
+ group = dist.init_process_group(backend='nccl')
56
+ local_rank = dist.get_rank(group)
57
+ torch.cuda.set_device(local_rank)
58
+
59
+ for args in _ALL_TO_ALL_BENCHMARK:
60
+ benchmark_all_to_all(group, *args)