drbh
commited on
Commit
·
2595c46
0
Parent(s):
feat: initial port of megablocks to builder format
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +5 -0
- README.md +6 -0
- build.toml +30 -0
- csrc/bak.ops.cu +21 -0
- csrc/cuda_util.h +62 -0
- csrc/cumsum.h +163 -0
- csrc/histogram.h +86 -0
- csrc/indices.h +95 -0
- csrc/new_cumsum.cu +161 -0
- csrc/new_cumsum.h +11 -0
- csrc/new_histogram.cu +85 -0
- csrc/new_histogram.h +10 -0
- csrc/new_indices.cu +97 -0
- csrc/new_indices.h +14 -0
- csrc/new_replicate.cu +210 -0
- csrc/new_replicate.h +17 -0
- csrc/new_sort.cu +90 -0
- csrc/new_sort.h +13 -0
- csrc/replicate.h +211 -0
- csrc/sort.h +91 -0
- flake.lock +164 -0
- flake.nix +18 -0
- tests/__init__.py +0 -0
- tests/test_mb_moe.py +6 -0
- torch-ext/megablocks/__init__.py +191 -0
- torch-ext/megablocks/_version.py +6 -0
- torch-ext/megablocks/backend/__init__.py +2 -0
- torch-ext/megablocks/backend/kernels.py +543 -0
- torch-ext/megablocks/bak.__init__.py +23 -0
- torch-ext/megablocks/benchmark_util.py +35 -0
- torch-ext/megablocks/grouped_gemm_util.py +26 -0
- torch-ext/megablocks/layers/__init__.py +10 -0
- torch-ext/megablocks/layers/activation_fn.py +33 -0
- torch-ext/megablocks/layers/all_to_all.py +54 -0
- torch-ext/megablocks/layers/arguments.py +100 -0
- torch-ext/megablocks/layers/common.py +26 -0
- torch-ext/megablocks/layers/dmlp_registry.py +42 -0
- torch-ext/megablocks/layers/dmoe.py +327 -0
- torch-ext/megablocks/layers/gelu.py +43 -0
- torch-ext/megablocks/layers/glu.py +223 -0
- torch-ext/megablocks/layers/memory_test.py +102 -0
- torch-ext/megablocks/layers/memory_test.sh +12 -0
- torch-ext/megablocks/layers/mlp.py +574 -0
- torch-ext/megablocks/layers/moe.py +475 -0
- torch-ext/megablocks/layers/mpu.py +93 -0
- torch-ext/megablocks/layers/router.py +114 -0
- torch-ext/megablocks/layers/sharedexpert_registry.py +30 -0
- torch-ext/megablocks/ops/__init__.py +35 -0
- torch-ext/megablocks/ops/all_to_all_benchmark.py +60 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
__pycache__
|
3 |
+
.bak
|
4 |
+
megablocks-moe/.bak
|
5 |
+
.pytest_cache
|
README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- kernel
|
5 |
+
---
|
6 |
+
|
build.toml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "megablocks"
|
3 |
+
universal = false
|
4 |
+
|
5 |
+
[torch]
|
6 |
+
src = [
|
7 |
+
"torch-ext/torch_binding.cpp",
|
8 |
+
"torch-ext/torch_binding.h"
|
9 |
+
]
|
10 |
+
|
11 |
+
[kernel.megablocks]
|
12 |
+
backend = "cuda"
|
13 |
+
src = [
|
14 |
+
"csrc/new_cumsum.h",
|
15 |
+
"csrc/new_cumsum.cu",
|
16 |
+
"csrc/new_histogram.h",
|
17 |
+
"csrc/new_histogram.cu",
|
18 |
+
"csrc/new_indices.h",
|
19 |
+
"csrc/new_indices.cu",
|
20 |
+
"csrc/new_replicate.cu",
|
21 |
+
"csrc/new_replicate.h",
|
22 |
+
"csrc/new_sort.h",
|
23 |
+
"csrc/new_sort.cu",
|
24 |
+
]
|
25 |
+
depends = [ "torch", "cutlass_3_8" ]
|
26 |
+
|
27 |
+
[test]
|
28 |
+
python-git-packages = [
|
29 |
+
{ url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" }
|
30 |
+
]
|
csrc/bak.ops.cu
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "cumsum.h"
|
2 |
+
#include "histogram.h"
|
3 |
+
#include "indices.h"
|
4 |
+
#include "replicate.h"
|
5 |
+
#include "sort.h"
|
6 |
+
|
7 |
+
#include <torch/extension.h>
|
8 |
+
|
9 |
+
namespace megablocks {
|
10 |
+
|
11 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
12 |
+
m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
|
13 |
+
m.def("histogram", &histogram, "even width histogram.");
|
14 |
+
m.def("inclusive_cumsum", &inclusive_cumsum, "batched inclusive cumsum");
|
15 |
+
m.def("indices", &indices, "indices construction for sparse matrix.");
|
16 |
+
m.def("replicate_forward", &replicate_forward, "(fwd) replicate a vector dynamically.");
|
17 |
+
m.def("replicate_backward", &replicate_backward, "(bwd) replicate a vector dynamically.");
|
18 |
+
m.def("sort", &sort, "key/value sort.");
|
19 |
+
}
|
20 |
+
|
21 |
+
} // namespace megablocks
|
csrc/cuda_util.h
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef BLOCKPARTY_CSRC_CUDA_UTIL_H_
|
2 |
+
#define BLOCKPARTY_CSRC_CUDA_UTIL_H_
|
3 |
+
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
// #include <torch/extension.h>
|
7 |
+
|
8 |
+
namespace megablocks {
|
9 |
+
|
10 |
+
typedef __half2 half2;
|
11 |
+
|
12 |
+
struct __align__(8) half4 {
|
13 |
+
half2 x, y;
|
14 |
+
};
|
15 |
+
|
16 |
+
struct __align__(16) half8 {
|
17 |
+
half2 x, y, z, w;
|
18 |
+
};
|
19 |
+
|
20 |
+
template <class To, class From>
|
21 |
+
__device__ __forceinline__ To BitCast(const From& src) noexcept {
|
22 |
+
To dst;
|
23 |
+
std::memcpy(&dst, &src, sizeof(To));
|
24 |
+
return dst;
|
25 |
+
}
|
26 |
+
|
27 |
+
template <typename T>
|
28 |
+
__device__ __forceinline__ void Store(const T& value, T* ptr) {
|
29 |
+
*ptr = value;
|
30 |
+
}
|
31 |
+
|
32 |
+
template <typename T>
|
33 |
+
__device__ __forceinline__ T Load(const T* address) {
|
34 |
+
return __ldg(address);
|
35 |
+
}
|
36 |
+
|
37 |
+
__device__ __forceinline__ half4 Load(const half4* address) {
|
38 |
+
float2 x = __ldg(reinterpret_cast<const float2*>(address));
|
39 |
+
return BitCast<half4>(x);
|
40 |
+
}
|
41 |
+
|
42 |
+
__device__ __forceinline__ half8 Load(const half8* address) {
|
43 |
+
float4 x = __ldg(reinterpret_cast<const float4*>(address));
|
44 |
+
return BitCast<half8>(x);
|
45 |
+
}
|
46 |
+
|
47 |
+
template <typename T>
|
48 |
+
__device__ __forceinline__ T Zero() { return 0; };
|
49 |
+
|
50 |
+
template <>
|
51 |
+
__device__ __forceinline__ half2 Zero<half2>() {
|
52 |
+
return {(c10::Half)0., (c10::Half)0.};
|
53 |
+
};
|
54 |
+
|
55 |
+
template <>
|
56 |
+
__device__ __forceinline__ half4 Zero<half4>() {
|
57 |
+
return {Zero<half2>(), Zero<half2>()};
|
58 |
+
};
|
59 |
+
|
60 |
+
} // namespace megablocks
|
61 |
+
|
62 |
+
#endif // BLOCKPARTY_CSRC_CUDA_UTIL_H_
|
csrc/cumsum.h
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#define CUB_IGNORE_DEPRECATED_API
|
2 |
+
|
3 |
+
#undef CUB_WRAPPED_NAMESPACE
|
4 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
5 |
+
|
6 |
+
#include <cstdint>
|
7 |
+
|
8 |
+
#include <cub/cub.cuh>
|
9 |
+
#include <c10/cuda/CUDAStream.h>
|
10 |
+
#include <torch/all.h>
|
11 |
+
// #include <torch/extension.h>
|
12 |
+
|
13 |
+
#define CUDA_CALL(code) \
|
14 |
+
do { \
|
15 |
+
cudaError_t status = code; \
|
16 |
+
std::string err = cudaGetErrorString(status); \
|
17 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
18 |
+
} while (0)
|
19 |
+
|
20 |
+
namespace megablocks {
|
21 |
+
|
22 |
+
struct Inclusive {};
|
23 |
+
struct Exclusive {};
|
24 |
+
|
25 |
+
template <typename Type> struct Cumsum {
|
26 |
+
|
27 |
+
template<
|
28 |
+
typename InputIteratorT,
|
29 |
+
typename OutputIteratorT>
|
30 |
+
static void Run(void * d_temp_storage,
|
31 |
+
size_t & temp_storage_bytes,
|
32 |
+
InputIteratorT d_in,
|
33 |
+
OutputIteratorT d_out,
|
34 |
+
int num_items,
|
35 |
+
cudaStream_t stream = 0,
|
36 |
+
bool debug_synchronous = false) {
|
37 |
+
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage,
|
38 |
+
temp_storage_bytes,
|
39 |
+
d_in,
|
40 |
+
d_out,
|
41 |
+
num_items,
|
42 |
+
stream));//,
|
43 |
+
//debug_synchronous));
|
44 |
+
}
|
45 |
+
};
|
46 |
+
|
47 |
+
template <> struct Cumsum<Inclusive> {
|
48 |
+
template<
|
49 |
+
typename InputIteratorT,
|
50 |
+
typename OutputIteratorT>
|
51 |
+
static void Run(void * d_temp_storage,
|
52 |
+
size_t & temp_storage_bytes,
|
53 |
+
InputIteratorT d_in,
|
54 |
+
OutputIteratorT d_out,
|
55 |
+
int num_items,
|
56 |
+
cudaStream_t stream = 0,
|
57 |
+
bool debug_synchronous = false) {
|
58 |
+
CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage,
|
59 |
+
temp_storage_bytes,
|
60 |
+
d_in,
|
61 |
+
d_out,
|
62 |
+
num_items,
|
63 |
+
stream));//,
|
64 |
+
//debug_synchronous));
|
65 |
+
}
|
66 |
+
};
|
67 |
+
|
68 |
+
template <typename SumType, typename T>
|
69 |
+
void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
70 |
+
// Get temporary storage size.
|
71 |
+
size_t scratchpad_bytes = 0;
|
72 |
+
Cumsum<SumType>::Run(nullptr,
|
73 |
+
scratchpad_bytes,
|
74 |
+
x.data_ptr<T>(),
|
75 |
+
out.data_ptr<T>(),
|
76 |
+
x.size(1),
|
77 |
+
c10::cuda::getCurrentCUDAStream());
|
78 |
+
|
79 |
+
// Allocate scratchpad.
|
80 |
+
//
|
81 |
+
// NOTE: We scale for the batch dimension so we can run in parallel.
|
82 |
+
auto options = torch::TensorOptions()
|
83 |
+
.dtype(torch::kInt8)
|
84 |
+
.device(x.device());
|
85 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
|
86 |
+
options);
|
87 |
+
|
88 |
+
// Run the kernel.
|
89 |
+
//
|
90 |
+
// NOTE: Using different streams for each issue does not appear to
|
91 |
+
// yield performance gains for our problem set. The overhead of
|
92 |
+
// event/stream synchronization appears to outweigh the benfits.
|
93 |
+
// We could write a true batched cumsum, but this would require
|
94 |
+
// significant code duplication from cub and we might move away
|
95 |
+
// from this formulation anyways.
|
96 |
+
for (int i = 0; i < x.size(0); ++i) {
|
97 |
+
void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
|
98 |
+
Cumsum<SumType>::Run(scratchpad_ptr,
|
99 |
+
scratchpad_bytes,
|
100 |
+
x.data_ptr<T>() + x.size(1) * i,
|
101 |
+
out.data_ptr<T>() + x.size(1) * i,
|
102 |
+
x.size(1),
|
103 |
+
c10::cuda::getCurrentCUDAStream());
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
108 |
+
// Validate the input matrix.
|
109 |
+
TORCH_CHECK(x.is_cuda());
|
110 |
+
TORCH_CHECK(x.ndimension() == 2);
|
111 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
112 |
+
x.scalar_type() == torch::kInt32 ||
|
113 |
+
x.scalar_type() == torch::kInt64);
|
114 |
+
TORCH_CHECK(out.is_cuda());
|
115 |
+
TORCH_CHECK(out.ndimension() == 2);
|
116 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
117 |
+
|
118 |
+
// NOTE: We currently only support contraction across the contiguous
|
119 |
+
// dimension in the matrix.
|
120 |
+
TORCH_CHECK(dim == 1);
|
121 |
+
|
122 |
+
switch (x.scalar_type()) {
|
123 |
+
case torch::kInt16:
|
124 |
+
cub_cumsum<Exclusive, short>(x, dim, out);
|
125 |
+
return;
|
126 |
+
case torch::kInt32:
|
127 |
+
cub_cumsum<Exclusive, int>(x, dim, out);
|
128 |
+
return;
|
129 |
+
}
|
130 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
131 |
+
cub_cumsum<Exclusive, long>(x, dim, out);
|
132 |
+
}
|
133 |
+
|
134 |
+
void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
135 |
+
// Validate the input matrix.
|
136 |
+
TORCH_CHECK(x.is_cuda());
|
137 |
+
TORCH_CHECK(x.ndimension() == 2);
|
138 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
139 |
+
x.scalar_type() == torch::kInt32 ||
|
140 |
+
x.scalar_type() == torch::kInt64);
|
141 |
+
TORCH_CHECK(out.is_cuda());
|
142 |
+
TORCH_CHECK(out.ndimension() == 2);
|
143 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
144 |
+
|
145 |
+
// NOTE: We currently only support contraction across the contiguous
|
146 |
+
// dimension in the matrix.
|
147 |
+
TORCH_CHECK(dim == 1);
|
148 |
+
|
149 |
+
switch (x.scalar_type()) {
|
150 |
+
case torch::kInt16:
|
151 |
+
cub_cumsum<Inclusive, short>(x, dim, out);
|
152 |
+
return;
|
153 |
+
case torch::kInt32:
|
154 |
+
cub_cumsum<Inclusive, int>(x, dim, out);
|
155 |
+
return;
|
156 |
+
}
|
157 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
158 |
+
cub_cumsum<Inclusive, long>(x, dim, out);
|
159 |
+
}
|
160 |
+
|
161 |
+
} // namespace megablocks
|
162 |
+
|
163 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/histogram.h
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include <cstdint>
|
5 |
+
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
// #include <torch/extension.h>
|
9 |
+
|
10 |
+
#define CUDA_CALL(code) \
|
11 |
+
do { \
|
12 |
+
cudaError_t status = code; \
|
13 |
+
std::string err = cudaGetErrorString(status); \
|
14 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
15 |
+
} while (0)
|
16 |
+
|
17 |
+
namespace megablocks {
|
18 |
+
|
19 |
+
template <typename T>
|
20 |
+
torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
|
21 |
+
// Allocate the count buffer.
|
22 |
+
auto options = torch::TensorOptions()
|
23 |
+
.dtype(torch::kInt32)
|
24 |
+
.device(x.device());
|
25 |
+
torch::Tensor out = torch::empty({x.size(0), num_bins}, options);
|
26 |
+
|
27 |
+
// Exit early if there is not work to do.
|
28 |
+
if (out.numel() == 0) return out;
|
29 |
+
|
30 |
+
// Get scratchpad size.
|
31 |
+
size_t scratchpad_bytes = 0;
|
32 |
+
CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr,
|
33 |
+
scratchpad_bytes,
|
34 |
+
x.data_ptr<T>(),
|
35 |
+
out.data_ptr<int>(),
|
36 |
+
/*num_levels=*/num_bins + 1,
|
37 |
+
/*lower_level=*/0,
|
38 |
+
/*upper_level=*/num_bins,
|
39 |
+
/*num_samples=*/int(x.size(1)),
|
40 |
+
c10::cuda::getCurrentCUDAStream()));
|
41 |
+
|
42 |
+
// Allocate scratchpad.
|
43 |
+
options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
|
44 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
45 |
+
|
46 |
+
// Run the kernel.
|
47 |
+
for (int i = 0; i < x.size(0); ++i) {
|
48 |
+
CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
|
49 |
+
scratchpad_bytes,
|
50 |
+
x.data_ptr<T>() + x.size(1) * i,
|
51 |
+
out.data_ptr<int>() + out.size(1) * i,
|
52 |
+
/*num_levels=*/num_bins + 1,
|
53 |
+
/*lower_level=*/0,
|
54 |
+
/*upper_level=*/num_bins,
|
55 |
+
/*num_samples=*/int(x.size(1)),
|
56 |
+
c10::cuda::getCurrentCUDAStream()));
|
57 |
+
}
|
58 |
+
return out;
|
59 |
+
}
|
60 |
+
|
61 |
+
torch::Tensor histogram(torch::Tensor x, int num_bins) {
|
62 |
+
TORCH_CHECK(x.is_cuda());
|
63 |
+
TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
|
64 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
65 |
+
x.scalar_type() == torch::kInt32 ||
|
66 |
+
x.scalar_type() == torch::kInt64);
|
67 |
+
bool no_batch = x.ndimension() == 1;
|
68 |
+
if (no_batch) x = x.view({1, x.numel()});
|
69 |
+
|
70 |
+
if (x.scalar_type() == torch::kInt16) {
|
71 |
+
auto out = cub_histogram<short>(x, num_bins);
|
72 |
+
return no_batch ? out.flatten() : out;
|
73 |
+
} else if (x.scalar_type() == torch::kInt32) {
|
74 |
+
auto out = cub_histogram<int>(x, num_bins);
|
75 |
+
return no_batch ? out.flatten() : out;
|
76 |
+
} else {
|
77 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
78 |
+
auto out = cub_histogram<long>(x, num_bins);
|
79 |
+
return no_batch ? out.flatten() : out;
|
80 |
+
}
|
81 |
+
}
|
82 |
+
|
83 |
+
} // namespace megablocks
|
84 |
+
|
85 |
+
#undef CUDA_CALL
|
86 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/indices.h
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cstdint>
|
2 |
+
#include <c10/util/Half.h>
|
3 |
+
// #include <torch/extension.h>
|
4 |
+
#include <c10/cuda/CUDAStream.h>
|
5 |
+
|
6 |
+
#define CUDA_CALL(code) \
|
7 |
+
do { \
|
8 |
+
cudaError_t status = code; \
|
9 |
+
std::string err = cudaGetErrorString(status); \
|
10 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
11 |
+
} while (0)
|
12 |
+
|
13 |
+
namespace megablocks {
|
14 |
+
namespace construct_indices {
|
15 |
+
|
16 |
+
// We expect the number of outputs per block to be small. For
|
17 |
+
// example, with ffn_hidden_size=4096, we only need to write
|
18 |
+
// 32 elements per block per iteration.
|
19 |
+
const int kThreadsPerBlock = 32;
|
20 |
+
|
21 |
+
__global__ void __launch_bounds__(kThreadsPerBlock)
|
22 |
+
ConstructIndicesKernel(short * __restrict__ indices,
|
23 |
+
int num_columns,
|
24 |
+
int block_size,
|
25 |
+
const int * __restrict__ padded_bins) {
|
26 |
+
// Load the offset for this bins indices.
|
27 |
+
int start = 0;
|
28 |
+
if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1);
|
29 |
+
int end = __ldg(padded_bins + blockIdx.x);
|
30 |
+
|
31 |
+
// Divide the start and end into blocks.
|
32 |
+
start /= block_size;
|
33 |
+
end /= block_size;
|
34 |
+
|
35 |
+
// Offset the output buffer to the start of the bin.
|
36 |
+
indices += (start + blockIdx.y) * num_columns + threadIdx.x;
|
37 |
+
|
38 |
+
// Write the indices to the output.
|
39 |
+
int bin_offset = blockIdx.y;
|
40 |
+
int num_rows = end - start;
|
41 |
+
for (; bin_offset < num_rows; num_rows -= gridDim.y) {
|
42 |
+
short *out = indices;
|
43 |
+
for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
|
44 |
+
*out = bid + (blockIdx.x * num_columns);
|
45 |
+
out += kThreadsPerBlock;
|
46 |
+
}
|
47 |
+
indices += gridDim.y * num_columns;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
cudaError_t ConstructIndices(short * __restrict__ indices,
|
52 |
+
int output_block_rows,
|
53 |
+
int output_block_columns,
|
54 |
+
int block_size,
|
55 |
+
const int * __restrict__ padded_bins,
|
56 |
+
int num_bins,
|
57 |
+
cudaStream_t stream) {
|
58 |
+
dim3 block_dim(kThreadsPerBlock);
|
59 |
+
dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins));
|
60 |
+
ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices,
|
61 |
+
output_block_columns,
|
62 |
+
block_size,
|
63 |
+
padded_bins);
|
64 |
+
return cudaGetLastError();
|
65 |
+
}
|
66 |
+
|
67 |
+
} // namespace construct_indices
|
68 |
+
|
69 |
+
void indices(torch::Tensor padded_bins,
|
70 |
+
int block_size,
|
71 |
+
int output_block_rows,
|
72 |
+
int output_block_columns,
|
73 |
+
torch::Tensor out) {
|
74 |
+
TORCH_CHECK(padded_bins.is_cuda());
|
75 |
+
TORCH_CHECK(padded_bins.ndimension() == 1);
|
76 |
+
TORCH_CHECK(padded_bins.scalar_type() == torch::kInt);
|
77 |
+
|
78 |
+
TORCH_CHECK(out.is_cuda());
|
79 |
+
TORCH_CHECK(out.ndimension() == 1);
|
80 |
+
TORCH_CHECK(out.scalar_type() == torch::kInt16);
|
81 |
+
TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns));
|
82 |
+
|
83 |
+
// Exit early if there is no work to do.
|
84 |
+
if (out.numel() == 0) return;
|
85 |
+
|
86 |
+
CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(),
|
87 |
+
output_block_rows,
|
88 |
+
output_block_columns,
|
89 |
+
block_size,
|
90 |
+
padded_bins.data_ptr<int>(),
|
91 |
+
padded_bins.numel(),
|
92 |
+
c10::cuda::getCurrentCUDAStream()));
|
93 |
+
}
|
94 |
+
|
95 |
+
} // namespace megablocks
|
csrc/new_cumsum.cu
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#define CUB_IGNORE_DEPRECATED_API
|
2 |
+
|
3 |
+
#undef CUB_WRAPPED_NAMESPACE
|
4 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
5 |
+
|
6 |
+
#include "new_cumsum.h"
|
7 |
+
#include <cstdint>
|
8 |
+
#include <cub/cub.cuh>
|
9 |
+
#include <c10/cuda/CUDAStream.h>
|
10 |
+
|
11 |
+
#define CUDA_CALL(code) \
|
12 |
+
do { \
|
13 |
+
cudaError_t status = code; \
|
14 |
+
std::string err = cudaGetErrorString(status); \
|
15 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
16 |
+
} while (0)
|
17 |
+
|
18 |
+
namespace megablocks {
|
19 |
+
|
20 |
+
struct Inclusive {};
|
21 |
+
struct Exclusive {};
|
22 |
+
|
23 |
+
template <typename Type> struct Cumsum {
|
24 |
+
|
25 |
+
template<
|
26 |
+
typename InputIteratorT,
|
27 |
+
typename OutputIteratorT>
|
28 |
+
static void Run(void * d_temp_storage,
|
29 |
+
size_t & temp_storage_bytes,
|
30 |
+
InputIteratorT d_in,
|
31 |
+
OutputIteratorT d_out,
|
32 |
+
int num_items,
|
33 |
+
cudaStream_t stream = 0,
|
34 |
+
bool debug_synchronous = false) {
|
35 |
+
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage,
|
36 |
+
temp_storage_bytes,
|
37 |
+
d_in,
|
38 |
+
d_out,
|
39 |
+
num_items,
|
40 |
+
stream));//,
|
41 |
+
//debug_synchronous));
|
42 |
+
}
|
43 |
+
};
|
44 |
+
|
45 |
+
template <> struct Cumsum<Inclusive> {
|
46 |
+
template<
|
47 |
+
typename InputIteratorT,
|
48 |
+
typename OutputIteratorT>
|
49 |
+
static void Run(void * d_temp_storage,
|
50 |
+
size_t & temp_storage_bytes,
|
51 |
+
InputIteratorT d_in,
|
52 |
+
OutputIteratorT d_out,
|
53 |
+
int num_items,
|
54 |
+
cudaStream_t stream = 0,
|
55 |
+
bool debug_synchronous = false) {
|
56 |
+
CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage,
|
57 |
+
temp_storage_bytes,
|
58 |
+
d_in,
|
59 |
+
d_out,
|
60 |
+
num_items,
|
61 |
+
stream));//,
|
62 |
+
//debug_synchronous));
|
63 |
+
}
|
64 |
+
};
|
65 |
+
|
66 |
+
template <typename SumType, typename T>
|
67 |
+
void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
68 |
+
// Get temporary storage size.
|
69 |
+
size_t scratchpad_bytes = 0;
|
70 |
+
Cumsum<SumType>::Run(nullptr,
|
71 |
+
scratchpad_bytes,
|
72 |
+
x.data_ptr<T>(),
|
73 |
+
out.data_ptr<T>(),
|
74 |
+
x.size(1),
|
75 |
+
c10::cuda::getCurrentCUDAStream());
|
76 |
+
|
77 |
+
// Allocate scratchpad.
|
78 |
+
//
|
79 |
+
// NOTE: We scale for the batch dimension so we can run in parallel.
|
80 |
+
auto options = torch::TensorOptions()
|
81 |
+
.dtype(torch::kInt8)
|
82 |
+
.device(x.device());
|
83 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
|
84 |
+
options);
|
85 |
+
|
86 |
+
// Run the kernel.
|
87 |
+
//
|
88 |
+
// NOTE: Using different streams for each issue does not appear to
|
89 |
+
// yield performance gains for our problem set. The overhead of
|
90 |
+
// event/stream synchronization appears to outweigh the benfits.
|
91 |
+
// We could write a true batched cumsum, but this would require
|
92 |
+
// significant code duplication from cub and we might move away
|
93 |
+
// from this formulation anyways.
|
94 |
+
for (int i = 0; i < x.size(0); ++i) {
|
95 |
+
void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
|
96 |
+
Cumsum<SumType>::Run(scratchpad_ptr,
|
97 |
+
scratchpad_bytes,
|
98 |
+
x.data_ptr<T>() + x.size(1) * i,
|
99 |
+
out.data_ptr<T>() + x.size(1) * i,
|
100 |
+
x.size(1),
|
101 |
+
c10::cuda::getCurrentCUDAStream());
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
106 |
+
// Validate the input matrix.
|
107 |
+
TORCH_CHECK(x.is_cuda());
|
108 |
+
TORCH_CHECK(x.ndimension() == 2);
|
109 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
110 |
+
x.scalar_type() == torch::kInt32 ||
|
111 |
+
x.scalar_type() == torch::kInt64);
|
112 |
+
TORCH_CHECK(out.is_cuda());
|
113 |
+
TORCH_CHECK(out.ndimension() == 2);
|
114 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
115 |
+
|
116 |
+
// NOTE: We currently only support contraction across the contiguous
|
117 |
+
// dimension in the matrix.
|
118 |
+
TORCH_CHECK(dim == 1);
|
119 |
+
|
120 |
+
switch (x.scalar_type()) {
|
121 |
+
case torch::kInt16:
|
122 |
+
cub_cumsum<Exclusive, short>(x, dim, out);
|
123 |
+
return;
|
124 |
+
case torch::kInt32:
|
125 |
+
cub_cumsum<Exclusive, int>(x, dim, out);
|
126 |
+
return;
|
127 |
+
}
|
128 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
129 |
+
cub_cumsum<Exclusive, long>(x, dim, out);
|
130 |
+
}
|
131 |
+
|
132 |
+
void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
133 |
+
// Validate the input matrix.
|
134 |
+
TORCH_CHECK(x.is_cuda());
|
135 |
+
TORCH_CHECK(x.ndimension() == 2);
|
136 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
137 |
+
x.scalar_type() == torch::kInt32 ||
|
138 |
+
x.scalar_type() == torch::kInt64);
|
139 |
+
TORCH_CHECK(out.is_cuda());
|
140 |
+
TORCH_CHECK(out.ndimension() == 2);
|
141 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
142 |
+
|
143 |
+
// NOTE: We currently only support contraction across the contiguous
|
144 |
+
// dimension in the matrix.
|
145 |
+
TORCH_CHECK(dim == 1);
|
146 |
+
|
147 |
+
switch (x.scalar_type()) {
|
148 |
+
case torch::kInt16:
|
149 |
+
cub_cumsum<Inclusive, short>(x, dim, out);
|
150 |
+
return;
|
151 |
+
case torch::kInt32:
|
152 |
+
cub_cumsum<Inclusive, int>(x, dim, out);
|
153 |
+
return;
|
154 |
+
}
|
155 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
156 |
+
cub_cumsum<Inclusive, long>(x, dim, out);
|
157 |
+
}
|
158 |
+
|
159 |
+
} // namespace megablocks
|
160 |
+
|
161 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/new_cumsum.h
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Forward declarations for the public interface functions
|
8 |
+
void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
|
9 |
+
void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
|
10 |
+
|
11 |
+
} // namespace megablocks
|
csrc/new_histogram.cu
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include "new_histogram.h"
|
5 |
+
#include <cstdint>
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
|
9 |
+
#define CUDA_CALL(code) \
|
10 |
+
do { \
|
11 |
+
cudaError_t status = code; \
|
12 |
+
std::string err = cudaGetErrorString(status); \
|
13 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
14 |
+
} while (0)
|
15 |
+
|
16 |
+
namespace megablocks {
|
17 |
+
|
18 |
+
template <typename T>
|
19 |
+
torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
|
20 |
+
// Allocate the count buffer.
|
21 |
+
auto options = torch::TensorOptions()
|
22 |
+
.dtype(torch::kInt32)
|
23 |
+
.device(x.device());
|
24 |
+
torch::Tensor out = torch::empty({x.size(0), num_bins}, options);
|
25 |
+
|
26 |
+
// Exit early if there is not work to do.
|
27 |
+
if (out.numel() == 0) return out;
|
28 |
+
|
29 |
+
// Get scratchpad size.
|
30 |
+
size_t scratchpad_bytes = 0;
|
31 |
+
CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr,
|
32 |
+
scratchpad_bytes,
|
33 |
+
x.data_ptr<T>(),
|
34 |
+
out.data_ptr<int>(),
|
35 |
+
/*num_levels=*/num_bins + 1,
|
36 |
+
/*lower_level=*/0,
|
37 |
+
/*upper_level=*/num_bins,
|
38 |
+
/*num_samples=*/int(x.size(1)),
|
39 |
+
c10::cuda::getCurrentCUDAStream()));
|
40 |
+
|
41 |
+
// Allocate scratchpad.
|
42 |
+
options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
|
43 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
44 |
+
|
45 |
+
// Run the kernel.
|
46 |
+
for (int i = 0; i < x.size(0); ++i) {
|
47 |
+
CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
|
48 |
+
scratchpad_bytes,
|
49 |
+
x.data_ptr<T>() + x.size(1) * i,
|
50 |
+
out.data_ptr<int>() + out.size(1) * i,
|
51 |
+
/*num_levels=*/num_bins + 1,
|
52 |
+
/*lower_level=*/0,
|
53 |
+
/*upper_level=*/num_bins,
|
54 |
+
/*num_samples=*/int(x.size(1)),
|
55 |
+
c10::cuda::getCurrentCUDAStream()));
|
56 |
+
}
|
57 |
+
return out;
|
58 |
+
}
|
59 |
+
|
60 |
+
torch::Tensor histogram(torch::Tensor x, int num_bins) {
|
61 |
+
TORCH_CHECK(x.is_cuda());
|
62 |
+
TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
|
63 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
64 |
+
x.scalar_type() == torch::kInt32 ||
|
65 |
+
x.scalar_type() == torch::kInt64);
|
66 |
+
bool no_batch = x.ndimension() == 1;
|
67 |
+
if (no_batch) x = x.view({1, x.numel()});
|
68 |
+
|
69 |
+
if (x.scalar_type() == torch::kInt16) {
|
70 |
+
auto out = cub_histogram<short>(x, num_bins);
|
71 |
+
return no_batch ? out.flatten() : out;
|
72 |
+
} else if (x.scalar_type() == torch::kInt32) {
|
73 |
+
auto out = cub_histogram<int>(x, num_bins);
|
74 |
+
return no_batch ? out.flatten() : out;
|
75 |
+
} else {
|
76 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
77 |
+
auto out = cub_histogram<long>(x, num_bins);
|
78 |
+
return no_batch ? out.flatten() : out;
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
} // namespace megablocks
|
83 |
+
|
84 |
+
#undef CUDA_CALL
|
85 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/new_histogram.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Public interface function for computing histograms
|
8 |
+
torch::Tensor histogram(torch::Tensor x, int num_bins);
|
9 |
+
|
10 |
+
} // namespace megablocks
|
csrc/new_indices.cu
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "new_indices.h"
|
2 |
+
#include <cstdint>
|
3 |
+
#include <c10/util/Half.h>
|
4 |
+
#include <c10/cuda/CUDAStream.h>
|
5 |
+
|
6 |
+
#define CUDA_CALL(code) \
|
7 |
+
do { \
|
8 |
+
cudaError_t status = code; \
|
9 |
+
std::string err = cudaGetErrorString(status); \
|
10 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
11 |
+
} while (0)
|
12 |
+
|
13 |
+
namespace megablocks {
|
14 |
+
namespace construct_indices {
|
15 |
+
|
16 |
+
// We expect the number of outputs per block to be small. For
|
17 |
+
// example, with ffn_hidden_size=4096, we only need to write
|
18 |
+
// 32 elements per block per iteration.
|
19 |
+
const int kThreadsPerBlock = 32;
|
20 |
+
|
21 |
+
__global__ void __launch_bounds__(kThreadsPerBlock)
|
22 |
+
ConstructIndicesKernel(short * __restrict__ indices,
|
23 |
+
int num_columns,
|
24 |
+
int block_size,
|
25 |
+
const int * __restrict__ padded_bins) {
|
26 |
+
// Load the offset for this bins indices.
|
27 |
+
int start = 0;
|
28 |
+
if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1);
|
29 |
+
int end = __ldg(padded_bins + blockIdx.x);
|
30 |
+
|
31 |
+
// Divide the start and end into blocks.
|
32 |
+
start /= block_size;
|
33 |
+
end /= block_size;
|
34 |
+
|
35 |
+
// Offset the output buffer to the start of the bin.
|
36 |
+
indices += (start + blockIdx.y) * num_columns + threadIdx.x;
|
37 |
+
|
38 |
+
// Write the indices to the output.
|
39 |
+
int bin_offset = blockIdx.y;
|
40 |
+
int num_rows = end - start;
|
41 |
+
for (; bin_offset < num_rows; num_rows -= gridDim.y) {
|
42 |
+
short *out = indices;
|
43 |
+
for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
|
44 |
+
*out = bid + (blockIdx.x * num_columns);
|
45 |
+
out += kThreadsPerBlock;
|
46 |
+
}
|
47 |
+
indices += gridDim.y * num_columns;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
cudaError_t ConstructIndices(short * __restrict__ indices,
|
52 |
+
int output_block_rows,
|
53 |
+
int output_block_columns,
|
54 |
+
int block_size,
|
55 |
+
const int * __restrict__ padded_bins,
|
56 |
+
int num_bins,
|
57 |
+
cudaStream_t stream) {
|
58 |
+
dim3 block_dim(kThreadsPerBlock);
|
59 |
+
dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins));
|
60 |
+
ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices,
|
61 |
+
output_block_columns,
|
62 |
+
block_size,
|
63 |
+
padded_bins);
|
64 |
+
return cudaGetLastError();
|
65 |
+
}
|
66 |
+
|
67 |
+
} // namespace construct_indices
|
68 |
+
|
69 |
+
void indices(torch::Tensor padded_bins,
|
70 |
+
int block_size,
|
71 |
+
int output_block_rows,
|
72 |
+
int output_block_columns,
|
73 |
+
torch::Tensor out) {
|
74 |
+
TORCH_CHECK(padded_bins.is_cuda());
|
75 |
+
TORCH_CHECK(padded_bins.ndimension() == 1);
|
76 |
+
TORCH_CHECK(padded_bins.scalar_type() == torch::kInt);
|
77 |
+
|
78 |
+
TORCH_CHECK(out.is_cuda());
|
79 |
+
TORCH_CHECK(out.ndimension() == 1);
|
80 |
+
TORCH_CHECK(out.scalar_type() == torch::kInt16);
|
81 |
+
TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns));
|
82 |
+
|
83 |
+
// Exit early if there is no work to do.
|
84 |
+
if (out.numel() == 0) return;
|
85 |
+
|
86 |
+
CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(),
|
87 |
+
output_block_rows,
|
88 |
+
output_block_columns,
|
89 |
+
block_size,
|
90 |
+
padded_bins.data_ptr<int>(),
|
91 |
+
padded_bins.numel(),
|
92 |
+
c10::cuda::getCurrentCUDAStream()));
|
93 |
+
}
|
94 |
+
|
95 |
+
} // namespace megablocks
|
96 |
+
|
97 |
+
#undef CUDA_CALL
|
csrc/new_indices.h
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Public interface function for constructing indices from padded bins
|
8 |
+
void indices(torch::Tensor padded_bins,
|
9 |
+
int block_size,
|
10 |
+
int output_block_rows,
|
11 |
+
int output_block_columns,
|
12 |
+
torch::Tensor out);
|
13 |
+
|
14 |
+
} // namespace megablocks
|
csrc/new_replicate.cu
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include "new_replicate.h"
|
5 |
+
#include <cstdint>
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/util/Half.h>
|
8 |
+
#include <c10/cuda/CUDAStream.h>
|
9 |
+
|
10 |
+
#define CUDA_CALL(code) \
|
11 |
+
do { \
|
12 |
+
cudaError_t status = code; \
|
13 |
+
std::string err = cudaGetErrorString(status); \
|
14 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
15 |
+
} while (0)
|
16 |
+
|
17 |
+
namespace megablocks {
|
18 |
+
namespace replicate {
|
19 |
+
|
20 |
+
template <typename T, int kThreadsPerBlock>
|
21 |
+
__global__ void __launch_bounds__(kThreadsPerBlock)
|
22 |
+
ReplicateForwardKernel(T * __restrict__ x,
|
23 |
+
int * __restrict__ bins,
|
24 |
+
T * __restrict__ out,
|
25 |
+
int columns) {
|
26 |
+
// Offset to this threadblocks batch.
|
27 |
+
//
|
28 |
+
// x is [batch_size, num_bins]
|
29 |
+
// out is [batch_size, columns]
|
30 |
+
// bins is [num_bins]
|
31 |
+
int batch_idx = blockIdx.y;
|
32 |
+
int num_bins = gridDim.x;
|
33 |
+
x += batch_idx * num_bins;
|
34 |
+
out += batch_idx * columns;
|
35 |
+
|
36 |
+
// Load the start/end for this bin.
|
37 |
+
int bin_idx = blockIdx.x;
|
38 |
+
int start = 0;
|
39 |
+
if (bin_idx > 0) start = __ldg(bins + bin_idx - 1);
|
40 |
+
int end = __ldg(bins + bin_idx);
|
41 |
+
|
42 |
+
// Load the value to replicate.
|
43 |
+
T value = __ldg((T*)x + bin_idx);
|
44 |
+
|
45 |
+
// Offset to this threadblocks bin and this threads
|
46 |
+
// offset within the bin.
|
47 |
+
int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
|
48 |
+
out += start + bin_offset;
|
49 |
+
|
50 |
+
// Replicate the value to the output.
|
51 |
+
//
|
52 |
+
// TODO(tgale): Vectorize these stores.
|
53 |
+
int num_elements = end - start;
|
54 |
+
const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
|
55 |
+
T *out_ptr = (T*)out;
|
56 |
+
for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
|
57 |
+
*out_ptr = value;
|
58 |
+
out_ptr += kElementsPerLoop;
|
59 |
+
}
|
60 |
+
}
|
61 |
+
|
62 |
+
template <typename T>
|
63 |
+
cudaError_t ReplicateForward(T *x,
|
64 |
+
int batch_size,
|
65 |
+
int num_bins,
|
66 |
+
int *bins,
|
67 |
+
T *out,
|
68 |
+
int columns,
|
69 |
+
cudaStream_t stream) {
|
70 |
+
const int kThreadsPerBlock = 64;
|
71 |
+
dim3 block_dim(kThreadsPerBlock, 1, 1);
|
72 |
+
int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
|
73 |
+
dim3 grid_dim(num_bins, batch_size, group_size);
|
74 |
+
ReplicateForwardKernel<T, kThreadsPerBlock><<<
|
75 |
+
grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
|
76 |
+
return cudaGetLastError();
|
77 |
+
}
|
78 |
+
|
79 |
+
void cub_segmented_reduce(torch::Tensor grad,
|
80 |
+
torch::Tensor bins,
|
81 |
+
torch::Tensor out,
|
82 |
+
cudaStream_t stream) {
|
83 |
+
// Append a zero to the bin boundaries for CUB.
|
84 |
+
torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
|
85 |
+
CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
|
86 |
+
0,
|
87 |
+
offsets.numel() * sizeof(int),
|
88 |
+
stream));
|
89 |
+
CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
|
90 |
+
bins.data_ptr<int>(),
|
91 |
+
bins.numel() * sizeof(int),
|
92 |
+
cudaMemcpyDeviceToDevice,
|
93 |
+
stream));
|
94 |
+
|
95 |
+
// Get temporary buffer size.
|
96 |
+
size_t scratchpad_bytes = 0;
|
97 |
+
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
|
98 |
+
scratchpad_bytes,
|
99 |
+
grad.data_ptr<c10::Half>(),
|
100 |
+
out.data_ptr<c10::Half>(),
|
101 |
+
bins.numel(),
|
102 |
+
offsets.data_ptr<int>(),
|
103 |
+
offsets.data_ptr<int>() + 1,
|
104 |
+
stream));
|
105 |
+
|
106 |
+
// Allocate scratchpad.
|
107 |
+
auto options = torch::TensorOptions()
|
108 |
+
.dtype(torch::kInt8)
|
109 |
+
.device(grad.device());
|
110 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
111 |
+
|
112 |
+
// Run the kernel for each batch item.
|
113 |
+
for (int i = 0; i < grad.size(0); ++i) {
|
114 |
+
int num_bins = out.size(1);
|
115 |
+
int num_values = grad.size(1);
|
116 |
+
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
|
117 |
+
scratchpad_bytes,
|
118 |
+
grad.data_ptr<c10::Half>() + i * num_values,
|
119 |
+
out.data_ptr<c10::Half>() + i * num_bins,
|
120 |
+
bins.numel(),
|
121 |
+
offsets.data_ptr<int>(),
|
122 |
+
offsets.data_ptr<int>() + 1,
|
123 |
+
stream));
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
} // namespace replicate
|
128 |
+
|
129 |
+
void replicate_forward(torch::Tensor x,
|
130 |
+
torch::Tensor bins,
|
131 |
+
torch::Tensor out) {
|
132 |
+
// Validate the inputs.
|
133 |
+
TORCH_CHECK(x.is_cuda());
|
134 |
+
TORCH_CHECK(x.ndimension() == 2);
|
135 |
+
TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
|
136 |
+
x.scalar_type() == torch::kInt16 ||
|
137 |
+
x.scalar_type() == torch::kInt32);
|
138 |
+
TORCH_CHECK(bins.is_cuda());
|
139 |
+
TORCH_CHECK(bins.ndimension() == 1);
|
140 |
+
TORCH_CHECK(bins.scalar_type() == torch::kInt);
|
141 |
+
TORCH_CHECK(out.is_cuda());
|
142 |
+
TORCH_CHECK(out.ndimension() == 2);
|
143 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
144 |
+
|
145 |
+
// Batch dimensions should match for input/output.
|
146 |
+
TORCH_CHECK(x.size(0) == out.size(0));
|
147 |
+
|
148 |
+
// One input for each bin (in each batch).
|
149 |
+
TORCH_CHECK(x.size(1) == bins.size(0));
|
150 |
+
|
151 |
+
// Exit early if there is no work to do.
|
152 |
+
if (out.numel() == 0) return;
|
153 |
+
|
154 |
+
switch (x.scalar_type()) {
|
155 |
+
case torch::kFloat16:
|
156 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
|
157 |
+
x.size(0),
|
158 |
+
x.size(1),
|
159 |
+
bins.data_ptr<int>(),
|
160 |
+
out.data_ptr<c10::Half>(),
|
161 |
+
out.size(1),
|
162 |
+
c10::cuda::getCurrentCUDAStream()));
|
163 |
+
return;
|
164 |
+
case torch::kInt32:
|
165 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
|
166 |
+
x.size(0),
|
167 |
+
x.size(1),
|
168 |
+
bins.data_ptr<int>(),
|
169 |
+
out.data_ptr<int>(),
|
170 |
+
out.size(1),
|
171 |
+
c10::cuda::getCurrentCUDAStream()));
|
172 |
+
return;
|
173 |
+
}
|
174 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16);
|
175 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
|
176 |
+
x.size(0),
|
177 |
+
x.size(1),
|
178 |
+
bins.data_ptr<int>(),
|
179 |
+
out.data_ptr<short>(),
|
180 |
+
out.size(1),
|
181 |
+
c10::cuda::getCurrentCUDAStream()));
|
182 |
+
}
|
183 |
+
|
184 |
+
void replicate_backward(torch::Tensor grad,
|
185 |
+
torch::Tensor bins,
|
186 |
+
torch::Tensor out) {
|
187 |
+
// Validate the inputs.
|
188 |
+
TORCH_CHECK(grad.is_cuda());
|
189 |
+
TORCH_CHECK(grad.ndimension() == 2);
|
190 |
+
TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
|
191 |
+
TORCH_CHECK(bins.is_cuda());
|
192 |
+
TORCH_CHECK(bins.ndimension() == 1);
|
193 |
+
TORCH_CHECK(bins.scalar_type() == torch::kInt);
|
194 |
+
TORCH_CHECK(out.is_cuda());
|
195 |
+
TORCH_CHECK(out.ndimension() == 2);
|
196 |
+
TORCH_CHECK(out.scalar_type() == torch::kFloat16);
|
197 |
+
|
198 |
+
// Batch dimensions should match for input/output.
|
199 |
+
TORCH_CHECK(grad.size(0) == out.size(0));
|
200 |
+
|
201 |
+
// One output for each bin (in each batch).
|
202 |
+
TORCH_CHECK(out.size(1) == bins.size(0));
|
203 |
+
|
204 |
+
replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
|
205 |
+
}
|
206 |
+
|
207 |
+
} // namespace megablocks
|
208 |
+
|
209 |
+
#undef CUDA_CALL
|
210 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/new_replicate.h
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Forward pass: replicate values from x according to bin sizes
|
8 |
+
void replicate_forward(torch::Tensor x,
|
9 |
+
torch::Tensor bins,
|
10 |
+
torch::Tensor out);
|
11 |
+
|
12 |
+
// Backward pass: reduce gradients back to bins using segmented reduction
|
13 |
+
void replicate_backward(torch::Tensor grad,
|
14 |
+
torch::Tensor bins,
|
15 |
+
torch::Tensor out);
|
16 |
+
|
17 |
+
} // namespace megablocks
|
csrc/new_sort.cu
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include "new_sort.h"
|
5 |
+
#include <cstdint>
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
|
9 |
+
#define CUDA_CALL(code) \
|
10 |
+
do { \
|
11 |
+
cudaError_t status = code; \
|
12 |
+
std::string err = cudaGetErrorString(status); \
|
13 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
14 |
+
} while (0)
|
15 |
+
|
16 |
+
namespace megablocks {
|
17 |
+
|
18 |
+
template <typename T>
|
19 |
+
void cub_radix_sort(torch::Tensor x,
|
20 |
+
int end_bit,
|
21 |
+
torch::Tensor x_out,
|
22 |
+
torch::Tensor iota_out) {
|
23 |
+
// Get iota for values in sort.
|
24 |
+
torch::Tensor iota = torch::arange(0, x.numel(), x.options());
|
25 |
+
|
26 |
+
// Get temporary buffer size.
|
27 |
+
size_t scratchpad_bytes = 0;
|
28 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr,
|
29 |
+
scratchpad_bytes,
|
30 |
+
x.data_ptr<T>(),
|
31 |
+
x_out.data_ptr<T>(),
|
32 |
+
iota.data_ptr<T>(),
|
33 |
+
iota_out.data_ptr<T>(),
|
34 |
+
x.numel(),
|
35 |
+
/*begin_bit*/0,
|
36 |
+
/*end_bit=*/end_bit,
|
37 |
+
c10::cuda::getCurrentCUDAStream()));
|
38 |
+
|
39 |
+
// Allocate scratchpad.
|
40 |
+
auto options = torch::TensorOptions()
|
41 |
+
.dtype(torch::kInt8)
|
42 |
+
.device(x.device());
|
43 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
44 |
+
|
45 |
+
// Run the kernel.
|
46 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
|
47 |
+
scratchpad_bytes,
|
48 |
+
x.data_ptr<T>(),
|
49 |
+
x_out.data_ptr<T>(),
|
50 |
+
iota.data_ptr<T>(),
|
51 |
+
iota_out.data_ptr<T>(),
|
52 |
+
x.numel(),
|
53 |
+
/*begin_bit=*/0,
|
54 |
+
/*end_bit=*/end_bit,
|
55 |
+
c10::cuda::getCurrentCUDAStream()));
|
56 |
+
}
|
57 |
+
|
58 |
+
void sort(torch::Tensor x,
|
59 |
+
int end_bit,
|
60 |
+
torch::Tensor x_out,
|
61 |
+
torch::Tensor iota_out) {
|
62 |
+
TORCH_CHECK(x.is_cuda());
|
63 |
+
TORCH_CHECK(x.ndimension() == 1);
|
64 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
65 |
+
x.scalar_type() == torch::kInt32 ||
|
66 |
+
x.scalar_type() == torch::kInt64);
|
67 |
+
TORCH_CHECK(x_out.is_cuda());
|
68 |
+
TORCH_CHECK(x_out.ndimension() == 1);
|
69 |
+
TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
|
70 |
+
TORCH_CHECK(iota_out.is_cuda());
|
71 |
+
TORCH_CHECK(iota_out.ndimension() == 1);
|
72 |
+
TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
|
73 |
+
|
74 |
+
// Exit early if there is not work to do.
|
75 |
+
if (x_out.numel() == 0) return;
|
76 |
+
|
77 |
+
switch (x.scalar_type()) {
|
78 |
+
case torch::kInt16:
|
79 |
+
return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
|
80 |
+
case torch::kInt32:
|
81 |
+
return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
|
82 |
+
}
|
83 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
84 |
+
return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
|
85 |
+
}
|
86 |
+
|
87 |
+
} // namespace megablocks
|
88 |
+
|
89 |
+
#undef CUDA_CALL
|
90 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/new_sort.h
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Public interface function for radix sorting with indices
|
8 |
+
void sort(torch::Tensor x,
|
9 |
+
int end_bit,
|
10 |
+
torch::Tensor x_out,
|
11 |
+
torch::Tensor iota_out);
|
12 |
+
|
13 |
+
} // namespace megablocks
|
csrc/replicate.h
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include <cstdint>
|
5 |
+
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/util/Half.h>
|
8 |
+
#include <c10/cuda/CUDAStream.h>
|
9 |
+
// #include <torch/extension.h>
|
10 |
+
|
11 |
+
#define CUDA_CALL(code) \
|
12 |
+
do { \
|
13 |
+
cudaError_t status = code; \
|
14 |
+
std::string err = cudaGetErrorString(status); \
|
15 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
16 |
+
} while (0)
|
17 |
+
|
18 |
+
namespace megablocks {
|
19 |
+
namespace replicate {
|
20 |
+
|
21 |
+
template <typename T, int kThreadsPerBlock>
|
22 |
+
__global__ void __launch_bounds__(kThreadsPerBlock)
|
23 |
+
ReplicateForwardKernel(T * __restrict__ x,
|
24 |
+
int * __restrict__ bins,
|
25 |
+
T * __restrict__ out,
|
26 |
+
int columns) {
|
27 |
+
// Offset to this threadblocks batch.
|
28 |
+
//
|
29 |
+
// x is [batch_size, num_bins]
|
30 |
+
// out is [batch_size, columns]
|
31 |
+
// bins is [num_bins]
|
32 |
+
int batch_idx = blockIdx.y;
|
33 |
+
int num_bins = gridDim.x;
|
34 |
+
x += batch_idx * num_bins;
|
35 |
+
out += batch_idx * columns;
|
36 |
+
|
37 |
+
// Load the start/end for this bin.
|
38 |
+
int bin_idx = blockIdx.x;
|
39 |
+
int start = 0;
|
40 |
+
if (bin_idx > 0) start = __ldg(bins + bin_idx - 1);
|
41 |
+
int end = __ldg(bins + bin_idx);
|
42 |
+
|
43 |
+
// Load the value to replicate.
|
44 |
+
T value = __ldg((T*)x + bin_idx);
|
45 |
+
|
46 |
+
// Offset to this threadblocks bin and this threads
|
47 |
+
// offset within the bin.
|
48 |
+
int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
|
49 |
+
out += start + bin_offset;
|
50 |
+
|
51 |
+
// Replicate the value to the output.
|
52 |
+
//
|
53 |
+
// TODO(tgale): Vectorize these stores.
|
54 |
+
int num_elements = end - start;
|
55 |
+
const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
|
56 |
+
T *out_ptr = (T*)out;
|
57 |
+
for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
|
58 |
+
*out_ptr = value;
|
59 |
+
out_ptr += kElementsPerLoop;
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
template <typename T>
|
64 |
+
cudaError_t ReplicateForward(T *x,
|
65 |
+
int batch_size,
|
66 |
+
int num_bins,
|
67 |
+
int *bins,
|
68 |
+
T *out,
|
69 |
+
int columns,
|
70 |
+
cudaStream_t stream) {
|
71 |
+
const int kThreadsPerBlock = 64;
|
72 |
+
dim3 block_dim(kThreadsPerBlock, 1, 1);
|
73 |
+
int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
|
74 |
+
dim3 grid_dim(num_bins, batch_size, group_size);
|
75 |
+
ReplicateForwardKernel<T, kThreadsPerBlock><<<
|
76 |
+
grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
|
77 |
+
return cudaGetLastError();
|
78 |
+
}
|
79 |
+
|
80 |
+
void cub_segmented_reduce(torch::Tensor grad,
|
81 |
+
torch::Tensor bins,
|
82 |
+
torch::Tensor out,
|
83 |
+
cudaStream_t stream) {
|
84 |
+
// Append a zero to the bin boundaries for CUB.
|
85 |
+
torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
|
86 |
+
CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
|
87 |
+
0,
|
88 |
+
offsets.numel() * sizeof(int),
|
89 |
+
stream));
|
90 |
+
CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
|
91 |
+
bins.data_ptr<int>(),
|
92 |
+
bins.numel() * sizeof(int),
|
93 |
+
cudaMemcpyDeviceToDevice,
|
94 |
+
stream));
|
95 |
+
|
96 |
+
// Get temporary buffer size.
|
97 |
+
size_t scratchpad_bytes = 0;
|
98 |
+
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
|
99 |
+
scratchpad_bytes,
|
100 |
+
grad.data_ptr<c10::Half>(),
|
101 |
+
out.data_ptr<c10::Half>(),
|
102 |
+
bins.numel(),
|
103 |
+
offsets.data_ptr<int>(),
|
104 |
+
offsets.data_ptr<int>() + 1,
|
105 |
+
stream));
|
106 |
+
|
107 |
+
// Allocate scratchpad.
|
108 |
+
auto options = torch::TensorOptions()
|
109 |
+
.dtype(torch::kInt8)
|
110 |
+
.device(grad.device());
|
111 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
112 |
+
|
113 |
+
// Run the kernel for each batch item.
|
114 |
+
for (int i = 0; i < grad.size(0); ++i) {
|
115 |
+
int num_bins = out.size(1);
|
116 |
+
int num_values = grad.size(1);
|
117 |
+
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
|
118 |
+
scratchpad_bytes,
|
119 |
+
grad.data_ptr<c10::Half>() + i * num_values,
|
120 |
+
out.data_ptr<c10::Half>() + i * num_bins,
|
121 |
+
bins.numel(),
|
122 |
+
offsets.data_ptr<int>(),
|
123 |
+
offsets.data_ptr<int>() + 1,
|
124 |
+
stream));
|
125 |
+
}
|
126 |
+
}
|
127 |
+
|
128 |
+
} // namespace replicate
|
129 |
+
|
130 |
+
void replicate_forward(torch::Tensor x,
|
131 |
+
torch::Tensor bins,
|
132 |
+
torch::Tensor out) {
|
133 |
+
// Validate the inputs.
|
134 |
+
TORCH_CHECK(x.is_cuda());
|
135 |
+
TORCH_CHECK(x.ndimension() == 2);
|
136 |
+
TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
|
137 |
+
x.scalar_type() == torch::kInt16 ||
|
138 |
+
x.scalar_type() == torch::kInt32);
|
139 |
+
TORCH_CHECK(bins.is_cuda());
|
140 |
+
TORCH_CHECK(bins.ndimension() == 1);
|
141 |
+
TORCH_CHECK(bins.scalar_type() == torch::kInt);
|
142 |
+
TORCH_CHECK(out.is_cuda());
|
143 |
+
TORCH_CHECK(out.ndimension() == 2);
|
144 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
145 |
+
|
146 |
+
// Batch dimensions should match for input/output.
|
147 |
+
TORCH_CHECK(x.size(0) == out.size(0));
|
148 |
+
|
149 |
+
// One input for each bin (in each batch).
|
150 |
+
TORCH_CHECK(x.size(1) == bins.size(0));
|
151 |
+
|
152 |
+
// Exit early if there is no work to do.
|
153 |
+
if (out.numel() == 0) return;
|
154 |
+
|
155 |
+
switch (x.scalar_type()) {
|
156 |
+
case torch::kFloat16:
|
157 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
|
158 |
+
x.size(0),
|
159 |
+
x.size(1),
|
160 |
+
bins.data_ptr<int>(),
|
161 |
+
out.data_ptr<c10::Half>(),
|
162 |
+
out.size(1),
|
163 |
+
c10::cuda::getCurrentCUDAStream()));
|
164 |
+
return;
|
165 |
+
case torch::kInt32:
|
166 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
|
167 |
+
x.size(0),
|
168 |
+
x.size(1),
|
169 |
+
bins.data_ptr<int>(),
|
170 |
+
out.data_ptr<int>(),
|
171 |
+
out.size(1),
|
172 |
+
c10::cuda::getCurrentCUDAStream()));
|
173 |
+
return;
|
174 |
+
}
|
175 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16);
|
176 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
|
177 |
+
x.size(0),
|
178 |
+
x.size(1),
|
179 |
+
bins.data_ptr<int>(),
|
180 |
+
out.data_ptr<short>(),
|
181 |
+
out.size(1),
|
182 |
+
c10::cuda::getCurrentCUDAStream()));
|
183 |
+
}
|
184 |
+
|
185 |
+
void replicate_backward(torch::Tensor grad,
|
186 |
+
torch::Tensor bins,
|
187 |
+
torch::Tensor out) {
|
188 |
+
// Validate the inputs.
|
189 |
+
TORCH_CHECK(grad.is_cuda());
|
190 |
+
TORCH_CHECK(grad.ndimension() == 2);
|
191 |
+
TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
|
192 |
+
TORCH_CHECK(bins.is_cuda());
|
193 |
+
TORCH_CHECK(bins.ndimension() == 1);
|
194 |
+
TORCH_CHECK(bins.scalar_type() == torch::kInt);
|
195 |
+
TORCH_CHECK(out.is_cuda());
|
196 |
+
TORCH_CHECK(out.ndimension() == 2);
|
197 |
+
TORCH_CHECK(out.scalar_type() == torch::kFloat16);
|
198 |
+
|
199 |
+
// Batch dimensions should match for input/output.
|
200 |
+
TORCH_CHECK(grad.size(0) == out.size(0));
|
201 |
+
|
202 |
+
// One output for each bin (in each batch).
|
203 |
+
TORCH_CHECK(out.size(1) == bins.size(0));
|
204 |
+
|
205 |
+
replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
|
206 |
+
}
|
207 |
+
|
208 |
+
} // namespace megablocks
|
209 |
+
|
210 |
+
#undef CUDA_CALL
|
211 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/sort.h
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include <cstdint>
|
5 |
+
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
// #include <torch/extension.h>
|
9 |
+
|
10 |
+
#define CUDA_CALL(code) \
|
11 |
+
do { \
|
12 |
+
cudaError_t status = code; \
|
13 |
+
std::string err = cudaGetErrorString(status); \
|
14 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
15 |
+
} while (0)
|
16 |
+
|
17 |
+
namespace megablocks {
|
18 |
+
|
19 |
+
template <typename T>
|
20 |
+
void cub_radix_sort(torch::Tensor x,
|
21 |
+
int end_bit,
|
22 |
+
torch::Tensor x_out,
|
23 |
+
torch::Tensor iota_out) {
|
24 |
+
// Get iota for values in sort.
|
25 |
+
torch::Tensor iota = torch::arange(0, x.numel(), x.options());
|
26 |
+
|
27 |
+
// Get temporary buffer size.
|
28 |
+
size_t scratchpad_bytes = 0;
|
29 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr,
|
30 |
+
scratchpad_bytes,
|
31 |
+
x.data_ptr<T>(),
|
32 |
+
x_out.data_ptr<T>(),
|
33 |
+
iota.data_ptr<T>(),
|
34 |
+
iota_out.data_ptr<T>(),
|
35 |
+
x.numel(),
|
36 |
+
/*begin_bit*/0,
|
37 |
+
/*end_bit=*/end_bit,
|
38 |
+
c10::cuda::getCurrentCUDAStream()));
|
39 |
+
|
40 |
+
// Allocate scratchpad.
|
41 |
+
auto options = torch::TensorOptions()
|
42 |
+
.dtype(torch::kInt8)
|
43 |
+
.device(x.device());
|
44 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
45 |
+
|
46 |
+
// Run the kernel.
|
47 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
|
48 |
+
scratchpad_bytes,
|
49 |
+
x.data_ptr<T>(),
|
50 |
+
x_out.data_ptr<T>(),
|
51 |
+
iota.data_ptr<T>(),
|
52 |
+
iota_out.data_ptr<T>(),
|
53 |
+
x.numel(),
|
54 |
+
/*begin_bit=*/0,
|
55 |
+
/*end_bit=*/end_bit,
|
56 |
+
c10::cuda::getCurrentCUDAStream()));
|
57 |
+
}
|
58 |
+
|
59 |
+
void sort(torch::Tensor x,
|
60 |
+
int end_bit,
|
61 |
+
torch::Tensor x_out,
|
62 |
+
torch::Tensor iota_out) {
|
63 |
+
TORCH_CHECK(x.is_cuda());
|
64 |
+
TORCH_CHECK(x.ndimension() == 1);
|
65 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
66 |
+
x.scalar_type() == torch::kInt32 ||
|
67 |
+
x.scalar_type() == torch::kInt64);
|
68 |
+
TORCH_CHECK(x_out.is_cuda());
|
69 |
+
TORCH_CHECK(x_out.ndimension() == 1);
|
70 |
+
TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
|
71 |
+
TORCH_CHECK(iota_out.is_cuda());
|
72 |
+
TORCH_CHECK(iota_out.ndimension() == 1);
|
73 |
+
TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
|
74 |
+
|
75 |
+
// Exit early if there is not work to do.
|
76 |
+
if (x_out.numel() == 0) return;
|
77 |
+
|
78 |
+
switch (x.scalar_type()) {
|
79 |
+
case torch::kInt16:
|
80 |
+
return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
|
81 |
+
case torch::kInt32:
|
82 |
+
return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
|
83 |
+
}
|
84 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
85 |
+
return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
|
86 |
+
}
|
87 |
+
|
88 |
+
} // namespace megablocks
|
89 |
+
|
90 |
+
#undef CUDA_CALL
|
91 |
+
#undef CUB_WRAPPED_NAMESPACE
|
flake.lock
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1747046372,
|
6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-compat_2": {
|
19 |
+
"locked": {
|
20 |
+
"lastModified": 1733328505,
|
21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
22 |
+
"owner": "edolstra",
|
23 |
+
"repo": "flake-compat",
|
24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
25 |
+
"type": "github"
|
26 |
+
},
|
27 |
+
"original": {
|
28 |
+
"owner": "edolstra",
|
29 |
+
"repo": "flake-compat",
|
30 |
+
"type": "github"
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"flake-utils": {
|
34 |
+
"inputs": {
|
35 |
+
"systems": "systems"
|
36 |
+
},
|
37 |
+
"locked": {
|
38 |
+
"lastModified": 1731533236,
|
39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
40 |
+
"owner": "numtide",
|
41 |
+
"repo": "flake-utils",
|
42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
43 |
+
"type": "github"
|
44 |
+
},
|
45 |
+
"original": {
|
46 |
+
"owner": "numtide",
|
47 |
+
"repo": "flake-utils",
|
48 |
+
"type": "github"
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"flake-utils_2": {
|
52 |
+
"inputs": {
|
53 |
+
"systems": "systems_2"
|
54 |
+
},
|
55 |
+
"locked": {
|
56 |
+
"lastModified": 1731533236,
|
57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
58 |
+
"owner": "numtide",
|
59 |
+
"repo": "flake-utils",
|
60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
61 |
+
"type": "github"
|
62 |
+
},
|
63 |
+
"original": {
|
64 |
+
"owner": "numtide",
|
65 |
+
"repo": "flake-utils",
|
66 |
+
"type": "github"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"hf-nix": {
|
70 |
+
"inputs": {
|
71 |
+
"flake-compat": "flake-compat_2",
|
72 |
+
"flake-utils": "flake-utils_2",
|
73 |
+
"nixpkgs": "nixpkgs"
|
74 |
+
},
|
75 |
+
"locked": {
|
76 |
+
"lastModified": 1748598786,
|
77 |
+
"owner": "huggingface",
|
78 |
+
"repo": "hf-nix",
|
79 |
+
"rev": "6ca679441494139fde1f2355691ddb5dc8170269",
|
80 |
+
"type": "github"
|
81 |
+
},
|
82 |
+
"original": {
|
83 |
+
"owner": "huggingface",
|
84 |
+
"repo": "hf-nix",
|
85 |
+
"type": "github"
|
86 |
+
}
|
87 |
+
},
|
88 |
+
"kernel-builder": {
|
89 |
+
"inputs": {
|
90 |
+
"flake-compat": "flake-compat",
|
91 |
+
"flake-utils": "flake-utils",
|
92 |
+
"hf-nix": "hf-nix",
|
93 |
+
"nixpkgs": [
|
94 |
+
"kernel-builder",
|
95 |
+
"hf-nix",
|
96 |
+
"nixpkgs"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
"locked": {
|
100 |
+
"lastModified": 1749576434,
|
101 |
+
"narHash": "sha256-wSdtZih2fMQ3ne/U7OKIhmP43zCIuRBhJ5zMMz747u0=",
|
102 |
+
"path": "/home/ubuntu/Projects/kernel-builder",
|
103 |
+
"type": "path"
|
104 |
+
},
|
105 |
+
"original": {
|
106 |
+
"path": "/home/ubuntu/Projects/kernel-builder",
|
107 |
+
"type": "path"
|
108 |
+
}
|
109 |
+
},
|
110 |
+
"nixpkgs": {
|
111 |
+
"locked": {
|
112 |
+
"lastModified": 1747820358,
|
113 |
+
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
114 |
+
"owner": "danieldk",
|
115 |
+
"repo": "nixpkgs",
|
116 |
+
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
117 |
+
"type": "github"
|
118 |
+
},
|
119 |
+
"original": {
|
120 |
+
"owner": "danieldk",
|
121 |
+
"ref": "cudatoolkit-12.9-kernel-builder",
|
122 |
+
"repo": "nixpkgs",
|
123 |
+
"type": "github"
|
124 |
+
}
|
125 |
+
},
|
126 |
+
"root": {
|
127 |
+
"inputs": {
|
128 |
+
"kernel-builder": "kernel-builder"
|
129 |
+
}
|
130 |
+
},
|
131 |
+
"systems": {
|
132 |
+
"locked": {
|
133 |
+
"lastModified": 1681028828,
|
134 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
135 |
+
"owner": "nix-systems",
|
136 |
+
"repo": "default",
|
137 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
138 |
+
"type": "github"
|
139 |
+
},
|
140 |
+
"original": {
|
141 |
+
"owner": "nix-systems",
|
142 |
+
"repo": "default",
|
143 |
+
"type": "github"
|
144 |
+
}
|
145 |
+
},
|
146 |
+
"systems_2": {
|
147 |
+
"locked": {
|
148 |
+
"lastModified": 1681028828,
|
149 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
150 |
+
"owner": "nix-systems",
|
151 |
+
"repo": "default",
|
152 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
153 |
+
"type": "github"
|
154 |
+
},
|
155 |
+
"original": {
|
156 |
+
"owner": "nix-systems",
|
157 |
+
"repo": "default",
|
158 |
+
"type": "github"
|
159 |
+
}
|
160 |
+
}
|
161 |
+
},
|
162 |
+
"root": "root",
|
163 |
+
"version": 7
|
164 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for megablocks_moe kernel";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "path:/home/ubuntu/Projects/kernel-builder";
|
6 |
+
# kernel-builder.url = "github:huggingface/kernel-builder/v0.4.0";
|
7 |
+
};
|
8 |
+
|
9 |
+
outputs =
|
10 |
+
{
|
11 |
+
self,
|
12 |
+
kernel-builder,
|
13 |
+
}:
|
14 |
+
kernel-builder.lib.genFlakeOutputs {
|
15 |
+
path = ./.;
|
16 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
17 |
+
};
|
18 |
+
}
|
tests/__init__.py
ADDED
File without changes
|
tests/test_mb_moe.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import megablocks
|
2 |
+
|
3 |
+
def test_import():
|
4 |
+
"""Simple test to check if the module can be imported."""
|
5 |
+
print("megablocks_moe module imported successfully.")
|
6 |
+
print("Available functions:", dir(megablocks))
|
torch-ext/megablocks/__init__.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from ._ops import ops
|
7 |
+
|
8 |
+
from megablocks.layers.arguments import Arguments
|
9 |
+
from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE
|
10 |
+
from megablocks.layers.glu import SparseGLU
|
11 |
+
from megablocks.layers.mlp import MLP, SparseMLP
|
12 |
+
from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss
|
13 |
+
|
14 |
+
# This section contains the direct kernel exports (not inlcuded in the original code)
|
15 |
+
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
16 |
+
"""
|
17 |
+
Compute exclusive cumulative sum along the specified dimension.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
x: Input tensor
|
21 |
+
dim: Dimension along which to compute cumsum
|
22 |
+
out: Output tensor (modified in-place)
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
The output tensor
|
26 |
+
"""
|
27 |
+
return ops.exclusive_cumsum(x, dim, out)
|
28 |
+
|
29 |
+
|
30 |
+
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
31 |
+
"""
|
32 |
+
Compute inclusive cumulative sum along the specified dimension.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
x: Input tensor
|
36 |
+
dim: Dimension along which to compute cumsum
|
37 |
+
out: Output tensor (modified in-place)
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
The output tensor
|
41 |
+
"""
|
42 |
+
return ops.inclusive_cumsum(x, dim, out)
|
43 |
+
|
44 |
+
|
45 |
+
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
|
46 |
+
"""
|
47 |
+
Compute histogram of input tensor values.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
x: Input tensor
|
51 |
+
num_bins: Number of histogram bins
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Histogram tensor with counts for each bin
|
55 |
+
"""
|
56 |
+
return ops.histogram(x, num_bins)
|
57 |
+
|
58 |
+
|
59 |
+
def indices(
|
60 |
+
padded_bins: torch.Tensor,
|
61 |
+
block_size: int,
|
62 |
+
output_block_rows: int,
|
63 |
+
output_block_columns: int,
|
64 |
+
) -> torch.Tensor:
|
65 |
+
"""
|
66 |
+
Construct indices from padded bins for sparse operations.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
padded_bins: Tensor containing bin boundaries
|
70 |
+
block_size: Size of each block
|
71 |
+
output_block_rows: Number of rows in output blocks
|
72 |
+
output_block_columns: Number of columns in output blocks
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Tensor containing constructed indices
|
76 |
+
"""
|
77 |
+
return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
|
78 |
+
|
79 |
+
|
80 |
+
def replicate_forward(
|
81 |
+
x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
82 |
+
) -> torch.Tensor:
|
83 |
+
"""
|
84 |
+
Forward pass of replicate operation - replicate values according to bin sizes.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
x: Input tensor with values to replicate
|
88 |
+
bins: Tensor containing bin sizes
|
89 |
+
out: Output tensor (modified in-place)
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
The output tensor
|
93 |
+
"""
|
94 |
+
return ops.replicate_forward(x, bins, out)
|
95 |
+
|
96 |
+
|
97 |
+
def replicate_backward(
|
98 |
+
grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
99 |
+
) -> torch.Tensor:
|
100 |
+
"""
|
101 |
+
Backward pass of replicate operation - reduce gradients back to bins.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
grad: Gradient tensor to reduce
|
105 |
+
bins: Tensor containing bin sizes
|
106 |
+
out: Output tensor (modified in-place)
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
The output tensor
|
110 |
+
"""
|
111 |
+
return ops.replicate_backward(grad, bins, out)
|
112 |
+
|
113 |
+
|
114 |
+
def sort(
|
115 |
+
x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
|
116 |
+
) -> torch.Tensor:
|
117 |
+
"""
|
118 |
+
Radix sort with index tracking.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
x: Input tensor to sort
|
122 |
+
end_bit: Number of bits to consider in sorting
|
123 |
+
x_out: Output tensor for sorted values
|
124 |
+
iota_out: Output tensor for sorted indices
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
The sorted values tensor
|
128 |
+
"""
|
129 |
+
return ops.sort(x, end_bit, x_out, iota_out)
|
130 |
+
|
131 |
+
|
132 |
+
# Convenience functions for common use cases
|
133 |
+
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
|
134 |
+
"""
|
135 |
+
Compute cumulative sum with automatic output allocation.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
x: Input tensor
|
139 |
+
dim: Dimension along which to compute cumsum (default: last dimension)
|
140 |
+
exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
New tensor containing the cumulative sum
|
144 |
+
"""
|
145 |
+
out = torch.empty_like(x)
|
146 |
+
if exclusive:
|
147 |
+
return exclusive_cumsum(x, dim, out)
|
148 |
+
else:
|
149 |
+
return inclusive_cumsum(x, dim, out)
|
150 |
+
|
151 |
+
|
152 |
+
def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
|
153 |
+
"""
|
154 |
+
Sort tensor and return both sorted values and indices.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
x: Input tensor to sort
|
158 |
+
end_bit: Number of bits to consider in sorting
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
Tuple of (sorted_values, sorted_indices)
|
162 |
+
"""
|
163 |
+
x_out = torch.empty_like(x)
|
164 |
+
iota_out = torch.empty_like(x)
|
165 |
+
sort(x, end_bit, x_out, iota_out)
|
166 |
+
return x_out, iota_out
|
167 |
+
|
168 |
+
|
169 |
+
# Export public API
|
170 |
+
__all__ = [
|
171 |
+
# Direct kernel exports
|
172 |
+
"exclusive_cumsum",
|
173 |
+
"inclusive_cumsum",
|
174 |
+
"histogram",
|
175 |
+
"indices",
|
176 |
+
"replicate_forward",
|
177 |
+
"replicate_backward",
|
178 |
+
"sort",
|
179 |
+
"cumsum",
|
180 |
+
"argsort",
|
181 |
+
# Original exports
|
182 |
+
"Arguments",
|
183 |
+
"ParallelDroplessMLP",
|
184 |
+
"dMoE",
|
185 |
+
"SparseGLU",
|
186 |
+
"MLP",
|
187 |
+
"SparseMLP",
|
188 |
+
"MoE",
|
189 |
+
"ParallelMLP",
|
190 |
+
"get_load_balancing_loss",
|
191 |
+
]
|
torch-ext/megablocks/_version.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""The MegaBlocks Version."""
|
5 |
+
|
6 |
+
__version__ = '0.11.0.dev0'
|
torch-ext/megablocks/backend/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
torch-ext/megablocks/backend/kernels.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
|
9 |
+
def assert_is_tensor(x, ndim):
|
10 |
+
if x.ndim != ndim:
|
11 |
+
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
|
12 |
+
|
13 |
+
|
14 |
+
def assert_is_matrix(x):
|
15 |
+
assert_is_tensor(x, 2)
|
16 |
+
|
17 |
+
|
18 |
+
def assert_is_vector(x):
|
19 |
+
if x.ndim != 1:
|
20 |
+
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
|
21 |
+
|
22 |
+
|
23 |
+
def assert_equal(a, b):
|
24 |
+
if a != b:
|
25 |
+
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
|
26 |
+
|
27 |
+
|
28 |
+
# a: (tokens, hidden_size), real.
|
29 |
+
# indices: (tokens * top_k), integer.
|
30 |
+
# bin_ids: (tokens * top_k), integer.
|
31 |
+
# weights: (tokens * top_k), real.
|
32 |
+
# bins: (num_experts), integer.
|
33 |
+
# padded_bins: (num_experts), integer.
|
34 |
+
@triton.autotune(
|
35 |
+
configs=[
|
36 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
37 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
38 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
39 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
40 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
41 |
+
],
|
42 |
+
key=['NUM_COLUMNS'],
|
43 |
+
)
|
44 |
+
@triton.jit
|
45 |
+
def _padded_copy(
|
46 |
+
a,
|
47 |
+
b,
|
48 |
+
indices,
|
49 |
+
bin_ids,
|
50 |
+
weights,
|
51 |
+
bins,
|
52 |
+
padded_bins,
|
53 |
+
NUM_COLUMNS: tl.constexpr,
|
54 |
+
TOP_K: tl.constexpr,
|
55 |
+
BLOCK_X: tl.constexpr,
|
56 |
+
A_TO_B: tl.constexpr,
|
57 |
+
SCALE: tl.constexpr,
|
58 |
+
):
|
59 |
+
# Our index into array 'a'.
|
60 |
+
index_a = tl.load(indices + tl.program_id(0))
|
61 |
+
|
62 |
+
# One threadblock per row in 'a'. Array 'b' has greater or equal
|
63 |
+
# number of rows since they could be padded.
|
64 |
+
bin_idx = tl.load(bin_ids + tl.program_id(0))
|
65 |
+
|
66 |
+
# Now we know what bin we're assigned to, but we need to know how
|
67 |
+
# many threadblocks were assigned to earlier bins so we can offset
|
68 |
+
# in our bin properly.
|
69 |
+
offset_in_bin = tl.program_id(0)
|
70 |
+
if bin_idx > 0:
|
71 |
+
offset_in_bin -= tl.load(bins + bin_idx - 1)
|
72 |
+
|
73 |
+
# Load the starting index of our bin in array 'b'.
|
74 |
+
index_b = offset_in_bin
|
75 |
+
if bin_idx > 0:
|
76 |
+
index_b += tl.load(padded_bins + bin_idx - 1)
|
77 |
+
|
78 |
+
# Offset the input and output pointers.
|
79 |
+
#
|
80 |
+
# If we're going from A to B, divide the input index to copy
|
81 |
+
# the same input repeatedly. If we're going from B to A we
|
82 |
+
# need to reduce the result. Using atomics is slow, so we
|
83 |
+
# do the reduce step in a second kernel.
|
84 |
+
offset = index_a // TOP_K if A_TO_B else index_a
|
85 |
+
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
|
86 |
+
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
|
87 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
88 |
+
|
89 |
+
# Load the scale, if requested.
|
90 |
+
scale = tl.load(weights + index_a) if SCALE else 1
|
91 |
+
|
92 |
+
# Swap the pointers depending on the direction.
|
93 |
+
iptr = a if A_TO_B else b
|
94 |
+
optr = b if A_TO_B else a
|
95 |
+
|
96 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
97 |
+
for _ in range(iterations):
|
98 |
+
mask = offsets < NUM_COLUMNS
|
99 |
+
x = tl.load(iptr + offsets, mask=mask)
|
100 |
+
x = x.to(tl.float32) * scale.to(tl.float32)
|
101 |
+
|
102 |
+
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
|
103 |
+
|
104 |
+
offsets += BLOCK_X
|
105 |
+
|
106 |
+
|
107 |
+
def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
|
108 |
+
# Validate the input shapes.
|
109 |
+
assert_is_matrix(x)
|
110 |
+
assert_is_vector(indices)
|
111 |
+
assert_is_vector(bin_ids)
|
112 |
+
assert_is_vector(bins)
|
113 |
+
assert_is_vector(padded_bins)
|
114 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
115 |
+
assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
|
116 |
+
assert_equal(bins.size(), padded_bins.size())
|
117 |
+
|
118 |
+
if weights is not None:
|
119 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
120 |
+
|
121 |
+
# NOTE: Because of the padding, the output size is dynamic.
|
122 |
+
# We load the final padded bin bound to get the output rows.
|
123 |
+
output_rows = padded_bins[-1].cpu().item()
|
124 |
+
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
|
125 |
+
_padded_copy[(indices.shape[0],)](
|
126 |
+
x,
|
127 |
+
out,
|
128 |
+
indices,
|
129 |
+
bin_ids,
|
130 |
+
weights,
|
131 |
+
bins,
|
132 |
+
padded_bins,
|
133 |
+
NUM_COLUMNS=x.shape[1],
|
134 |
+
A_TO_B=True,
|
135 |
+
TOP_K=top_k,
|
136 |
+
SCALE=weights is not None,
|
137 |
+
)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def gather(x, indices, bin_ids, weights, bins, top_k):
|
142 |
+
# Validate the input shapes.
|
143 |
+
assert_is_matrix(x)
|
144 |
+
assert_is_vector(indices)
|
145 |
+
assert_is_vector(bin_ids)
|
146 |
+
assert_is_vector(bins)
|
147 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
148 |
+
assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
|
149 |
+
|
150 |
+
if weights is not None:
|
151 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
152 |
+
|
153 |
+
# NOTE: There is no padding so the output rows equals the
|
154 |
+
# input rows multiplied by top_k.
|
155 |
+
output_rows = x.shape[0] * top_k
|
156 |
+
out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
|
157 |
+
_padded_copy[(indices.shape[0],)](
|
158 |
+
x,
|
159 |
+
out,
|
160 |
+
indices,
|
161 |
+
bin_ids,
|
162 |
+
weights,
|
163 |
+
bins,
|
164 |
+
bins,
|
165 |
+
NUM_COLUMNS=x.shape[1],
|
166 |
+
A_TO_B=True,
|
167 |
+
TOP_K=top_k,
|
168 |
+
SCALE=weights is not None,
|
169 |
+
)
|
170 |
+
return out
|
171 |
+
|
172 |
+
|
173 |
+
def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
|
174 |
+
# Validate the input shapes.
|
175 |
+
assert_is_matrix(x)
|
176 |
+
assert_is_vector(indices)
|
177 |
+
assert_is_vector(bin_ids)
|
178 |
+
assert_is_vector(bins)
|
179 |
+
assert_is_vector(padded_bins)
|
180 |
+
assert_equal(indices.shape[0], bin_ids.shape[0])
|
181 |
+
assert_equal(bins.size(), padded_bins.size())
|
182 |
+
|
183 |
+
if weights is not None:
|
184 |
+
assert_equal(indices.shape[0], weights.shape[0])
|
185 |
+
|
186 |
+
tokens = indices.shape[0] // top_k
|
187 |
+
out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
|
188 |
+
_padded_copy[(indices.shape[0],)](
|
189 |
+
out,
|
190 |
+
x,
|
191 |
+
indices,
|
192 |
+
bin_ids,
|
193 |
+
weights,
|
194 |
+
bins,
|
195 |
+
padded_bins,
|
196 |
+
NUM_COLUMNS=x.shape[1],
|
197 |
+
A_TO_B=False,
|
198 |
+
TOP_K=top_k,
|
199 |
+
SCALE=weights is not None,
|
200 |
+
)
|
201 |
+
|
202 |
+
# Reduce along the top-k dimension, if needed.
|
203 |
+
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
|
204 |
+
|
205 |
+
|
206 |
+
def scatter(x, indices, bin_ids, weights, bins, top_k):
|
207 |
+
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
|
208 |
+
|
209 |
+
|
210 |
+
# x: (tokens, top_k, hidden_size), real
|
211 |
+
# grad: (tokens, hidden_size), real.
|
212 |
+
# wgrad: (tokens, top_k), real.
|
213 |
+
# indices: (tokens * top_k), integer.
|
214 |
+
# bin_ids: (tokens * top_k), integer.
|
215 |
+
# bins: (num_experts), integer.
|
216 |
+
# padded_bins: (num_experts), integer.
|
217 |
+
@triton.autotune(
|
218 |
+
configs=[
|
219 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
220 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
221 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
222 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
223 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
224 |
+
],
|
225 |
+
key=['NUM_COLUMNS'],
|
226 |
+
)
|
227 |
+
@triton.jit
|
228 |
+
def _padded_copy_wgrad(
|
229 |
+
x,
|
230 |
+
grad,
|
231 |
+
wgrad,
|
232 |
+
indices,
|
233 |
+
bin_ids,
|
234 |
+
bins,
|
235 |
+
padded_bins,
|
236 |
+
NUM_COLUMNS: tl.constexpr,
|
237 |
+
TOP_K: tl.constexpr,
|
238 |
+
BLOCK_X: tl.constexpr,
|
239 |
+
):
|
240 |
+
# Our index into 'tokens * top_k'.
|
241 |
+
index_out = tl.load(indices + tl.program_id(0))
|
242 |
+
|
243 |
+
# One threadblock per row in 'a'. Array 'b' has greater or equal
|
244 |
+
# number of rows since they could be padded.
|
245 |
+
bin_idx = tl.load(bin_ids + tl.program_id(0))
|
246 |
+
|
247 |
+
# Now we know what bin we're assigned to, but we need to know how
|
248 |
+
# many threadblocks were assigned to earlier bins so we can offset
|
249 |
+
# in our bin properly.
|
250 |
+
offset_in_bin = tl.program_id(0)
|
251 |
+
if bin_idx > 0:
|
252 |
+
offset_in_bin -= tl.load(bins + bin_idx - 1)
|
253 |
+
|
254 |
+
# Load the starting index of our bin in array 'x'.
|
255 |
+
index_x = offset_in_bin
|
256 |
+
if bin_idx > 0:
|
257 |
+
index_x += tl.load(padded_bins + bin_idx - 1)
|
258 |
+
|
259 |
+
# Offset the input and output pointers.
|
260 |
+
wgrad += index_out
|
261 |
+
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
|
262 |
+
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
|
263 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
264 |
+
|
265 |
+
acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
|
266 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
267 |
+
for _ in range(iterations):
|
268 |
+
mask = offsets < NUM_COLUMNS
|
269 |
+
data = tl.load(x + offsets, mask=mask).to(tl.float32)
|
270 |
+
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
|
271 |
+
acc += data * scale
|
272 |
+
offsets += BLOCK_X
|
273 |
+
|
274 |
+
# Reduce to get the final result and store.
|
275 |
+
out = tl.sum(acc).to(wgrad.dtype.element_ty)
|
276 |
+
tl.store(wgrad, out)
|
277 |
+
|
278 |
+
|
279 |
+
def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
|
280 |
+
# Validate the input shapes.
|
281 |
+
assert_is_matrix(x)
|
282 |
+
assert_is_matrix(grad)
|
283 |
+
assert_is_vector(indices)
|
284 |
+
assert_is_vector(bin_ids)
|
285 |
+
assert_is_vector(bins)
|
286 |
+
assert_is_vector(padded_bins)
|
287 |
+
assert_equal(indices.shape[0], bin_ids.shape[0])
|
288 |
+
assert_equal(bins.size(), padded_bins.size())
|
289 |
+
|
290 |
+
tokens = indices.shape[0] // top_k
|
291 |
+
out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
|
292 |
+
_padded_copy_wgrad[(indices.shape[0],)](
|
293 |
+
x,
|
294 |
+
grad,
|
295 |
+
out,
|
296 |
+
indices,
|
297 |
+
bin_ids,
|
298 |
+
bins,
|
299 |
+
padded_bins,
|
300 |
+
NUM_COLUMNS=x.shape[1],
|
301 |
+
TOP_K=top_k,
|
302 |
+
)
|
303 |
+
return out
|
304 |
+
|
305 |
+
|
306 |
+
def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
|
307 |
+
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
|
308 |
+
|
309 |
+
|
310 |
+
# a: (tokens, hidden_size), real.
|
311 |
+
# b: (num_experts, expert_capacity, num_columns), real.
|
312 |
+
# indices: (tokens * top_k), integer.
|
313 |
+
# weights: (tokens * top_k), real.
|
314 |
+
# bins: (num_experts), integer.
|
315 |
+
@triton.autotune(
|
316 |
+
configs=[
|
317 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
318 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
319 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
320 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
321 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
322 |
+
],
|
323 |
+
key=['NUM_COLUMNS'],
|
324 |
+
)
|
325 |
+
@triton.jit
|
326 |
+
def _binned_copy(
|
327 |
+
a,
|
328 |
+
b,
|
329 |
+
num_experts,
|
330 |
+
expert_capacity,
|
331 |
+
indices,
|
332 |
+
weights,
|
333 |
+
bins,
|
334 |
+
NUM_COLUMNS: tl.constexpr,
|
335 |
+
TOP_K: tl.constexpr,
|
336 |
+
BLOCK_X: tl.constexpr,
|
337 |
+
A_TO_B: tl.constexpr,
|
338 |
+
SCALE: tl.constexpr,
|
339 |
+
):
|
340 |
+
# Load our indices into the output.
|
341 |
+
expert_idx = tl.program_id(0)
|
342 |
+
entry_idx = tl.program_id(1)
|
343 |
+
|
344 |
+
# Calculate our offset into the output.
|
345 |
+
index_b = expert_idx * expert_capacity + entry_idx
|
346 |
+
|
347 |
+
# Load the index bounds for our bin and calculate
|
348 |
+
# the number of tokens assigned to our expert.
|
349 |
+
start = 0
|
350 |
+
if expert_idx > 0:
|
351 |
+
start = tl.load(bins + expert_idx - 1)
|
352 |
+
end = tl.load(bins + expert_idx)
|
353 |
+
num_tokens = end - start
|
354 |
+
|
355 |
+
# Calculate our offset into the input. If we don't
|
356 |
+
# have an input exit early.
|
357 |
+
if entry_idx >= num_tokens:
|
358 |
+
return
|
359 |
+
index_a = tl.load(indices + start + entry_idx)
|
360 |
+
|
361 |
+
# Offset the input and output pointers.
|
362 |
+
#
|
363 |
+
# If we're going from A to B, divide the input index to copy
|
364 |
+
# the same input repeatedly. If we're going from B to A we
|
365 |
+
# need to reduce the result. Using atomics is slow, so we
|
366 |
+
# do the reduce step in a second kernel.
|
367 |
+
offset = index_a // TOP_K if A_TO_B else index_a
|
368 |
+
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
|
369 |
+
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
|
370 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
371 |
+
|
372 |
+
# Load the scale, if requested.
|
373 |
+
scale = tl.load(weights + index_a) if SCALE else 1
|
374 |
+
|
375 |
+
# Swap the pointers depending on the direction.
|
376 |
+
#
|
377 |
+
# NOTE: We need to zero the output in both directions.
|
378 |
+
iptr = a if A_TO_B else b
|
379 |
+
optr = b if A_TO_B else a
|
380 |
+
|
381 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
382 |
+
for _ in range(iterations):
|
383 |
+
mask = offsets < NUM_COLUMNS
|
384 |
+
x = tl.load(iptr + offsets, mask=mask)
|
385 |
+
x = x.to(tl.float32) * scale.to(tl.float32)
|
386 |
+
|
387 |
+
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
|
388 |
+
|
389 |
+
offsets += BLOCK_X
|
390 |
+
|
391 |
+
|
392 |
+
def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
393 |
+
# Validate the input shapes.
|
394 |
+
assert_is_matrix(x)
|
395 |
+
assert_is_vector(indices)
|
396 |
+
assert_is_vector(bins)
|
397 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
398 |
+
|
399 |
+
if weights is not None:
|
400 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
401 |
+
|
402 |
+
num_experts = bins.shape[0]
|
403 |
+
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
404 |
+
|
405 |
+
_binned_copy[(num_experts, expert_capacity)](
|
406 |
+
x,
|
407 |
+
out,
|
408 |
+
num_experts,
|
409 |
+
expert_capacity,
|
410 |
+
indices,
|
411 |
+
weights,
|
412 |
+
bins,
|
413 |
+
NUM_COLUMNS=x.shape[1],
|
414 |
+
A_TO_B=True,
|
415 |
+
TOP_K=top_k,
|
416 |
+
SCALE=weights is not None,
|
417 |
+
)
|
418 |
+
return out
|
419 |
+
|
420 |
+
|
421 |
+
def binned_scatter(x, indices, weights, bins, top_k):
|
422 |
+
# Validate the input shapes.
|
423 |
+
assert_is_tensor(x, 3)
|
424 |
+
assert_is_vector(indices)
|
425 |
+
assert_is_vector(bins)
|
426 |
+
assert_equal(bins.shape[0], x.shape[0])
|
427 |
+
|
428 |
+
if weights is not None:
|
429 |
+
assert_equal(indices.shape[0], weights.shape[0])
|
430 |
+
|
431 |
+
num_experts, expert_capacity, hidden_size = x.shape
|
432 |
+
tokens = indices.shape[0] // top_k
|
433 |
+
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
434 |
+
_binned_copy[(num_experts, expert_capacity)](
|
435 |
+
out,
|
436 |
+
x,
|
437 |
+
num_experts,
|
438 |
+
expert_capacity,
|
439 |
+
indices,
|
440 |
+
weights,
|
441 |
+
bins,
|
442 |
+
NUM_COLUMNS=hidden_size,
|
443 |
+
A_TO_B=False,
|
444 |
+
TOP_K=top_k,
|
445 |
+
SCALE=weights is not None,
|
446 |
+
)
|
447 |
+
|
448 |
+
# Reduce along the top-k dimension, if needed.
|
449 |
+
return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
|
450 |
+
|
451 |
+
|
452 |
+
# a: (tokens, hidden_size), real.
|
453 |
+
# b: (num_experts, expert_capacity, num_columns), real.
|
454 |
+
# indices: (tokens * top_k), integer.
|
455 |
+
# weights: (tokens * top_k), real.
|
456 |
+
# bins: (num_experts), integer.
|
457 |
+
@triton.autotune(
|
458 |
+
configs=[
|
459 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
460 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
461 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
462 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
463 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
464 |
+
],
|
465 |
+
key=['NUM_COLUMNS'],
|
466 |
+
)
|
467 |
+
@triton.jit
|
468 |
+
def _binned_copy_wgrad(
|
469 |
+
x,
|
470 |
+
grad,
|
471 |
+
wgrad,
|
472 |
+
num_experts,
|
473 |
+
expert_capacity,
|
474 |
+
indices,
|
475 |
+
bins,
|
476 |
+
NUM_COLUMNS: tl.constexpr,
|
477 |
+
TOP_K: tl.constexpr,
|
478 |
+
BLOCK_X: tl.constexpr,
|
479 |
+
):
|
480 |
+
# Load our indices into the output.
|
481 |
+
expert_idx = tl.program_id(0)
|
482 |
+
entry_idx = tl.program_id(1)
|
483 |
+
|
484 |
+
# Calculate our offset into the output.
|
485 |
+
index_x = expert_idx * expert_capacity + entry_idx
|
486 |
+
|
487 |
+
# Load the index bounds for our bin and calculate
|
488 |
+
# the number of tokens assigned to our expert.
|
489 |
+
start = 0
|
490 |
+
if expert_idx > 0:
|
491 |
+
start = tl.load(bins + expert_idx - 1)
|
492 |
+
end = tl.load(bins + expert_idx)
|
493 |
+
num_tokens = end - start
|
494 |
+
|
495 |
+
# Calculate our offset into the input. If we don't
|
496 |
+
# have an input exit early.
|
497 |
+
if entry_idx >= num_tokens:
|
498 |
+
return
|
499 |
+
index_out = tl.load(indices + start + entry_idx)
|
500 |
+
|
501 |
+
# Offset the input and output pointers.
|
502 |
+
wgrad += index_out
|
503 |
+
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
|
504 |
+
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
|
505 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
506 |
+
|
507 |
+
acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
|
508 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
509 |
+
for _ in range(iterations):
|
510 |
+
mask = offsets < NUM_COLUMNS
|
511 |
+
data = tl.load(x + offsets, mask=mask).to(tl.float32)
|
512 |
+
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
|
513 |
+
acc += data * scale
|
514 |
+
offsets += BLOCK_X
|
515 |
+
|
516 |
+
# Reduce to get the final result and store.
|
517 |
+
out = tl.sum(acc).to(wgrad.dtype.element_ty)
|
518 |
+
tl.store(wgrad, out)
|
519 |
+
|
520 |
+
|
521 |
+
def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
522 |
+
# Validate the input shapes.
|
523 |
+
assert_is_tensor(x, 3)
|
524 |
+
assert_is_matrix(grad)
|
525 |
+
assert_is_vector(indices)
|
526 |
+
assert_is_vector(bins)
|
527 |
+
assert_equal(bins.shape[0], x.shape[0])
|
528 |
+
|
529 |
+
num_experts, expert_capacity, hidden_size = x.shape
|
530 |
+
tokens = indices.shape[0] // top_k
|
531 |
+
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
532 |
+
_binned_copy_wgrad[(num_experts, expert_capacity)](
|
533 |
+
x,
|
534 |
+
grad,
|
535 |
+
out,
|
536 |
+
num_experts,
|
537 |
+
expert_capacity,
|
538 |
+
indices,
|
539 |
+
bins,
|
540 |
+
NUM_COLUMNS=hidden_size,
|
541 |
+
TOP_K=top_k,
|
542 |
+
)
|
543 |
+
return out
|
torch-ext/megablocks/bak.__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from megablocks_moe.megablocks import (
|
2 |
+
MoE,
|
3 |
+
dMoE,
|
4 |
+
get_load_balancing_loss,
|
5 |
+
ParallelMLP,
|
6 |
+
ParallelDroplessMLP,
|
7 |
+
SparseMLP,
|
8 |
+
MLP,
|
9 |
+
SparseGLU,
|
10 |
+
Arguments,
|
11 |
+
)
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"MoE",
|
15 |
+
"dMoE",
|
16 |
+
"get_load_balancing_loss",
|
17 |
+
"ParallelMLP",
|
18 |
+
"ParallelDroplessMLP",
|
19 |
+
"SparseMLP",
|
20 |
+
"MLP",
|
21 |
+
"SparseGLU",
|
22 |
+
"Arguments",
|
23 |
+
]
|
torch-ext/megablocks/benchmark_util.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def log_benchmark(name, arguments, time, std):
|
9 |
+
print('=' * 60)
|
10 |
+
print(f'{name} Benchmark')
|
11 |
+
print('Benchmark Parameters:')
|
12 |
+
for (key, value) in arguments.items():
|
13 |
+
print(f'{key} = {value}')
|
14 |
+
print('Results:')
|
15 |
+
print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
|
16 |
+
print('=' * 60)
|
17 |
+
|
18 |
+
|
19 |
+
def benchmark_function(fn, iterations=100, warmup=10):
|
20 |
+
# Warmup iterations.
|
21 |
+
for _ in range(warmup):
|
22 |
+
fn()
|
23 |
+
|
24 |
+
times = []
|
25 |
+
for i in range(iterations):
|
26 |
+
start = torch.cuda.Event(enable_timing=True)
|
27 |
+
end = torch.cuda.Event(enable_timing=True)
|
28 |
+
|
29 |
+
start.record()
|
30 |
+
fn()
|
31 |
+
end.record()
|
32 |
+
|
33 |
+
torch.cuda.synchronize()
|
34 |
+
times.append(start.elapsed_time(end))
|
35 |
+
return np.mean(times), np.std(times)
|
torch-ext/megablocks/grouped_gemm_util.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
_grouped_gemm_is_available: bool = False
|
6 |
+
try:
|
7 |
+
import grouped_gemm
|
8 |
+
_grouped_gemm_is_available = True
|
9 |
+
except ImportError as error:
|
10 |
+
warnings.warn('Grouped GEMM not available.')
|
11 |
+
|
12 |
+
|
13 |
+
def grouped_gemm_is_available():
|
14 |
+
return _grouped_gemm_is_available
|
15 |
+
|
16 |
+
|
17 |
+
def assert_grouped_gemm_is_available():
|
18 |
+
msg = (
|
19 |
+
'Grouped GEMM not available. Please run '
|
20 |
+
'`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
|
21 |
+
)
|
22 |
+
assert _grouped_gemm_is_available, msg
|
23 |
+
|
24 |
+
|
25 |
+
backend = grouped_gemm.backend if grouped_gemm_is_available() else None
|
26 |
+
ops = grouped_gemm.ops if grouped_gemm_is_available() else None
|
torch-ext/megablocks/layers/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# from megablocks.layers.dmoe import dMoE
|
5 |
+
from megablocks.layers.moe import MoE
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'MoE',
|
9 |
+
# 'dMoE',
|
10 |
+
]
|
torch-ext/megablocks/layers/activation_fn.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Callable, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from stk import Matrix
|
8 |
+
|
9 |
+
|
10 |
+
def act_fn(
|
11 |
+
x: Matrix,
|
12 |
+
function: Callable,
|
13 |
+
return_grad_fn: bool = False,
|
14 |
+
**kwargs,
|
15 |
+
) -> Union[tuple[Matrix, Any] | Matrix]:
|
16 |
+
assert isinstance(x, Matrix)
|
17 |
+
with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
|
18 |
+
if return_grad_fn:
|
19 |
+
x.data.requires_grad = True
|
20 |
+
out = function(x.data, **kwargs)
|
21 |
+
y = Matrix(
|
22 |
+
x.size(),
|
23 |
+
out,
|
24 |
+
x.row_indices,
|
25 |
+
x.column_indices,
|
26 |
+
x.offsets,
|
27 |
+
x.column_indices_t,
|
28 |
+
x.offsets_t,
|
29 |
+
x.block_offsets_t,
|
30 |
+
)
|
31 |
+
if return_grad_fn:
|
32 |
+
return y, out.backward
|
33 |
+
return y
|
torch-ext/megablocks/layers/all_to_all.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
class AllToAllOp(torch.autograd.Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
|
12 |
+
out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
|
13 |
+
|
14 |
+
ctx.input_shape = x.shape
|
15 |
+
ctx.output_split_sizes = output_split_sizes
|
16 |
+
ctx.input_split_sizes = input_split_sizes
|
17 |
+
ctx.group = group
|
18 |
+
handle = dist.all_to_all_single(
|
19 |
+
out,
|
20 |
+
x,
|
21 |
+
output_split_sizes=output_split_sizes,
|
22 |
+
input_split_sizes=input_split_sizes,
|
23 |
+
group=group,
|
24 |
+
async_op=async_op,
|
25 |
+
)
|
26 |
+
return out, handle
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, grad, _):
|
30 |
+
if ctx.needs_input_grad[0]:
|
31 |
+
out = torch.empty(
|
32 |
+
ctx.input_shape,
|
33 |
+
device=grad.device,
|
34 |
+
dtype=grad.dtype,
|
35 |
+
)
|
36 |
+
dist.all_to_all_single(
|
37 |
+
out,
|
38 |
+
grad,
|
39 |
+
output_split_sizes=ctx.input_split_sizes,
|
40 |
+
input_split_sizes=ctx.output_split_sizes,
|
41 |
+
group=ctx.group,
|
42 |
+
)
|
43 |
+
return out, None, None, None, None
|
44 |
+
return None, None, None, None, None
|
45 |
+
|
46 |
+
|
47 |
+
def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
|
48 |
+
return AllToAllOp.apply(
|
49 |
+
x,
|
50 |
+
output_split_sizes,
|
51 |
+
input_split_sizes,
|
52 |
+
group,
|
53 |
+
async_op,
|
54 |
+
)
|
torch-ext/megablocks/layers/arguments.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import dataclasses
|
5 |
+
from functools import partial
|
6 |
+
from typing import Any, Callable, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import megablocks.grouped_gemm_util as grouped_gemm
|
13 |
+
|
14 |
+
# Type annotation for in-place Tensor initialization function.
|
15 |
+
InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
|
16 |
+
|
17 |
+
_ALLOWED_BITWIDTHS = (-1, 4, 8)
|
18 |
+
|
19 |
+
DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
|
20 |
+
|
21 |
+
|
22 |
+
@dataclasses.dataclass
|
23 |
+
class Arguments:
|
24 |
+
# Model arguments.
|
25 |
+
hidden_size: int = 1024
|
26 |
+
ffn_hidden_size: int = 4096
|
27 |
+
num_layers: int = 1
|
28 |
+
bias: bool = True
|
29 |
+
return_bias: bool = True
|
30 |
+
activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
|
31 |
+
|
32 |
+
# MoE arguments.
|
33 |
+
moe_num_experts: int = 1
|
34 |
+
moe_top_k: int = 1
|
35 |
+
moe_capacity_factor: int = 1
|
36 |
+
moe_normalize_expert_weights: Optional[Union[int, float]] = None
|
37 |
+
moe_loss_weight: float = 0.1
|
38 |
+
moe_jitter_eps: Optional[float] = None
|
39 |
+
moe_lbl_in_fp32: bool = False
|
40 |
+
|
41 |
+
# Parallelism arguments.
|
42 |
+
moe_expert_model_parallelism: bool = False
|
43 |
+
expert_parallel_group: Optional[dist.ProcessGroup] = None
|
44 |
+
pipeline_model_parallel_size: int = 1
|
45 |
+
num_layers_per_virtual_pipeline_stage: Optional[int] = None
|
46 |
+
|
47 |
+
# Compute arguments.
|
48 |
+
memory_optimized_mlp: bool = False
|
49 |
+
mlp_type: str = 'mlp'
|
50 |
+
mlp_impl: str = 'sparse'
|
51 |
+
|
52 |
+
# Initialization arguments.
|
53 |
+
fp16: bool = True
|
54 |
+
bf16: bool = False
|
55 |
+
device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
|
56 |
+
init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
|
57 |
+
output_layer_init_method: InitFn = init_method
|
58 |
+
|
59 |
+
# Benchmarking arguments.
|
60 |
+
uniform_expert_assignment: bool = False
|
61 |
+
|
62 |
+
# shared expert arguments
|
63 |
+
shared_expert: bool = False # enable using shared expert
|
64 |
+
fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
|
65 |
+
fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
|
66 |
+
remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
|
67 |
+
shared_expert_hidden_size: Optional[
|
68 |
+
int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
|
69 |
+
shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
|
70 |
+
|
71 |
+
# Router Z-loss arguments
|
72 |
+
moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
|
73 |
+
moe_zloss_in_fp32: bool = False
|
74 |
+
|
75 |
+
def __post_init__(self):
|
76 |
+
# Sparse MLP is not supported with triton >=3.2.0
|
77 |
+
# TODO: Remove this once sparse is supported with triton >=3.2.0
|
78 |
+
if self.__getattribute__('mlp_impl') == 'sparse':
|
79 |
+
try:
|
80 |
+
import triton
|
81 |
+
if triton.__version__ >= '3.2.0':
|
82 |
+
raise ValueError(
|
83 |
+
'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
|
84 |
+
)
|
85 |
+
except ImportError:
|
86 |
+
raise ImportError('Triton is required for sparse MLP implementation')
|
87 |
+
|
88 |
+
if self.__getattribute__('mlp_impl') == 'grouped':
|
89 |
+
grouped_gemm.assert_grouped_gemm_is_available()
|
90 |
+
|
91 |
+
if self.shared_expert_hidden_size is None:
|
92 |
+
self.shared_expert_hidden_size = self.ffn_hidden_size
|
93 |
+
|
94 |
+
|
95 |
+
def from_megatron(megatron_args: Any):
|
96 |
+
args = Arguments()
|
97 |
+
for field in dataclasses.fields(args):
|
98 |
+
if hasattr(megatron_args, field.name):
|
99 |
+
setattr(args, field.name, getattr(megatron_args, field.name))
|
100 |
+
return args
|
torch-ext/megablocks/layers/common.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from megablocks.layers.arguments import Arguments
|
7 |
+
|
8 |
+
|
9 |
+
def dtype(args: Arguments):
|
10 |
+
if args.fp16:
|
11 |
+
return torch.float16
|
12 |
+
elif args.bf16:
|
13 |
+
return torch.bfloat16
|
14 |
+
return None
|
15 |
+
|
16 |
+
|
17 |
+
def cast_if_autocast_enabled(tensor):
|
18 |
+
if torch.is_autocast_enabled():
|
19 |
+
if tensor.device.type == 'cuda':
|
20 |
+
dtype = torch.get_autocast_gpu_dtype()
|
21 |
+
elif tensor.device.type == 'cpu':
|
22 |
+
dtype = torch.get_autocast_cpu_dtype()
|
23 |
+
else:
|
24 |
+
raise NotImplementedError()
|
25 |
+
return tensor.to(dtype=dtype)
|
26 |
+
return tensor
|
torch-ext/megablocks/layers/dmlp_registry.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from megablocks.layers import glu, mlp
|
7 |
+
from megablocks.layers.arguments import Arguments
|
8 |
+
|
9 |
+
MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
|
10 |
+
|
11 |
+
_REGISTRY = {
|
12 |
+
'mlp': {
|
13 |
+
'grouped': mlp.GroupedMLP,
|
14 |
+
'sparse': mlp.SparseMLP,
|
15 |
+
},
|
16 |
+
'glu': {
|
17 |
+
'grouped': glu.GroupedGLU,
|
18 |
+
'sparse': glu.SparseGLU,
|
19 |
+
},
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
def get(args: Arguments) -> MlpType:
|
24 |
+
"""Returns an MLP for use in a dMoE instance.
|
25 |
+
|
26 |
+
Uses the provided arguments to instantiate the appropriate
|
27 |
+
MLP instance. This only contains MLPs for use in dMoEs
|
28 |
+
(ie. only for the dropless versions of MoEs).
|
29 |
+
|
30 |
+
Args:
|
31 |
+
args: propagated Arguments dataclass.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
An instantiated MLP constructed using the input args.
|
35 |
+
"""
|
36 |
+
if args.mlp_type not in _REGISTRY:
|
37 |
+
raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
|
38 |
+
|
39 |
+
if args.mlp_impl not in _REGISTRY[args.mlp_type]:
|
40 |
+
raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
|
41 |
+
|
42 |
+
return _REGISTRY[args.mlp_type][args.mlp_impl](args)
|
torch-ext/megablocks/layers/dmoe.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import stk.ops
|
6 |
+
import torch
|
7 |
+
from stk import Matrix
|
8 |
+
|
9 |
+
import megablocks.ops as ops
|
10 |
+
# from megablocks.ops import ops
|
11 |
+
from megablocks.layers import common, dmlp_registry, moe, mpu
|
12 |
+
from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
|
15 |
+
def promote_scalar(x):
|
16 |
+
return x.view(1) if not len(x.size()) else x
|
17 |
+
|
18 |
+
|
19 |
+
class ParallelDroplessMLP(moe.ParallelMLP):
|
20 |
+
|
21 |
+
def __init__(self, args: Arguments):
|
22 |
+
super(ParallelDroplessMLP, self).__init__(args)
|
23 |
+
self.hidden_size = args.hidden_size
|
24 |
+
self.ffn_hidden_size = mpu.features_per_rank(args)
|
25 |
+
self.blocking = 128
|
26 |
+
self.mlp = dmlp_registry.get(args)
|
27 |
+
|
28 |
+
# Calculate the number of bits needed to represent the column indices
|
29 |
+
# in the intermediate sparse matrix.
|
30 |
+
max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
|
31 |
+
self.transpose_sort_end_bit = max(
|
32 |
+
int(np.ceil(np.log2(max_column_index))),
|
33 |
+
1,
|
34 |
+
)
|
35 |
+
|
36 |
+
def sparse_transpose(self, size, row_indices, column_indices, offsets):
|
37 |
+
block_columns = size[1] // self.blocking
|
38 |
+
|
39 |
+
# Sort row indices by column indices to get the transposed matrix's
|
40 |
+
# column indices.
|
41 |
+
#
|
42 |
+
# NOTE: Our sort operation uses the same width indices as the input values.
|
43 |
+
# To avoid overflow when we have large activation matrices we cast to
|
44 |
+
# 32-bit before sorting.
|
45 |
+
_, gather_indices = ops.sort(
|
46 |
+
column_indices.int(),
|
47 |
+
self.transpose_sort_end_bit,
|
48 |
+
)
|
49 |
+
|
50 |
+
# There are a constant number of blocks in every row of the sparse matrix.
|
51 |
+
# A blocks offset is:
|
52 |
+
#
|
53 |
+
# row_index * blocks_per_row + column_index % blocks_per_row
|
54 |
+
#
|
55 |
+
# Once we have the block offsets ordered for transposition we can divide
|
56 |
+
# by blocks_per_row to get the transposed column indices.
|
57 |
+
column_indices_t = row_indices.gather(0, gather_indices.long())
|
58 |
+
block_offsets_t = gather_indices.int()
|
59 |
+
|
60 |
+
zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
|
61 |
+
nnz_per_column = ops.histogram(column_indices, block_columns)
|
62 |
+
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
|
63 |
+
if nnz_per_column.dim() == 0:
|
64 |
+
# This addresses an edge case when ffn_hidden_size is equal to self.blocking.
|
65 |
+
nnz_per_column = nnz_per_column.unsqueeze(0)
|
66 |
+
offsets_t = torch.cat([zero, nnz_per_column])
|
67 |
+
return column_indices_t, offsets_t, block_offsets_t
|
68 |
+
|
69 |
+
def topology(self, x, padded_bins):
|
70 |
+
padded_tokens, _ = x.size()
|
71 |
+
assert padded_tokens % self.blocking == 0
|
72 |
+
if self.ffn_hidden_size % self.blocking != 0:
|
73 |
+
raise ValueError(
|
74 |
+
f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
|
75 |
+
f'the block size {self.blocking}. Please update your configuration.',
|
76 |
+
)
|
77 |
+
|
78 |
+
# Offsets for the sparse matrix. All rows have the
|
79 |
+
# same number of nonzero blocks dictated by the
|
80 |
+
# dimensionality of a single expert.
|
81 |
+
block_rows = padded_tokens // self.blocking
|
82 |
+
blocks_per_row = self.ffn_hidden_size // self.blocking
|
83 |
+
offsets = torch.arange(
|
84 |
+
0,
|
85 |
+
block_rows * blocks_per_row + 1,
|
86 |
+
blocks_per_row,
|
87 |
+
dtype=torch.int32,
|
88 |
+
device=x.device,
|
89 |
+
)
|
90 |
+
|
91 |
+
# Indices for the sparse matrix. The indices for
|
92 |
+
# the intermediate matrix are dynamic depending
|
93 |
+
# on the mapping of tokens to experts.
|
94 |
+
column_indices = ops.topology(
|
95 |
+
padded_bins,
|
96 |
+
self.blocking,
|
97 |
+
block_rows,
|
98 |
+
blocks_per_row,
|
99 |
+
)
|
100 |
+
|
101 |
+
# TODO(tgale): This is unused. Remove the need for this in stk.
|
102 |
+
# For now, use meta init to save the device memory.
|
103 |
+
data = torch.empty(
|
104 |
+
column_indices.numel(),
|
105 |
+
self.blocking,
|
106 |
+
self.blocking,
|
107 |
+
dtype=common.dtype(self.args),
|
108 |
+
device='meta',
|
109 |
+
)
|
110 |
+
shape = (
|
111 |
+
padded_tokens,
|
112 |
+
self.ffn_hidden_size * mpu.experts_per_rank(self.args),
|
113 |
+
)
|
114 |
+
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
115 |
+
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
|
116 |
+
shape,
|
117 |
+
row_indices,
|
118 |
+
column_indices,
|
119 |
+
offsets,
|
120 |
+
)
|
121 |
+
return stk.Matrix(
|
122 |
+
shape,
|
123 |
+
data,
|
124 |
+
row_indices,
|
125 |
+
column_indices,
|
126 |
+
offsets,
|
127 |
+
column_indices_t,
|
128 |
+
offsets_t,
|
129 |
+
block_offsets_t,
|
130 |
+
)
|
131 |
+
|
132 |
+
def indices_and_padded_bins(self, top_experts):
|
133 |
+
# Sort the expert ids to produce the scatter/gather
|
134 |
+
# indices for the permutation.
|
135 |
+
top_experts = top_experts.int()
|
136 |
+
bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
|
137 |
+
|
138 |
+
# Histogram the expert ids to identify the number of
|
139 |
+
# tokens routed to each expert.
|
140 |
+
tokens_per_expert = ops.histogram(top_experts, self.num_experts)
|
141 |
+
|
142 |
+
# Round the token counts up to the block size used in
|
143 |
+
# the matrix muliplications. Caculate the starting
|
144 |
+
# position of each bin.
|
145 |
+
padded_tokens_per_expert = ops.round_up(
|
146 |
+
tokens_per_expert,
|
147 |
+
self.blocking,
|
148 |
+
)
|
149 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
150 |
+
padded_bins = promote_scalar(padded_bins)
|
151 |
+
|
152 |
+
# Calculate the bin bounds for the sorted tokens.
|
153 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
154 |
+
bins = promote_scalar(bins)
|
155 |
+
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
156 |
+
|
157 |
+
def sparse_forward_once(self, x, expert_weights, top_experts):
|
158 |
+
# x: [sl, bs, hs]
|
159 |
+
# expert_weights: [sl * bs, top-k]
|
160 |
+
# top_experts: [sl * bs, top-k]
|
161 |
+
expert_weights = expert_weights.flatten()
|
162 |
+
top_experts = top_experts.flatten()
|
163 |
+
with torch.no_grad():
|
164 |
+
indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
|
165 |
+
|
166 |
+
# Route the tokens for MoE computation.
|
167 |
+
x = x.view(-1, x.shape[-1])
|
168 |
+
x = ops.padded_gather(
|
169 |
+
x,
|
170 |
+
indices,
|
171 |
+
bin_ids,
|
172 |
+
bins,
|
173 |
+
padded_bins,
|
174 |
+
self.top_k,
|
175 |
+
)
|
176 |
+
|
177 |
+
# Create the sparse matrix topology.
|
178 |
+
with torch.no_grad():
|
179 |
+
topo = self.topology(x, padded_bins)
|
180 |
+
|
181 |
+
# Perform the expert computation.
|
182 |
+
x = self.mlp(x, topo)
|
183 |
+
|
184 |
+
# Un-route the data for the MoE output.
|
185 |
+
x = ops.padded_scatter(
|
186 |
+
x,
|
187 |
+
indices,
|
188 |
+
bin_ids,
|
189 |
+
expert_weights,
|
190 |
+
bins,
|
191 |
+
padded_bins,
|
192 |
+
self.top_k,
|
193 |
+
)
|
194 |
+
return x, tokens_per_expert
|
195 |
+
|
196 |
+
# For use in the base-class parallel_forward_once.
|
197 |
+
def sparse_permute_and_compute(
|
198 |
+
self,
|
199 |
+
x,
|
200 |
+
tokens_per_expert,
|
201 |
+
indices,
|
202 |
+
bin_ids,
|
203 |
+
expert_weights,
|
204 |
+
bins,
|
205 |
+
expert_capactiy, # unused
|
206 |
+
top_k,
|
207 |
+
):
|
208 |
+
|
209 |
+
# Round the token counts up to the block size used in the matrix
|
210 |
+
# multiplication. Calculate the starting position of each bin.
|
211 |
+
padded_tokens_per_expert = ops.round_up(
|
212 |
+
tokens_per_expert,
|
213 |
+
self.blocking,
|
214 |
+
)
|
215 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
216 |
+
padded_bins = promote_scalar(padded_bins)
|
217 |
+
|
218 |
+
# Route the tokens for MoE computation.
|
219 |
+
x = x.view(-1, x.shape[-1])
|
220 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
221 |
+
|
222 |
+
# Create the sparse matrix topology.
|
223 |
+
with torch.no_grad():
|
224 |
+
topo = self.topology(x, padded_bins)
|
225 |
+
|
226 |
+
# Perform the expert computation.
|
227 |
+
x = self.mlp(x, topo)
|
228 |
+
|
229 |
+
# Un-route the data for the MoE output.
|
230 |
+
return ops.padded_scatter(
|
231 |
+
x,
|
232 |
+
indices,
|
233 |
+
bin_ids,
|
234 |
+
expert_weights,
|
235 |
+
bins,
|
236 |
+
padded_bins,
|
237 |
+
top_k,
|
238 |
+
)
|
239 |
+
|
240 |
+
def grouped_forward_once(self, x, expert_weights, top_experts):
|
241 |
+
# x: [sl, bs, hs]
|
242 |
+
# expert_weights: [sl * bs, top-k]
|
243 |
+
# top_experts: [sl * bs, top-k]
|
244 |
+
expert_weights = expert_weights.flatten()
|
245 |
+
top_experts = top_experts.flatten()
|
246 |
+
with torch.no_grad():
|
247 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
248 |
+
|
249 |
+
out = self.grouped_permute_and_compute(
|
250 |
+
x,
|
251 |
+
tokens_per_expert,
|
252 |
+
indices,
|
253 |
+
bin_ids,
|
254 |
+
expert_weights,
|
255 |
+
bins,
|
256 |
+
-1, # unused
|
257 |
+
self.args.moe_top_k,
|
258 |
+
)
|
259 |
+
return out, tokens_per_expert
|
260 |
+
|
261 |
+
def grouped_permute_and_compute(
|
262 |
+
self,
|
263 |
+
x,
|
264 |
+
tokens_per_expert,
|
265 |
+
indices,
|
266 |
+
bin_ids,
|
267 |
+
expert_weights,
|
268 |
+
bins,
|
269 |
+
expert_capactiy, # unused
|
270 |
+
top_k,
|
271 |
+
):
|
272 |
+
|
273 |
+
# Route the tokens for MoE computation.
|
274 |
+
x = x.view(-1, x.shape[-1])
|
275 |
+
x = ops.gather(x, indices, bin_ids, bins, top_k)
|
276 |
+
|
277 |
+
# Perform the expert computation.
|
278 |
+
x = self.mlp(x, tokens_per_expert)
|
279 |
+
|
280 |
+
# Un-route the data for the MoE output.
|
281 |
+
return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
|
282 |
+
|
283 |
+
def forward_once(self, x, expert_weights, top_experts):
|
284 |
+
if self.args.mlp_impl == 'sparse':
|
285 |
+
return self.sparse_forward_once(x, expert_weights, top_experts)
|
286 |
+
else:
|
287 |
+
return self.grouped_forward_once(x, expert_weights, top_experts)
|
288 |
+
|
289 |
+
def permute_and_compute(
|
290 |
+
self,
|
291 |
+
x,
|
292 |
+
tokens_per_expert,
|
293 |
+
indices,
|
294 |
+
bin_ids,
|
295 |
+
expert_weights,
|
296 |
+
bins,
|
297 |
+
expert_capactiy,
|
298 |
+
top_k,
|
299 |
+
):
|
300 |
+
if self.args.mlp_impl == 'sparse':
|
301 |
+
return self.sparse_permute_and_compute(
|
302 |
+
x,
|
303 |
+
tokens_per_expert,
|
304 |
+
indices,
|
305 |
+
bin_ids,
|
306 |
+
expert_weights,
|
307 |
+
bins,
|
308 |
+
expert_capactiy,
|
309 |
+
top_k,
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
return self.grouped_permute_and_compute(
|
313 |
+
x,
|
314 |
+
tokens_per_expert,
|
315 |
+
indices,
|
316 |
+
bin_ids,
|
317 |
+
expert_weights,
|
318 |
+
bins,
|
319 |
+
expert_capactiy,
|
320 |
+
top_k,
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
class dMoE(moe.MoE):
|
325 |
+
|
326 |
+
def _init_experts_mlp(self, args: Arguments):
|
327 |
+
return ParallelDroplessMLP(args)
|
torch-ext/megablocks/layers/gelu.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import stk
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
@torch.jit.script
|
10 |
+
def _gelu_backward_inplace(g, x):
|
11 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
12 |
+
ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
|
13 |
+
return g.mul_(ff)
|
14 |
+
|
15 |
+
|
16 |
+
def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
|
17 |
+
# NOTE: The two sparse matrices must have the same topology.
|
18 |
+
if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
|
19 |
+
return stk.Matrix(
|
20 |
+
x.size(),
|
21 |
+
_gelu_backward_inplace(grad.data, x.data),
|
22 |
+
x.row_indices,
|
23 |
+
x.column_indices,
|
24 |
+
x.offsets,
|
25 |
+
x.column_indices_t,
|
26 |
+
x.offsets_t,
|
27 |
+
x.block_offsets_t,
|
28 |
+
)
|
29 |
+
return _gelu_backward_inplace(grad, x)
|
30 |
+
|
31 |
+
|
32 |
+
def gelu(x: stk.Matrix):
|
33 |
+
assert isinstance(x, stk.Matrix)
|
34 |
+
return stk.Matrix(
|
35 |
+
x.size(),
|
36 |
+
F.gelu(x.data, approximate='tanh'),
|
37 |
+
x.row_indices,
|
38 |
+
x.column_indices,
|
39 |
+
x.offsets,
|
40 |
+
x.column_indices_t,
|
41 |
+
x.offsets_t,
|
42 |
+
x.block_offsets_t,
|
43 |
+
)
|
torch-ext/megablocks/layers/glu.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import stk.ops
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megablocks import grouped_gemm_util as gg
|
8 |
+
from megablocks.layers import common, mpu
|
9 |
+
from megablocks.layers.activation_fn import act_fn
|
10 |
+
from megablocks.layers.arguments import Arguments
|
11 |
+
from megablocks.layers.mlp import (
|
12 |
+
SharedMLP,
|
13 |
+
SparseMLP,
|
14 |
+
create_dmoe_expert_weights,
|
15 |
+
resolve_dtensor,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class SparseGLU(SparseMLP):
|
20 |
+
|
21 |
+
def __init__(self, args: Arguments):
|
22 |
+
super().__init__(args)
|
23 |
+
self.v1 = torch.nn.Parameter(
|
24 |
+
torch.empty(
|
25 |
+
self._num_rows_per_rank,
|
26 |
+
args.hidden_size,
|
27 |
+
device=args.device,
|
28 |
+
dtype=common.dtype(args),
|
29 |
+
),
|
30 |
+
)
|
31 |
+
with torch.no_grad():
|
32 |
+
self.v1.copy_(
|
33 |
+
create_dmoe_expert_weights(
|
34 |
+
args,
|
35 |
+
args.moe_num_experts,
|
36 |
+
args.ffn_hidden_size,
|
37 |
+
args.hidden_size,
|
38 |
+
args.init_method,
|
39 |
+
),
|
40 |
+
)
|
41 |
+
|
42 |
+
mpu.set_expert_model_parallel_attributes(
|
43 |
+
self.v1,
|
44 |
+
self._should_set_parallelism_attribute,
|
45 |
+
)
|
46 |
+
|
47 |
+
def forward(self, x, topo):
|
48 |
+
if self.args.memory_optimized_mlp:
|
49 |
+
raise NotImplementedError(
|
50 |
+
'Memory optimized implementation not yet supported with GLU with sparse kernels.',
|
51 |
+
)
|
52 |
+
|
53 |
+
w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
|
54 |
+
w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
|
55 |
+
|
56 |
+
# Compute the GLU.
|
57 |
+
x1 = stk.ops.sdd(x, w1.t(), topo)
|
58 |
+
x2 = stk.ops.sdd(x, v1.t(), topo)
|
59 |
+
|
60 |
+
activation_fn_out = act_fn(x1, self.args.activation_fn)
|
61 |
+
x1 = stk.ops.mul(activation_fn_out, x2)
|
62 |
+
|
63 |
+
return stk.ops.dsd(x1, w2)
|
64 |
+
|
65 |
+
|
66 |
+
class MemoryOptimizedGroupedGLU(torch.autograd.Function):
|
67 |
+
"""GroupedMLP with manually scheduled memory reuse."""
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
71 |
+
def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
|
72 |
+
# Cast inputs using ctx dtype from AMP
|
73 |
+
if ctx._fwd_used_autocast:
|
74 |
+
x = x.to(ctx._dtype)
|
75 |
+
w1 = w1.to(ctx._dtype)
|
76 |
+
v1 = v1.to(ctx._dtype)
|
77 |
+
w2 = w2.to(ctx._dtype)
|
78 |
+
# x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
|
79 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
|
80 |
+
raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
|
81 |
+
|
82 |
+
# Layer 0: x @ w1.t().
|
83 |
+
assert gg.backend is not None
|
84 |
+
sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
|
85 |
+
v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
|
86 |
+
|
87 |
+
# GeLU.
|
88 |
+
activation_fn_out = activation_fn(sdd_out) * v1_out
|
89 |
+
|
90 |
+
# Layer 1: x @ w2.
|
91 |
+
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
|
92 |
+
|
93 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
94 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
95 |
+
# pass in the backward pass to avoid materializing another
|
96 |
+
# intermediate.
|
97 |
+
ctx.x_shape = x.shape
|
98 |
+
ctx.sdd_out_shape = sdd_out.shape
|
99 |
+
ctx.dtype = x.dtype
|
100 |
+
ctx.activation_fn = activation_fn
|
101 |
+
ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
|
102 |
+
return dsd_out
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
106 |
+
def backward(ctx, ddsd_out):
|
107 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
108 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
109 |
+
|
110 |
+
# Unpack saved tensors
|
111 |
+
# dtype = ctx.dtype
|
112 |
+
saved_tensors = ctx.saved_tensors
|
113 |
+
w1, v1, w2 = saved_tensors[:3]
|
114 |
+
batch_sizes = saved_tensors[3]
|
115 |
+
x = saved_tensors[4]
|
116 |
+
sdd_out, v1_out = saved_tensors[5:7]
|
117 |
+
|
118 |
+
# Rematerialize activation_fn output.
|
119 |
+
activation_fn = ctx.activation_fn
|
120 |
+
with torch.set_grad_enabled(True):
|
121 |
+
sdd_out.requires_grad = True
|
122 |
+
v1_out.requires_grad = True
|
123 |
+
activation_fn_out = activation_fn(sdd_out) * v1_out
|
124 |
+
activation_grad_fn = activation_fn_out.backward
|
125 |
+
|
126 |
+
# Compute dw2 with recomputed activation_fn output.
|
127 |
+
assert gg.backend is not None
|
128 |
+
dw2 = gg.backend.gmm(
|
129 |
+
activation_fn_out,
|
130 |
+
ddsd_out,
|
131 |
+
batch_sizes,
|
132 |
+
trans_a=True,
|
133 |
+
)
|
134 |
+
|
135 |
+
# Compute dactivation_fn_out.
|
136 |
+
#
|
137 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
138 |
+
dactivation_fn_out = activation_fn_out
|
139 |
+
gg.backend.gmm(
|
140 |
+
ddsd_out,
|
141 |
+
w2,
|
142 |
+
batch_sizes,
|
143 |
+
trans_b=True,
|
144 |
+
c=dactivation_fn_out,
|
145 |
+
)
|
146 |
+
|
147 |
+
# Compute dsdd_out.
|
148 |
+
#
|
149 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
150 |
+
assert activation_grad_fn is not None
|
151 |
+
activation_grad_fn(dactivation_fn_out)
|
152 |
+
dsdd_out = sdd_out.grad
|
153 |
+
dv1_out = v1_out.grad
|
154 |
+
|
155 |
+
# Compute dw1.
|
156 |
+
dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
|
157 |
+
|
158 |
+
# Compute dv1.
|
159 |
+
dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
|
160 |
+
|
161 |
+
# Compute dx.
|
162 |
+
#
|
163 |
+
# NOTE: This reuses the ddsd_out allocation.
|
164 |
+
dx = ddsd_out
|
165 |
+
gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
|
166 |
+
dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
|
167 |
+
return dx, dw1, dv1, dw2, None, None
|
168 |
+
|
169 |
+
|
170 |
+
memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
|
171 |
+
|
172 |
+
|
173 |
+
class GroupedGLU(SparseGLU):
|
174 |
+
|
175 |
+
def forward(self, x, tokens_per_expert):
|
176 |
+
batch_sizes = tokens_per_expert.cpu().to(torch.long)
|
177 |
+
w1, v1, w2 = (
|
178 |
+
self.scale_grad(self.w1),
|
179 |
+
self.scale_grad(self.v1),
|
180 |
+
self.scale_grad(self.w2),
|
181 |
+
)
|
182 |
+
w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
|
183 |
+
|
184 |
+
# Re-shape the weights for the grouped GEMMs.
|
185 |
+
ne = mpu.experts_per_rank(self.args)
|
186 |
+
w1 = w1.view(ne, -1, self.args.hidden_size)
|
187 |
+
v1 = v1.view(ne, -1, self.args.hidden_size)
|
188 |
+
w2 = w2.view(ne, -1, self.args.hidden_size)
|
189 |
+
|
190 |
+
if self.args.memory_optimized_mlp:
|
191 |
+
return memory_optimized_grouped_glu(
|
192 |
+
x,
|
193 |
+
w1,
|
194 |
+
v1,
|
195 |
+
w2,
|
196 |
+
batch_sizes,
|
197 |
+
self.args.activation_fn,
|
198 |
+
)
|
199 |
+
|
200 |
+
# Compute the MLP.
|
201 |
+
assert gg.ops is not None
|
202 |
+
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
|
203 |
+
x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
|
204 |
+
x1 = self.args.activation_fn(x1) * x2
|
205 |
+
return gg.ops.gmm(x1, w2, batch_sizes)
|
206 |
+
|
207 |
+
|
208 |
+
class SharedGLU(SharedMLP):
|
209 |
+
"""GPU for shared expert.
|
210 |
+
|
211 |
+
Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(self, args: Arguments):
|
215 |
+
super().__init__(args)
|
216 |
+
self.gate_proj = args.fc_cls(
|
217 |
+
args.hidden_size,
|
218 |
+
self.args.shared_expert_hidden_size,
|
219 |
+
**self.fc_kwargs,
|
220 |
+
)
|
221 |
+
|
222 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
223 |
+
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
torch-ext/megablocks/layers/memory_test.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import gc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
from megablocks.layers import arguments, dmoe
|
10 |
+
|
11 |
+
_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
|
12 |
+
|
13 |
+
|
14 |
+
def get_tensors():
|
15 |
+
ptrs = set()
|
16 |
+
out = []
|
17 |
+
for obj in gc.get_objects():
|
18 |
+
if torch.is_tensor(obj):
|
19 |
+
if not obj.is_contiguous() or obj.data_ptr() in ptrs:
|
20 |
+
continue
|
21 |
+
out.append(obj)
|
22 |
+
ptrs.add(obj.data_ptr())
|
23 |
+
return out
|
24 |
+
|
25 |
+
|
26 |
+
def test_memory(
|
27 |
+
group,
|
28 |
+
batch_size,
|
29 |
+
sequence_length,
|
30 |
+
hidden_size,
|
31 |
+
ffn_hidden_size,
|
32 |
+
num_experts,
|
33 |
+
top_k,
|
34 |
+
):
|
35 |
+
args = arguments.Arguments(
|
36 |
+
hidden_size=hidden_size,
|
37 |
+
ffn_hidden_size=ffn_hidden_size,
|
38 |
+
moe_num_experts=num_experts,
|
39 |
+
moe_top_k=top_k,
|
40 |
+
moe_expert_model_parallelism=True,
|
41 |
+
expert_parallel_group=group,
|
42 |
+
fp16=False,
|
43 |
+
bf16=True,
|
44 |
+
device=torch.cuda.current_device(),
|
45 |
+
)
|
46 |
+
layer = dmoe.dMoE(args).cuda()
|
47 |
+
|
48 |
+
x = torch.randn((batch_size, sequence_length, hidden_size),
|
49 |
+
device=torch.cuda.current_device(),
|
50 |
+
dtype=torch.bfloat16).requires_grad_(True)
|
51 |
+
torch.cuda.empty_cache()
|
52 |
+
|
53 |
+
# Run forward + backward.
|
54 |
+
# with torch.autograd.detect_anomaly():
|
55 |
+
out, _ = layer(x)
|
56 |
+
out.mean().backward()
|
57 |
+
|
58 |
+
# Report peak memory.
|
59 |
+
mem = torch.cuda.max_memory_allocated()
|
60 |
+
print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
|
61 |
+
print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
|
62 |
+
|
63 |
+
# Calculate weight and gradient memory usage.
|
64 |
+
weight_memory = 2 * (
|
65 |
+
layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
|
66 |
+
)
|
67 |
+
|
68 |
+
def grad_numel(x):
|
69 |
+
if x.grad is not None:
|
70 |
+
return x.grad.numel()
|
71 |
+
return 0
|
72 |
+
|
73 |
+
grad_memory = 2 * (
|
74 |
+
grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
|
75 |
+
)
|
76 |
+
weight_memory += grad_memory
|
77 |
+
|
78 |
+
print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
|
79 |
+
print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
|
80 |
+
|
81 |
+
# Manually calculate GPU memory usage from the garbage
|
82 |
+
# collector.
|
83 |
+
gc.collect()
|
84 |
+
total = 0
|
85 |
+
tensors = get_tensors()
|
86 |
+
tensors = sorted(tensors, key=lambda x: -x.numel())
|
87 |
+
for i, t in enumerate(tensors):
|
88 |
+
total += t.numel()
|
89 |
+
print(f'{i}: {t.shape}, {t.numel() * 2}')
|
90 |
+
del tensors
|
91 |
+
|
92 |
+
print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
assert dist.is_available()
|
97 |
+
group = dist.init_process_group(backend='nccl')
|
98 |
+
local_rank = dist.get_rank(group)
|
99 |
+
torch.cuda.set_device(local_rank)
|
100 |
+
|
101 |
+
for args in _TESTS:
|
102 |
+
test_memory(group, *args)
|
torch-ext/megablocks/layers/memory_test.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
DISTRIBUTED_ARGUMENTS="\
|
4 |
+
--nproc_per_node 1 \
|
5 |
+
--nnodes 1 \
|
6 |
+
--node_rank 0 \
|
7 |
+
--master_addr localhost \
|
8 |
+
--master_port 6000"
|
9 |
+
|
10 |
+
python -m torch.distributed.launch \
|
11 |
+
${DISTRIBUTED_ARGUMENTS} \
|
12 |
+
megablocks/layers/memory_test.py
|
torch-ext/megablocks/layers/mlp.py
ADDED
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import stk
|
7 |
+
import stk.backend.triton_kernels
|
8 |
+
import stk.ops
|
9 |
+
import torch
|
10 |
+
from packaging import version
|
11 |
+
|
12 |
+
from megablocks import grouped_gemm_util as gg
|
13 |
+
from megablocks.layers import common, gelu, mpu
|
14 |
+
from megablocks.layers.activation_fn import act_fn
|
15 |
+
from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
16 |
+
|
17 |
+
|
18 |
+
class ScaleGradient(torch.autograd.Function):
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
22 |
+
def forward(ctx: Any, x: torch.Tensor, scale: float):
|
23 |
+
ctx.scale = scale
|
24 |
+
return x
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
28 |
+
def backward(ctx: torch.Tensor, grad: torch.Tensor):
|
29 |
+
return grad * ctx.scale, None
|
30 |
+
|
31 |
+
|
32 |
+
scale_gradient = ScaleGradient.apply
|
33 |
+
|
34 |
+
|
35 |
+
def resolve_dtensor(weight: torch.Tensor):
|
36 |
+
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
37 |
+
from torch.distributed._tensor import DTensor
|
38 |
+
if isinstance(weight, DTensor):
|
39 |
+
return weight.to_local()
|
40 |
+
return weight
|
41 |
+
|
42 |
+
|
43 |
+
def create_moe_expert_weights(
|
44 |
+
args: Arguments,
|
45 |
+
num_experts: int,
|
46 |
+
ffn_hidden_size: int,
|
47 |
+
hidden_size: int,
|
48 |
+
init_method: InitFn,
|
49 |
+
):
|
50 |
+
# Create the entire weight matrix such that the sampled weights will
|
51 |
+
# not vary between data parallelism and expert model parallelism for
|
52 |
+
# the same random seed.
|
53 |
+
master_weights = torch.empty(
|
54 |
+
num_experts,
|
55 |
+
ffn_hidden_size,
|
56 |
+
hidden_size,
|
57 |
+
device=args.device,
|
58 |
+
dtype=common.dtype(args),
|
59 |
+
)
|
60 |
+
init_method(master_weights)
|
61 |
+
|
62 |
+
if not args.moe_expert_model_parallelism:
|
63 |
+
return master_weights
|
64 |
+
|
65 |
+
# Calculate the amount of sharding in each dimension.
|
66 |
+
expert_sharding_degree = mpu.expert_sharding_degree(args)
|
67 |
+
hidden_sharding_degree = mpu.hidden_sharding_degree(args)
|
68 |
+
|
69 |
+
# Calculate the experts per rank.
|
70 |
+
#
|
71 |
+
# NOTE: We assign ranks to be expert parallel before going
|
72 |
+
# tensor parallel.
|
73 |
+
rank = mpu.get_expert_parallel_rank(args)
|
74 |
+
expert_rank = rank % expert_sharding_degree
|
75 |
+
num_experts_per_rank = num_experts // expert_sharding_degree
|
76 |
+
start_expert = expert_rank * num_experts_per_rank
|
77 |
+
end_expert = (expert_rank + 1) * num_experts_per_rank
|
78 |
+
|
79 |
+
# Calculate the rows per rank.
|
80 |
+
row_rank = rank // expert_sharding_degree
|
81 |
+
num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
|
82 |
+
start_row = row_rank * num_rows_per_rank
|
83 |
+
end_row = (row_rank + 1) * num_rows_per_rank
|
84 |
+
|
85 |
+
# Slice the weight matrix to get the chunk for this rank.
|
86 |
+
with torch.no_grad():
|
87 |
+
weights = master_weights[start_expert:end_expert, start_row:end_row]
|
88 |
+
return weights
|
89 |
+
|
90 |
+
|
91 |
+
class MLP(torch.nn.Module):
|
92 |
+
|
93 |
+
def __init__(self, args: Arguments):
|
94 |
+
super().__init__()
|
95 |
+
self.args = args
|
96 |
+
# expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
|
97 |
+
experts_per_rank = mpu.experts_per_rank(args)
|
98 |
+
|
99 |
+
self.w1 = torch.nn.Parameter(
|
100 |
+
torch.empty(
|
101 |
+
experts_per_rank,
|
102 |
+
args.hidden_size,
|
103 |
+
mpu.features_per_rank(args),
|
104 |
+
device=args.device,
|
105 |
+
dtype=common.dtype(args),
|
106 |
+
),
|
107 |
+
)
|
108 |
+
self.w2 = torch.nn.Parameter(
|
109 |
+
torch.empty(
|
110 |
+
experts_per_rank,
|
111 |
+
mpu.features_per_rank(args),
|
112 |
+
args.hidden_size,
|
113 |
+
device=args.device,
|
114 |
+
dtype=common.dtype(args),
|
115 |
+
),
|
116 |
+
)
|
117 |
+
mpu.set_expert_model_parallel_attributes(
|
118 |
+
self.w1,
|
119 |
+
args.moe_expert_model_parallelism,
|
120 |
+
)
|
121 |
+
mpu.set_expert_model_parallel_attributes(
|
122 |
+
self.w2,
|
123 |
+
args.moe_expert_model_parallelism,
|
124 |
+
)
|
125 |
+
|
126 |
+
# Initialize the parameters for the MLP.
|
127 |
+
#
|
128 |
+
# NOTE: It is important that we create the weight tensors prior
|
129 |
+
# to creating the master weights and slicing our the piece for
|
130 |
+
# this rank. If the master weights are created first the PyTorch
|
131 |
+
# caching allocator appears to use the same memory block for these
|
132 |
+
# and the slice which causes large increases in our peak memory
|
133 |
+
# usage.
|
134 |
+
with torch.no_grad():
|
135 |
+
w1 = create_moe_expert_weights(
|
136 |
+
args,
|
137 |
+
args.moe_num_experts,
|
138 |
+
args.ffn_hidden_size,
|
139 |
+
args.hidden_size,
|
140 |
+
args.init_method,
|
141 |
+
)
|
142 |
+
self.w1.copy_(w1.transpose(1, 2).contiguous())
|
143 |
+
self.w2.copy_(
|
144 |
+
create_moe_expert_weights(
|
145 |
+
args,
|
146 |
+
args.moe_num_experts,
|
147 |
+
args.ffn_hidden_size,
|
148 |
+
args.hidden_size,
|
149 |
+
args.output_layer_init_method,
|
150 |
+
),
|
151 |
+
)
|
152 |
+
|
153 |
+
self.gradient_scale = None
|
154 |
+
if self.args.moe_expert_model_parallelism:
|
155 |
+
self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
|
156 |
+
|
157 |
+
def scale_grad(self, w):
|
158 |
+
if self.gradient_scale is None:
|
159 |
+
return w
|
160 |
+
return scale_gradient(w, self.gradient_scale)
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
|
164 |
+
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
|
165 |
+
x = torch.bmm(x, w1)
|
166 |
+
x = self.args.activation_fn(x)
|
167 |
+
return torch.bmm(x, w2)
|
168 |
+
|
169 |
+
|
170 |
+
def create_dmoe_expert_weights(
|
171 |
+
args: Arguments,
|
172 |
+
num_experts: int,
|
173 |
+
rows: int,
|
174 |
+
columns: int,
|
175 |
+
init_method: InitFn,
|
176 |
+
):
|
177 |
+
weights = create_moe_expert_weights(
|
178 |
+
args,
|
179 |
+
num_experts,
|
180 |
+
rows,
|
181 |
+
columns,
|
182 |
+
init_method,
|
183 |
+
)
|
184 |
+
return weights.view([-1, columns])
|
185 |
+
|
186 |
+
|
187 |
+
class MemoryOptimizedMLP(torch.autograd.Function):
|
188 |
+
"""Sparse MLP with manually scheduled memory reuse."""
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
192 |
+
def forward(ctx, x, w1, w2, topo, activation_fn):
|
193 |
+
# Cast inputs using ctx dtype from AMP
|
194 |
+
if ctx._fwd_used_autocast:
|
195 |
+
x = x.to(ctx._dtype)
|
196 |
+
w1 = w1.to(ctx._dtype)
|
197 |
+
w2 = w2.to(ctx._dtype)
|
198 |
+
# x: [m, k], w1: [n, k], w2: [n, k]
|
199 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
|
200 |
+
raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
|
201 |
+
|
202 |
+
topo_tensors = (
|
203 |
+
topo.row_indices,
|
204 |
+
topo.column_indices,
|
205 |
+
topo.offsets,
|
206 |
+
topo.column_indices_t,
|
207 |
+
topo.offsets_t,
|
208 |
+
topo.block_offsets_t,
|
209 |
+
)
|
210 |
+
|
211 |
+
# Layer 0: x @ w1.t().
|
212 |
+
sdd_out = stk.ops.sdd(x, w1.t(), topo)
|
213 |
+
|
214 |
+
# GeLU.
|
215 |
+
activation_fn_out = act_fn(sdd_out, activation_fn)
|
216 |
+
|
217 |
+
# Layer 1: x @ w2.
|
218 |
+
dsd_out = stk.ops.dsd(activation_fn_out, w2)
|
219 |
+
|
220 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
221 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
222 |
+
# pass in the backward pass to avoid materializing another
|
223 |
+
# intermediate.
|
224 |
+
ctx.shape = topo.shape
|
225 |
+
ctx.x_shape = x.shape
|
226 |
+
ctx.sdd_out_shape = sdd_out.data.shape
|
227 |
+
ctx.dtype = x.dtype
|
228 |
+
ctx.activation_fn = activation_fn
|
229 |
+
ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
|
230 |
+
return dsd_out
|
231 |
+
|
232 |
+
@staticmethod
|
233 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
234 |
+
def backward(ctx, ddsd_out):
|
235 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
236 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
237 |
+
|
238 |
+
# unpack saved tensors
|
239 |
+
# dtype = ctx.dtype
|
240 |
+
saved_tensors = ctx.saved_tensors
|
241 |
+
w1, w2 = saved_tensors[:2]
|
242 |
+
topo_tensors = saved_tensors[2:8]
|
243 |
+
x = saved_tensors[8]
|
244 |
+
sdd_out_data = saved_tensors[9]
|
245 |
+
|
246 |
+
# rematerialize activation function output
|
247 |
+
activation_fn = ctx.activation_fn
|
248 |
+
sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
|
249 |
+
activation_fn_out, activation_grad_fn = act_fn(
|
250 |
+
sdd_out,
|
251 |
+
activation_fn,
|
252 |
+
return_grad_fn=True,
|
253 |
+
)
|
254 |
+
|
255 |
+
# Compute dw2 with recomputed activation_fn output.
|
256 |
+
dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
|
257 |
+
|
258 |
+
# Compute dactivation_fn_out.
|
259 |
+
#
|
260 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
261 |
+
dactivation_fn_out = activation_fn_out
|
262 |
+
stk.backend.triton_kernels.sdd(
|
263 |
+
ddsd_out,
|
264 |
+
w2.t(),
|
265 |
+
dactivation_fn_out.shape,
|
266 |
+
dactivation_fn_out.data,
|
267 |
+
dactivation_fn_out.offsets,
|
268 |
+
dactivation_fn_out.row_indices,
|
269 |
+
dactivation_fn_out.column_indices,
|
270 |
+
)
|
271 |
+
|
272 |
+
# Compute dsdd_out.
|
273 |
+
#
|
274 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
275 |
+
if activation_fn is DEFAULT_ACTIVATION_FN:
|
276 |
+
dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
|
277 |
+
else:
|
278 |
+
assert activation_grad_fn is not None
|
279 |
+
activation_grad_fn(dactivation_fn_out.data)
|
280 |
+
dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
|
281 |
+
|
282 |
+
# Compute dw1.
|
283 |
+
dw1 = stk.ops.dsd(dsdd_out.t(), x)
|
284 |
+
|
285 |
+
# Compute dx.
|
286 |
+
#
|
287 |
+
# NOTE: This reuses the ddsd_out allocation.
|
288 |
+
stk.backend.triton_kernels.dsd(
|
289 |
+
dsdd_out.shape,
|
290 |
+
dsdd_out.data,
|
291 |
+
dsdd_out.offsets,
|
292 |
+
dsdd_out.row_indices,
|
293 |
+
dsdd_out.column_indices,
|
294 |
+
dsdd_out.offsets_t,
|
295 |
+
dsdd_out.column_indices_t,
|
296 |
+
dsdd_out.block_offsets_t,
|
297 |
+
False,
|
298 |
+
w1,
|
299 |
+
ddsd_out,
|
300 |
+
)
|
301 |
+
dx = ddsd_out
|
302 |
+
return dx, dw1, dw2, None, None
|
303 |
+
|
304 |
+
|
305 |
+
memory_optimized_mlp = MemoryOptimizedMLP.apply
|
306 |
+
|
307 |
+
|
308 |
+
class SparseMLP(torch.nn.Module):
|
309 |
+
|
310 |
+
def __init__(self, args: Arguments):
|
311 |
+
super().__init__()
|
312 |
+
self.args = args
|
313 |
+
self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
|
314 |
+
|
315 |
+
self.w1 = torch.nn.Parameter(
|
316 |
+
torch.empty(
|
317 |
+
self._num_rows_per_rank,
|
318 |
+
args.hidden_size,
|
319 |
+
device=args.device,
|
320 |
+
dtype=common.dtype(args),
|
321 |
+
),
|
322 |
+
)
|
323 |
+
self.w2 = torch.nn.Parameter(
|
324 |
+
torch.empty(
|
325 |
+
self._num_rows_per_rank,
|
326 |
+
args.hidden_size,
|
327 |
+
device=args.device,
|
328 |
+
dtype=common.dtype(args),
|
329 |
+
),
|
330 |
+
)
|
331 |
+
|
332 |
+
# Initialize the parameters for the MLP.
|
333 |
+
#
|
334 |
+
# NOTE: It is important that we create the weight tensors prior
|
335 |
+
# to creating the master weights and slicing our the piece for
|
336 |
+
# this rank. If the master weights are created first the PyTorch
|
337 |
+
# caching allocator appears to use the same memory block for these
|
338 |
+
# and the slice which causes large increases in our peak memory
|
339 |
+
# usage.
|
340 |
+
with torch.no_grad():
|
341 |
+
self.w1.copy_(
|
342 |
+
create_dmoe_expert_weights(
|
343 |
+
args,
|
344 |
+
args.moe_num_experts,
|
345 |
+
args.ffn_hidden_size,
|
346 |
+
args.hidden_size,
|
347 |
+
args.init_method,
|
348 |
+
),
|
349 |
+
)
|
350 |
+
self.w2.copy_(
|
351 |
+
create_dmoe_expert_weights(
|
352 |
+
args,
|
353 |
+
args.moe_num_experts,
|
354 |
+
args.ffn_hidden_size,
|
355 |
+
args.hidden_size,
|
356 |
+
args.output_layer_init_method,
|
357 |
+
),
|
358 |
+
)
|
359 |
+
|
360 |
+
self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
|
361 |
+
mpu.set_expert_model_parallel_attributes(
|
362 |
+
self.w1,
|
363 |
+
self._should_set_parallelism_attribute,
|
364 |
+
)
|
365 |
+
mpu.set_expert_model_parallel_attributes(
|
366 |
+
self.w2,
|
367 |
+
self._should_set_parallelism_attribute,
|
368 |
+
)
|
369 |
+
|
370 |
+
self.gradient_scale = None
|
371 |
+
if self.args.moe_expert_model_parallelism:
|
372 |
+
self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
|
373 |
+
|
374 |
+
def scale_grad(self, w):
|
375 |
+
if self.gradient_scale is None:
|
376 |
+
return w
|
377 |
+
return scale_gradient(w, self.gradient_scale)
|
378 |
+
|
379 |
+
def forward(self, x, topo):
|
380 |
+
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
|
381 |
+
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
|
382 |
+
if self.args.memory_optimized_mlp:
|
383 |
+
return memory_optimized_mlp(
|
384 |
+
x,
|
385 |
+
w1,
|
386 |
+
w2,
|
387 |
+
topo,
|
388 |
+
self.args.activation_fn,
|
389 |
+
)
|
390 |
+
|
391 |
+
# Compute the MLP.
|
392 |
+
x = stk.ops.sdd(x, w1.t(), topo)
|
393 |
+
activation_fn_out = act_fn(x, self.args.activation_fn)
|
394 |
+
return stk.ops.dsd(activation_fn_out, w2)
|
395 |
+
|
396 |
+
|
397 |
+
class MemoryOptimizedGroupedMLP(torch.autograd.Function):
|
398 |
+
"""GroupedMLP with manually scheduled memory reuse."""
|
399 |
+
|
400 |
+
@staticmethod
|
401 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
402 |
+
def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
|
403 |
+
# Cast inputs using ctx dtype from AMP
|
404 |
+
if ctx._fwd_used_autocast:
|
405 |
+
x = x.to(ctx._dtype)
|
406 |
+
w1 = w1.to(ctx._dtype)
|
407 |
+
w2 = w2.to(ctx._dtype)
|
408 |
+
# x: [m, k], w1: [n, k], w2: [n, k]
|
409 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
|
410 |
+
raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
|
411 |
+
|
412 |
+
# Layer 0: x @ w1.t().
|
413 |
+
assert gg.backend is not None
|
414 |
+
sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
|
415 |
+
|
416 |
+
# activation_fn
|
417 |
+
activation_fn_out = activation_fn(sdd_out)
|
418 |
+
|
419 |
+
# Layer 1: x @ w2.
|
420 |
+
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
|
421 |
+
|
422 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
423 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
424 |
+
# pass in the backward pass to avoid materializing another
|
425 |
+
# intermediate.
|
426 |
+
ctx.x_shape = x.shape
|
427 |
+
ctx.sdd_out_shape = sdd_out.shape
|
428 |
+
ctx.dtype = x.dtype
|
429 |
+
ctx.activation_fn = activation_fn
|
430 |
+
ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
|
431 |
+
return dsd_out
|
432 |
+
|
433 |
+
@staticmethod
|
434 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
435 |
+
def backward(ctx: Any, ddsd_out: torch.Tensor):
|
436 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
437 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
438 |
+
|
439 |
+
# Unpack saved tensors
|
440 |
+
# dtype = ctx.dtype
|
441 |
+
saved_tensors = ctx.saved_tensors
|
442 |
+
w1, w2 = saved_tensors[:2]
|
443 |
+
batch_sizes = saved_tensors[2]
|
444 |
+
x = saved_tensors[3]
|
445 |
+
sdd_out = saved_tensors[4]
|
446 |
+
|
447 |
+
# Rematerialize activation_fn output.
|
448 |
+
activation_fn = ctx.activation_fn
|
449 |
+
with torch.set_grad_enabled(True):
|
450 |
+
sdd_out.requires_grad = True
|
451 |
+
activation_fn_out = activation_fn(sdd_out)
|
452 |
+
activation_grad_fn = activation_fn_out.backward
|
453 |
+
|
454 |
+
# Compute dw2 with recomputed activation_fn output.
|
455 |
+
assert gg.backend is not None
|
456 |
+
dw2 = gg.backend.gmm(
|
457 |
+
activation_fn_out,
|
458 |
+
ddsd_out,
|
459 |
+
batch_sizes,
|
460 |
+
trans_a=True,
|
461 |
+
)
|
462 |
+
|
463 |
+
# Compute dactivation_fn_out.
|
464 |
+
#
|
465 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
466 |
+
dactivation_fn_out = activation_fn_out
|
467 |
+
gg.backend.gmm(
|
468 |
+
ddsd_out,
|
469 |
+
w2,
|
470 |
+
batch_sizes,
|
471 |
+
trans_b=True,
|
472 |
+
c=dactivation_fn_out,
|
473 |
+
)
|
474 |
+
|
475 |
+
# Compute dsdd_out.
|
476 |
+
#
|
477 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
478 |
+
if activation_fn is DEFAULT_ACTIVATION_FN:
|
479 |
+
dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
|
480 |
+
else:
|
481 |
+
assert activation_grad_fn is not None
|
482 |
+
activation_grad_fn(dactivation_fn_out)
|
483 |
+
dsdd_out = sdd_out.grad
|
484 |
+
|
485 |
+
# Compute dw1.
|
486 |
+
dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
|
487 |
+
|
488 |
+
# Compute dx.
|
489 |
+
#
|
490 |
+
# NOTE: This reuses the ddsd_out allocation.
|
491 |
+
gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
|
492 |
+
dx = ddsd_out
|
493 |
+
return dx, dw1, dw2, None, None
|
494 |
+
|
495 |
+
|
496 |
+
memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
|
497 |
+
|
498 |
+
|
499 |
+
class GroupedMLP(SparseMLP):
|
500 |
+
|
501 |
+
def forward(self, x, tokens_per_expert):
|
502 |
+
batch_sizes = tokens_per_expert.cpu().to(torch.long)
|
503 |
+
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
|
504 |
+
|
505 |
+
# Re-shape the weights for the grouped GEMMs.
|
506 |
+
ne = mpu.experts_per_rank(self.args)
|
507 |
+
w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
|
508 |
+
w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
|
509 |
+
|
510 |
+
if self.args.memory_optimized_mlp:
|
511 |
+
return memory_optimized_grouped_mlp(
|
512 |
+
x,
|
513 |
+
w1,
|
514 |
+
w2,
|
515 |
+
batch_sizes,
|
516 |
+
self.args.activation_fn,
|
517 |
+
)
|
518 |
+
|
519 |
+
# Compute the MLP.
|
520 |
+
assert gg.ops is not None
|
521 |
+
x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
|
522 |
+
x = self.args.activation_fn(x)
|
523 |
+
return gg.ops.gmm(x, w2, batch_sizes)
|
524 |
+
|
525 |
+
|
526 |
+
class SharedMLP(torch.nn.Module):
|
527 |
+
"""MLP for shared expert.
|
528 |
+
|
529 |
+
Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
|
530 |
+
"""
|
531 |
+
|
532 |
+
def __init__(self, args: Arguments):
|
533 |
+
super().__init__()
|
534 |
+
self.args = args
|
535 |
+
self.fc_kwargs: dict[str, Any] = {
|
536 |
+
'bias': args.bias,
|
537 |
+
'device': args.device,
|
538 |
+
}
|
539 |
+
self.fc_kwargs.update(args.fc_kwargs)
|
540 |
+
|
541 |
+
self.up_proj = args.fc_cls(
|
542 |
+
args.hidden_size,
|
543 |
+
args.shared_expert_hidden_size,
|
544 |
+
**self.fc_kwargs,
|
545 |
+
)
|
546 |
+
self.act = args.activation_fn
|
547 |
+
self.down_proj = args.fc_cls(
|
548 |
+
args.shared_expert_hidden_size,
|
549 |
+
args.hidden_size,
|
550 |
+
**self.fc_kwargs,
|
551 |
+
)
|
552 |
+
self.down_proj._is_residual = True # a flag for llm-foundry init
|
553 |
+
|
554 |
+
def add_experts_sharedexpert(
|
555 |
+
self,
|
556 |
+
shared_expert_out: torch.Tensor,
|
557 |
+
expert_out: torch.Tensor,
|
558 |
+
) -> torch.Tensor:
|
559 |
+
# Helper function to add expert output to shared expert output
|
560 |
+
# with optional weighted sum.
|
561 |
+
if self.args.shared_expert_weighted_sum:
|
562 |
+
# enable using weighted sum for shared expert output
|
563 |
+
# wieghted by number of experts used
|
564 |
+
t_experts = self.args.moe_top_k + 1
|
565 |
+
sh_mlp_out = shared_expert_out / t_experts
|
566 |
+
return sh_mlp_out.add(
|
567 |
+
expert_out,
|
568 |
+
alpha=(self.args.moe_top_k / t_experts),
|
569 |
+
)
|
570 |
+
|
571 |
+
return shared_expert_out + expert_out
|
572 |
+
|
573 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
574 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
torch-ext/megablocks/layers/moe.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
import megablocks.ops as ops
|
10 |
+
from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
|
11 |
+
from megablocks.layers.all_to_all import all_to_all
|
12 |
+
from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
_LOAD_BALANCING_LOSS = []
|
15 |
+
|
16 |
+
|
17 |
+
def save_load_balancing_loss(loss):
|
18 |
+
global _LOAD_BALANCING_LOSS
|
19 |
+
_LOAD_BALANCING_LOSS.append(loss)
|
20 |
+
|
21 |
+
|
22 |
+
def get_load_balancing_loss():
|
23 |
+
global _LOAD_BALANCING_LOSS
|
24 |
+
return _LOAD_BALANCING_LOSS
|
25 |
+
|
26 |
+
|
27 |
+
def clear_load_balancing_loss():
|
28 |
+
global _LOAD_BALANCING_LOSS
|
29 |
+
_LOAD_BALANCING_LOSS.clear()
|
30 |
+
|
31 |
+
|
32 |
+
def batched_load_balancing_loss(args: Arguments):
|
33 |
+
if args.moe_loss_weight == 0:
|
34 |
+
return 0.0
|
35 |
+
|
36 |
+
# tokens_per_expert[i].shape = (num_experts)
|
37 |
+
# expert_scores[i].shape = (tokens, num_experts)
|
38 |
+
tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
|
39 |
+
num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
|
40 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
41 |
+
num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
|
42 |
+
|
43 |
+
if len(tokens_per_expert) != num_layers_per_pipeline_stage:
|
44 |
+
raise ValueError(
|
45 |
+
f'Expected {num_layers_per_pipeline_stage} token_per_experts '
|
46 |
+
f'but found {len(tokens_per_expert)}.\nnum_layers = '
|
47 |
+
f'{args.num_layers}\npipeline_model_parallel_size = '
|
48 |
+
f'{args.pipeline_model_parallel_size}\n'
|
49 |
+
'num_layers_per_virtual_pipeline_stage'
|
50 |
+
f' = {args.num_layers_per_virtual_pipeline_stage}',
|
51 |
+
)
|
52 |
+
if len(expert_scores) != num_layers_per_pipeline_stage:
|
53 |
+
raise ValueError(
|
54 |
+
f'Expected {num_layers_per_pipeline_stage} expert_scores '
|
55 |
+
f'but found {len(tokens_per_expert)}.\nnum_layers = '
|
56 |
+
f'{args.num_layers}\npipeline_model_parallel_size = '
|
57 |
+
f'{args.pipeline_model_parallel_size}\n'
|
58 |
+
'num_layers_per_virtual_pipeline_stage'
|
59 |
+
f' = {args.num_layers_per_virtual_pipeline_stage}',
|
60 |
+
)
|
61 |
+
|
62 |
+
# Verify the shape of the tokens_per_expert and expert_scores tensors.
|
63 |
+
assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
|
64 |
+
|
65 |
+
tokens = expert_scores[0].shape[0]
|
66 |
+
assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
|
67 |
+
|
68 |
+
# Concatenate the contributions of each layer and convert to
|
69 |
+
# the correct types and formats for the dot product.
|
70 |
+
expert_scores = torch.cat(expert_scores, dim=1)
|
71 |
+
if args.moe_lbl_in_fp32:
|
72 |
+
expert_scores = expert_scores.float()
|
73 |
+
if tokens != 0:
|
74 |
+
expert_scores = expert_scores.mean(dim=0)
|
75 |
+
else:
|
76 |
+
expert_scores = expert_scores.sum(dim=0)
|
77 |
+
tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
|
78 |
+
|
79 |
+
expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
|
80 |
+
assert tokens_per_expert.numel() == expected_values
|
81 |
+
assert expert_scores.numel() == expected_values
|
82 |
+
|
83 |
+
# Calculate the total scale across all factors.
|
84 |
+
#
|
85 |
+
# loss_weight * num_experts / (num_layers * tokens * top_k)
|
86 |
+
scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
|
87 |
+
scale_denominator = (args.num_layers * tokens * args.moe_top_k)
|
88 |
+
scale = scale_numerator / scale_denominator
|
89 |
+
return scale * torch.dot(tokens_per_expert, expert_scores)
|
90 |
+
|
91 |
+
|
92 |
+
# NOTE: This class defines MoE expert computation, including expert model parallel
|
93 |
+
# communication. When using FSDP on top of MegaBlocks this is the module that should
|
94 |
+
# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
|
95 |
+
# parallel all2all.
|
96 |
+
class ParallelMLP(torch.nn.Module):
|
97 |
+
|
98 |
+
def __init__(self, args: Arguments):
|
99 |
+
super(ParallelMLP, self).__init__()
|
100 |
+
self.args = args
|
101 |
+
|
102 |
+
# Calculate the number of experts in total and the number of experts
|
103 |
+
# owned by this rank.
|
104 |
+
# world_size = mpu.get_expert_parallel_world_size(args)
|
105 |
+
self.num_experts = args.moe_num_experts
|
106 |
+
self.top_k = self.args.moe_top_k
|
107 |
+
|
108 |
+
# Calculate the number of bits needed to represent the expert indices
|
109 |
+
# so that we can pass it to radix sort.
|
110 |
+
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
111 |
+
|
112 |
+
# Expert MLP.
|
113 |
+
self.mlp = mlp.MLP(args)
|
114 |
+
|
115 |
+
self.bias: Optional[torch.Tensor]
|
116 |
+
if self.args.bias:
|
117 |
+
# Note that the output bias is not parallelized with expert
|
118 |
+
# model parallelism.
|
119 |
+
self.bias = torch.nn.Parameter(
|
120 |
+
torch.empty(
|
121 |
+
args.hidden_size,
|
122 |
+
device=args.device,
|
123 |
+
dtype=common.dtype(args),
|
124 |
+
),
|
125 |
+
)
|
126 |
+
torch.nn.init.zeros_(self.bias)
|
127 |
+
else:
|
128 |
+
self.register_parameter('bias', None)
|
129 |
+
|
130 |
+
# Select the forward function for the operating mode.
|
131 |
+
self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
|
132 |
+
|
133 |
+
def expert_capacity(self, tokens: int) -> int:
|
134 |
+
world_size = mpu.get_expert_parallel_world_size(self.args)
|
135 |
+
tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
|
136 |
+
return int(self.args.moe_capacity_factor * tokens_per_expert)
|
137 |
+
|
138 |
+
def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
|
139 |
+
"""Calculate the load balancing loss contribution."""
|
140 |
+
assert len(expert_scores.size()) == 2
|
141 |
+
tokens, num_experts = expert_scores.size()
|
142 |
+
assert num_experts == self.num_experts
|
143 |
+
assert len(tokens_per_expert.size()) == 1
|
144 |
+
num_experts, = tokens_per_expert.size()
|
145 |
+
assert num_experts == self.num_experts
|
146 |
+
scale = self.num_experts / (tokens * self.top_k)
|
147 |
+
return scale * torch.dot(
|
148 |
+
tokens_per_expert.to(expert_scores.dtype),
|
149 |
+
expert_scores.mean(dim=0),
|
150 |
+
)
|
151 |
+
|
152 |
+
def indices_and_bins(self,
|
153 |
+
top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
154 |
+
# Sort the expert ids to produce the scatter/gather
|
155 |
+
# indices for the permutation.
|
156 |
+
#
|
157 |
+
# TODO(tgale): Is it worth doing this conversion to 32-bit
|
158 |
+
# prior? Could we place the `torch.max` operation to return
|
159 |
+
# 32-bit expert indices?
|
160 |
+
top_expert = top_expert.int()
|
161 |
+
output = ops.sort(top_expert, self.sort_end_bit)
|
162 |
+
assert output is not None
|
163 |
+
bin_ids, indices = output
|
164 |
+
|
165 |
+
# Histogram the expert ids to identify the number of
|
166 |
+
# tokens routed to each expert.
|
167 |
+
#
|
168 |
+
# TODO(tgale): Does the sorted data produce a more favorable
|
169 |
+
# data distribution for histogram? Or is the op parallelism
|
170 |
+
# worth more?
|
171 |
+
tokens_per_expert = ops.histogram(top_expert, self.num_experts)
|
172 |
+
|
173 |
+
# Calculate the bin bounds for the sorted tokens.
|
174 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
175 |
+
assert bins is not None
|
176 |
+
bins = bins.view(1) if not len(bins.size()) else bins
|
177 |
+
|
178 |
+
assert isinstance(indices, torch.Tensor)
|
179 |
+
assert isinstance(bin_ids, torch.Tensor)
|
180 |
+
assert isinstance(bins, torch.Tensor)
|
181 |
+
assert isinstance(tokens_per_expert, torch.Tensor)
|
182 |
+
|
183 |
+
return indices, bin_ids, bins, tokens_per_expert
|
184 |
+
|
185 |
+
def permute_and_compute(
|
186 |
+
self,
|
187 |
+
x: torch.Tensor,
|
188 |
+
tokens_per_expert: int, # unused
|
189 |
+
indices: torch.Tensor,
|
190 |
+
bin_ids: torch.Tensor, # unused
|
191 |
+
expert_weights: torch.Tensor,
|
192 |
+
bins: torch.Tensor,
|
193 |
+
expert_capacity: int,
|
194 |
+
top_k: int,
|
195 |
+
):
|
196 |
+
# Route the tokens for MoE computation.
|
197 |
+
x = x.view(-1, x.shape[-1])
|
198 |
+
output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
|
199 |
+
assert output is not None
|
200 |
+
x = output
|
201 |
+
|
202 |
+
# Perform the expert computation. Note that we don't
|
203 |
+
# use biases for these linear operations.
|
204 |
+
x = self.mlp(x)
|
205 |
+
|
206 |
+
# Un-route the data for the MoE output.
|
207 |
+
return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
|
208 |
+
|
209 |
+
def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
210 |
+
# x: [sl, bs, hs]
|
211 |
+
# expert_weights: [sl * bs, top-k]
|
212 |
+
# top_experts: [sl * bs, top-k]
|
213 |
+
expert_weights = expert_weights.flatten()
|
214 |
+
top_experts = top_experts.flatten()
|
215 |
+
with torch.no_grad():
|
216 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
217 |
+
|
218 |
+
# If expert_capacity is set to zero, set the number of tokens
|
219 |
+
# per expert to the maximum we need to avoid dropping tokens.
|
220 |
+
sl, bs, _ = x.size()
|
221 |
+
expert_capacity = self.expert_capacity(sl * bs)
|
222 |
+
if expert_capacity == 0:
|
223 |
+
expert_capacity = torch.max(tokens_per_expert).item()
|
224 |
+
|
225 |
+
x = self.permute_and_compute(
|
226 |
+
x,
|
227 |
+
tokens_per_expert,
|
228 |
+
indices,
|
229 |
+
bin_ids,
|
230 |
+
expert_weights,
|
231 |
+
bins,
|
232 |
+
expert_capacity,
|
233 |
+
self.top_k,
|
234 |
+
)
|
235 |
+
return x, tokens_per_expert
|
236 |
+
|
237 |
+
def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
238 |
+
# NOTE: This function implements the same computation as forward_once
|
239 |
+
# but with expert model parallelism.
|
240 |
+
#
|
241 |
+
# 1. Permute the tokens locally so that they are grouped by their
|
242 |
+
# expert assignments. This allows us to transfer all of the tokens
|
243 |
+
# for a remote device in one communication primitive.
|
244 |
+
#
|
245 |
+
# 2. Permute the tokens across the expert parallel devices. After
|
246 |
+
# this is completed each device has all of the tokens assigned to
|
247 |
+
# its set of experts in its local HBM.
|
248 |
+
#
|
249 |
+
# 3. Permute the tokens locally so that they are grouped by their
|
250 |
+
# expert assignement. After the distributed permutation the tokens
|
251 |
+
# are grouped by which device they came from. We re-order them
|
252 |
+
# locally to allow for efficient computation.
|
253 |
+
#
|
254 |
+
# After this series of permutations we compute the linear layers
|
255 |
+
# and then repeat these three steps in reverse to produce the final
|
256 |
+
# output.
|
257 |
+
#
|
258 |
+
# Compute the mapping of local tokens to experts.
|
259 |
+
expert_weights = expert_weights.flatten()
|
260 |
+
top_experts = top_experts.flatten()
|
261 |
+
with torch.no_grad():
|
262 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
263 |
+
|
264 |
+
# If we're sharding the experts along the hidden dimension
|
265 |
+
# multiple devices own parts of the same sets of experts.
|
266 |
+
# Replicate the token counts so every device gets the counts.
|
267 |
+
repeated_tokens_per_expert = ops.repeat(
|
268 |
+
tokens_per_expert,
|
269 |
+
(mpu.hidden_sharding_degree(self.args),),
|
270 |
+
)
|
271 |
+
|
272 |
+
# Pass token count information to the device on which the
|
273 |
+
# target expert resides.
|
274 |
+
parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
|
275 |
+
tpe_handle = dist.all_to_all_single(
|
276 |
+
parallel_tokens_per_expert,
|
277 |
+
repeated_tokens_per_expert,
|
278 |
+
group=self.args.expert_parallel_group,
|
279 |
+
async_op=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
# Permute locally and without any padding so that tokens for each
|
283 |
+
# parallel device are stored contiguously.
|
284 |
+
#
|
285 |
+
# This view updates the shape of the tensor from [sl, bs, hs] to
|
286 |
+
# [sl * bs, hs] prior to the permutation.
|
287 |
+
x = x.view(-1, x.shape[-1])
|
288 |
+
output = ops.gather(x, indices, bin_ids, bins, self.top_k)
|
289 |
+
assert output is not None
|
290 |
+
x = output
|
291 |
+
|
292 |
+
# Compute the number of tokens that will be received from each
|
293 |
+
# device and permute the input data across the devices.
|
294 |
+
with torch.no_grad():
|
295 |
+
tpe_handle.wait()
|
296 |
+
experts_per_rank = mpu.experts_per_rank(self.args)
|
297 |
+
|
298 |
+
# Reshape to [world_size, num_experts_per_rank].
|
299 |
+
world_size = mpu.get_expert_parallel_world_size(self.args)
|
300 |
+
repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
|
301 |
+
parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
|
302 |
+
|
303 |
+
# TODO(tgale): It might be faster to do this on the GPU and
|
304 |
+
# then communicate the results back to the host.
|
305 |
+
send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
|
306 |
+
parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
|
307 |
+
recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
|
308 |
+
|
309 |
+
# Convert the send/recv counts to lists.
|
310 |
+
send_counts = send_counts.tolist()
|
311 |
+
recv_counts = recv_counts.tolist()
|
312 |
+
tokens_received = sum(recv_counts)
|
313 |
+
|
314 |
+
# If we're sharding the experts along the hidden dimension
|
315 |
+
# multiple devices own parts of the same sets of experts.
|
316 |
+
# Replicate the token counts so devices that share experts
|
317 |
+
# get all of the tokens assigned to them.
|
318 |
+
#
|
319 |
+
# TODO(tgale): Fuse this into the prior, local permutation.
|
320 |
+
x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
321 |
+
|
322 |
+
# Start the cross-device permutation asynchronously so we can
|
323 |
+
# overlap communication with computation.
|
324 |
+
parallel_x, parallel_x_handle = all_to_all(
|
325 |
+
x,
|
326 |
+
recv_counts,
|
327 |
+
send_counts,
|
328 |
+
self.args.expert_parallel_group,
|
329 |
+
async_op=True,
|
330 |
+
)
|
331 |
+
|
332 |
+
with torch.no_grad():
|
333 |
+
# After we do the cross-device permutation we have the tokens on the
|
334 |
+
# correct device but not yet grouped by expert because we received
|
335 |
+
# tokens from each device as contiguous chunks. To group the tokens
|
336 |
+
# for expert computation we'll do one more local permutation. The
|
337 |
+
# rest of this torch.no_grad() scope sets up the indices and bins
|
338 |
+
# for this permutation.
|
339 |
+
replicate_bins = ops.inclusive_cumsum(
|
340 |
+
parallel_tokens_per_expert.flatten(),
|
341 |
+
0,
|
342 |
+
)
|
343 |
+
replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
|
344 |
+
|
345 |
+
# Construct the expert indices for the permuted tokens.
|
346 |
+
parallel_top_expert = torch.remainder(
|
347 |
+
torch.arange(
|
348 |
+
self.num_experts * mpu.hidden_sharding_degree(self.args),
|
349 |
+
dtype=torch.int32,
|
350 |
+
device=indices.device,
|
351 |
+
),
|
352 |
+
mpu.experts_per_rank(self.args),
|
353 |
+
)
|
354 |
+
parallel_top_expert = ops.replicate(
|
355 |
+
parallel_top_expert.unsqueeze(dim=0),
|
356 |
+
replicate_bins,
|
357 |
+
tokens_received,
|
358 |
+
).flatten()
|
359 |
+
|
360 |
+
# TODO(tgale): The sort_end_bit here can be reduced.
|
361 |
+
parallel_bin_ids, parallel_indices = ops.sort(
|
362 |
+
parallel_top_expert,
|
363 |
+
self.sort_end_bit,
|
364 |
+
)
|
365 |
+
|
366 |
+
# Calculate the bins boundaries from the token counts.
|
367 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
|
368 |
+
dim=0,
|
369 |
+
dtype=torch.int,
|
370 |
+
)
|
371 |
+
parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
|
372 |
+
parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
|
373 |
+
|
374 |
+
# If expert_capacity is set to zero, set the number of tokens
|
375 |
+
# per expert to the maximum we need to avoid dropping tokens.
|
376 |
+
tokens, _ = x.size()
|
377 |
+
expert_capacity = self.expert_capacity(tokens)
|
378 |
+
if expert_capacity == 0:
|
379 |
+
expert_capacity = torch.max(parallel_tokens_per_expert).item()
|
380 |
+
|
381 |
+
# Locally permute the tokens and perform the expert computation.
|
382 |
+
# Block to make sure that the cross-device permutation is complete.
|
383 |
+
if self.args.mlp_impl == 'grouped':
|
384 |
+
# GroupedMLP requires counts on CPU. We can use the tensor already
|
385 |
+
# moved to CPU for the prior all_to_all, which avoids an extra
|
386 |
+
# device synchronization.
|
387 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
|
388 |
+
dim=0,
|
389 |
+
dtype=torch.int,
|
390 |
+
)
|
391 |
+
parallel_x_handle.wait()
|
392 |
+
parallel_x = self.permute_and_compute(
|
393 |
+
parallel_x,
|
394 |
+
parallel_tokens_per_expert,
|
395 |
+
parallel_indices,
|
396 |
+
parallel_bin_ids,
|
397 |
+
None, # expert_weights
|
398 |
+
parallel_bins,
|
399 |
+
expert_capacity,
|
400 |
+
top_k=1,
|
401 |
+
)
|
402 |
+
|
403 |
+
# Un-permute the tokens across the devices.
|
404 |
+
x, _ = all_to_all(
|
405 |
+
parallel_x,
|
406 |
+
send_counts,
|
407 |
+
recv_counts,
|
408 |
+
self.args.expert_parallel_group,
|
409 |
+
)
|
410 |
+
|
411 |
+
# Reduce along the hidden sharding to get the final outputs.
|
412 |
+
#
|
413 |
+
# TODO(tgale): Fuse this into the following local permutation.
|
414 |
+
shape = (
|
415 |
+
mpu.hidden_sharding_degree(self.args),
|
416 |
+
-1,
|
417 |
+
self.args.hidden_size,
|
418 |
+
)
|
419 |
+
x = ops.sum(x.view(shape), dim=0)
|
420 |
+
|
421 |
+
# Un-permute locally to setup for the next series of operations.
|
422 |
+
x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
423 |
+
return x, tokens_per_expert.flatten()
|
424 |
+
|
425 |
+
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
426 |
+
in_shape = x.size()
|
427 |
+
|
428 |
+
# Compute the experts.
|
429 |
+
x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
|
430 |
+
if self.training and self.args.moe_loss_weight > 0:
|
431 |
+
save_load_balancing_loss((tokens_per_expert, scores))
|
432 |
+
x = x.view(in_shape)
|
433 |
+
if self.bias is not None:
|
434 |
+
if self.args.return_bias:
|
435 |
+
return x, self.bias
|
436 |
+
return x + self.bias
|
437 |
+
return x
|
438 |
+
|
439 |
+
|
440 |
+
class MoE(torch.nn.Module):
|
441 |
+
|
442 |
+
def __init__(self, args: Arguments):
|
443 |
+
super(MoE, self).__init__()
|
444 |
+
|
445 |
+
# Token router.
|
446 |
+
self.router = router.LearnedRouter(args)
|
447 |
+
|
448 |
+
# Expert computation helper.
|
449 |
+
self.experts = self._init_experts_mlp(args)
|
450 |
+
|
451 |
+
self.shared_expert = None
|
452 |
+
if args.shared_expert:
|
453 |
+
# SharedExpert computation helper.
|
454 |
+
self.shared_expert = sharedexpert_registry.get(args)
|
455 |
+
|
456 |
+
def _init_experts_mlp(self, args: Arguments):
|
457 |
+
return ParallelMLP(args)
|
458 |
+
|
459 |
+
def forward(self, x: torch.Tensor):
|
460 |
+
# NOTE: If we're going to cast the activations to lower precision
|
461 |
+
# do it before we permute the tokens to save bandwidth.
|
462 |
+
x = common.cast_if_autocast_enabled(x)
|
463 |
+
|
464 |
+
# Compute the expert scores and assignments.
|
465 |
+
scores, expert_weights, top_experts = self.router(x)
|
466 |
+
|
467 |
+
# Compute the experts.
|
468 |
+
out = self.experts(x, scores, expert_weights, top_experts)
|
469 |
+
if self.shared_expert is not None:
|
470 |
+
shared_expert_out = self.shared_expert(x)
|
471 |
+
out = self.shared_expert.add_experts_sharedexpert(
|
472 |
+
shared_expert_out,
|
473 |
+
out,
|
474 |
+
)
|
475 |
+
return out
|
torch-ext/megablocks/layers/mpu.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
from megablocks.layers.arguments import Arguments
|
10 |
+
|
11 |
+
|
12 |
+
class MoeParam(torch.Tensor):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__(self)
|
16 |
+
self.expert_model_parallel: bool
|
17 |
+
|
18 |
+
|
19 |
+
def is_moe_param(tensor: torch.Tensor) -> bool:
|
20 |
+
return hasattr(tensor, 'expert_model_parallel')
|
21 |
+
|
22 |
+
|
23 |
+
def get_expert_parallel_world_size(args: Arguments) -> int:
|
24 |
+
return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
|
25 |
+
|
26 |
+
|
27 |
+
def get_expert_parallel_rank(args: Arguments) -> int:
|
28 |
+
return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
|
29 |
+
|
30 |
+
|
31 |
+
def set_expert_model_parallel_attributes(
|
32 |
+
tensor: torch.Tensor,
|
33 |
+
is_parallel: bool,
|
34 |
+
):
|
35 |
+
assert not hasattr(tensor, 'expert_model_parallel')
|
36 |
+
setattr(tensor, 'expert_model_parallel', is_parallel)
|
37 |
+
|
38 |
+
|
39 |
+
def param_is_expert_model_parallel(param: MoeParam) -> bool:
|
40 |
+
return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
|
41 |
+
|
42 |
+
|
43 |
+
def copy_expert_model_parallel_attributes(
|
44 |
+
destination_tensor: torch.Tensor,
|
45 |
+
source_tensor: torch.Tensor,
|
46 |
+
):
|
47 |
+
if hasattr(source_tensor, 'expert_model_parallel'):
|
48 |
+
setattr(
|
49 |
+
destination_tensor,
|
50 |
+
'expert_model_parallel',
|
51 |
+
getattr(source_tensor, 'expert_model_parallel'),
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
|
56 |
+
world_size = dist.get_world_size(group)
|
57 |
+
rank = dist.get_rank(group)
|
58 |
+
for i in range(world_size):
|
59 |
+
dist.barrier(group)
|
60 |
+
if i == rank:
|
61 |
+
print(f'rank = {rank}', *x)
|
62 |
+
|
63 |
+
|
64 |
+
# Helpers for expert/tensor sharding.
|
65 |
+
def expert_sharding_degree(args: Arguments) -> int:
|
66 |
+
world_size = get_expert_parallel_world_size(args)
|
67 |
+
esd = min(world_size, args.moe_num_experts)
|
68 |
+
|
69 |
+
if (args.moe_num_experts % esd) != 0:
|
70 |
+
raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
|
71 |
+
return esd
|
72 |
+
|
73 |
+
|
74 |
+
def hidden_sharding_degree(args: Arguments) -> int:
|
75 |
+
world_size = get_expert_parallel_world_size(args)
|
76 |
+
esd = expert_sharding_degree(args)
|
77 |
+
hsd = world_size // esd
|
78 |
+
|
79 |
+
if (args.ffn_hidden_size % hsd) != 0:
|
80 |
+
raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
|
81 |
+
if (esd * hsd) != world_size:
|
82 |
+
raise ValueError(
|
83 |
+
f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
|
84 |
+
)
|
85 |
+
return hsd
|
86 |
+
|
87 |
+
|
88 |
+
def experts_per_rank(args: Arguments) -> int:
|
89 |
+
return args.moe_num_experts // expert_sharding_degree(args)
|
90 |
+
|
91 |
+
|
92 |
+
def features_per_rank(args: Arguments) -> int:
|
93 |
+
return args.ffn_hidden_size // hidden_sharding_degree(args)
|
torch-ext/megablocks/layers/router.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megablocks.layers import common
|
8 |
+
from megablocks.layers.arguments import Arguments
|
9 |
+
|
10 |
+
_ROUTER_LOGITS = []
|
11 |
+
|
12 |
+
|
13 |
+
def _save_router_logits(logits: torch.Tensor, args: Arguments):
|
14 |
+
if args.moe_zloss_weight == 0:
|
15 |
+
return
|
16 |
+
global _ROUTER_LOGITS
|
17 |
+
_ROUTER_LOGITS.append(logits)
|
18 |
+
|
19 |
+
|
20 |
+
def clear_router_zloss():
|
21 |
+
global _ROUTER_LOGITS
|
22 |
+
_ROUTER_LOGITS.clear()
|
23 |
+
|
24 |
+
|
25 |
+
def batched_router_zloss(args: Arguments):
|
26 |
+
global _ROUTER_LOGITS
|
27 |
+
|
28 |
+
if args.moe_zloss_weight == 0:
|
29 |
+
import warnings
|
30 |
+
warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
|
31 |
+
return 0
|
32 |
+
|
33 |
+
logits_per_router = _ROUTER_LOGITS
|
34 |
+
|
35 |
+
if args.moe_zloss_in_fp32:
|
36 |
+
logits_per_router = [logits.float() for logits in logits_per_router]
|
37 |
+
|
38 |
+
unscaled_zloss_per_router = torch.stack([
|
39 |
+
torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
|
40 |
+
])
|
41 |
+
|
42 |
+
return args.moe_zloss_weight * unscaled_zloss_per_router
|
43 |
+
|
44 |
+
|
45 |
+
# NOTE: To enable end-to-end benchmarking without convergence we
|
46 |
+
# support a flag to force the router to assign tokens uniformly
|
47 |
+
# across the experts. We do this with a custom autograd operation
|
48 |
+
# so that PyTorch still executes the full set of router operation.
|
49 |
+
class _UniformExpertAssignment(torch.autograd.Function):
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def forward(ctx: Any, x: torch.Tensor, num_experts: int):
|
53 |
+
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
|
54 |
+
out = torch.remainder(out, num_experts)
|
55 |
+
return out.view(x.shape)
|
56 |
+
|
57 |
+
|
58 |
+
_uniform_expert_assignment = _UniformExpertAssignment.apply
|
59 |
+
|
60 |
+
|
61 |
+
class LearnedRouter(torch.nn.Module):
|
62 |
+
|
63 |
+
def __init__(self, args: Arguments):
|
64 |
+
super().__init__()
|
65 |
+
self.args = args
|
66 |
+
|
67 |
+
# Learned router parameters.
|
68 |
+
#
|
69 |
+
# NOTE: This weight matrix is not parallelized with expert model
|
70 |
+
# parallelism. Each device needs the entire router weight matrix
|
71 |
+
# so that it can route its batch of data correctly.
|
72 |
+
self.layer = torch.nn.Linear(
|
73 |
+
args.hidden_size,
|
74 |
+
args.moe_num_experts,
|
75 |
+
bias=False,
|
76 |
+
dtype=common.dtype(args),
|
77 |
+
device=args.device,
|
78 |
+
)
|
79 |
+
args.init_method(self.layer.weight)
|
80 |
+
|
81 |
+
def jitter(self, x: torch.Tensor):
|
82 |
+
low: float = 1.0 - self.args.moe_jitter_eps
|
83 |
+
high: float = 1.0 + self.args.moe_jitter_eps
|
84 |
+
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
|
85 |
+
return low + noise * (high - low)
|
86 |
+
|
87 |
+
def _top_k(self, scores: torch.Tensor):
|
88 |
+
if self.args.moe_top_k == 1:
|
89 |
+
return scores.max(dim=-1, keepdim=True)
|
90 |
+
return torch.topk(scores, self.args.moe_top_k, dim=-1)
|
91 |
+
|
92 |
+
def forward(self, x: torch.Tensor):
|
93 |
+
if self.training and self.args.moe_jitter_eps is not None:
|
94 |
+
x = x * self.jitter(x)
|
95 |
+
|
96 |
+
logits = self.layer(x.view(-1, x.shape[-1]))
|
97 |
+
_save_router_logits(logits, self.args)
|
98 |
+
scores = logits.softmax(dim=-1)
|
99 |
+
expert_weights, expert_indices = self._top_k(scores)
|
100 |
+
if self.args.moe_normalize_expert_weights:
|
101 |
+
expert_weights = expert_weights / torch.norm(
|
102 |
+
expert_weights,
|
103 |
+
p=self.args.moe_normalize_expert_weights,
|
104 |
+
dim=-1,
|
105 |
+
keepdim=True,
|
106 |
+
)
|
107 |
+
|
108 |
+
expert_indices = (
|
109 |
+
_uniform_expert_assignment(
|
110 |
+
expert_indices,
|
111 |
+
self.args.moe_num_experts,
|
112 |
+
) if self.args.uniform_expert_assignment else expert_indices
|
113 |
+
)
|
114 |
+
return scores, expert_weights, expert_indices
|
torch-ext/megablocks/layers/sharedexpert_registry.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from megablocks.layers import glu, mlp
|
7 |
+
from megablocks.layers.arguments import Arguments
|
8 |
+
|
9 |
+
_REGISTRY = {
|
10 |
+
'mlp': mlp.SharedMLP,
|
11 |
+
'glu': glu.SharedGLU,
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
|
16 |
+
"""Returns an SharedMLP for use in a dMoE instance.
|
17 |
+
|
18 |
+
Uses the provided arguments to instantiate the appropriate
|
19 |
+
SharedMLP instance.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
args: propagated Arguments dataclass.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
An instantiated SharedMLP constructed using the input args.
|
26 |
+
"""
|
27 |
+
if args.mlp_type not in _REGISTRY:
|
28 |
+
raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
|
29 |
+
|
30 |
+
return _REGISTRY[args.mlp_type](args)
|
torch-ext/megablocks/ops/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from megablocks.ops.binned_gather import binned_gather
|
5 |
+
from megablocks.ops.binned_scatter import binned_scatter
|
6 |
+
from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum
|
7 |
+
from megablocks.ops.gather import gather
|
8 |
+
from megablocks.ops.histogram import histogram
|
9 |
+
from megablocks.ops.padded_gather import padded_gather
|
10 |
+
from megablocks.ops.padded_scatter import padded_scatter
|
11 |
+
from megablocks.ops.repeat import repeat
|
12 |
+
from megablocks.ops.replicate import replicate
|
13 |
+
from megablocks.ops.round_up import round_up
|
14 |
+
from megablocks.ops.scatter import scatter
|
15 |
+
from megablocks.ops.sort import sort
|
16 |
+
from megablocks.ops.sum import sum
|
17 |
+
from megablocks.ops.topology import topology
|
18 |
+
|
19 |
+
__all__ = [
|
20 |
+
'binned_gather',
|
21 |
+
'binned_scatter',
|
22 |
+
'exclusive_cumsum',
|
23 |
+
'inclusive_cumsum',
|
24 |
+
'gather',
|
25 |
+
'histogram',
|
26 |
+
'padded_gather',
|
27 |
+
'padded_scatter',
|
28 |
+
'repeat',
|
29 |
+
'replicate',
|
30 |
+
'round_up',
|
31 |
+
'scatter',
|
32 |
+
'sort',
|
33 |
+
'sum',
|
34 |
+
'topology',
|
35 |
+
]
|
torch-ext/megablocks/ops/all_to_all_benchmark.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
from megablocks import benchmark_util
|
8 |
+
from megablocks.layers.all_to_all import all_to_all
|
9 |
+
|
10 |
+
_ALL_TO_ALL_BENCHMARK = (
|
11 |
+
(8, 1024),
|
12 |
+
(16, 1024),
|
13 |
+
(32, 1024),
|
14 |
+
(64, 1024),
|
15 |
+
(128, 1024),
|
16 |
+
(256, 1024),
|
17 |
+
(512, 1024),
|
18 |
+
(1024, 1024),
|
19 |
+
(2 * 1024, 1024),
|
20 |
+
(4 * 1024, 1024),
|
21 |
+
(8 * 1024, 1024),
|
22 |
+
(16 * 1024, 1024),
|
23 |
+
(32 * 1024, 1024),
|
24 |
+
(64 * 1024, 1024),
|
25 |
+
(128 * 1024, 1024),
|
26 |
+
(256 * 1024, 1024),
|
27 |
+
(512 * 1024, 1024),
|
28 |
+
(1024 * 1024, 1024),
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def benchmark_all_to_all(group, sl, hs):
|
33 |
+
world_size = dist.get_world_size(group)
|
34 |
+
assert (sl % world_size) == 0
|
35 |
+
send_recv_sizes = [sl // world_size] * world_size
|
36 |
+
|
37 |
+
x = torch.randn((sl, hs)).cuda().half()
|
38 |
+
|
39 |
+
details = {
|
40 |
+
'world_size': world_size,
|
41 |
+
'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
|
42 |
+
}
|
43 |
+
|
44 |
+
def benchmark():
|
45 |
+
return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
|
46 |
+
|
47 |
+
time, std = benchmark_util.benchmark_function(benchmark)
|
48 |
+
|
49 |
+
if dist.get_rank(group) == 0:
|
50 |
+
benchmark_util.log_benchmark('All-To-All', details, time, std)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
assert dist.is_available()
|
55 |
+
group = dist.init_process_group(backend='nccl')
|
56 |
+
local_rank = dist.get_rank(group)
|
57 |
+
torch.cuda.set_device(local_rank)
|
58 |
+
|
59 |
+
for args in _ALL_TO_ALL_BENCHMARK:
|
60 |
+
benchmark_all_to_all(group, *args)
|